### Utils for printing

In [None]:
import jax
import jax.numpy as jnp

from xminigrid.core.constants import Colors, Tiles
from xminigrid.types import AgentState, RuleSet

In [None]:
COLOR_NAMES = {
    Colors.EMPTY: "white",
    Colors.RED: "red",
    Colors.GREEN: "green",
    Colors.BLUE: "blue",
    Colors.PURPLE: "purple",
    Colors.YELLOW: "yellow",
    Colors.GREY: "grey",
    Colors.BLACK: "black",
    Colors.ORANGE: "orange",
    Colors.WHITE: "white",
    Colors.BROWN: "brown",
    Colors.PINK: "pink",
}

TILE_STR = {
    Tiles.EMPTY: " ",
    Tiles.FLOOR: ".",
    Tiles.WALL: "☰",
    Tiles.BALL: "⏺",
    Tiles.SQUARE: "▪",
    Tiles.PYRAMID: "▲",
    Tiles.HEX: "⬢",
    Tiles.STAR: "★",
    Tiles.GOAL: "■",
    Tiles.DOOR_LOCKED: "L",
    Tiles.DOOR_CLOSED: "C",
    Tiles.DOOR_OPEN: "O",
    Tiles.KEY: "K",
}

# for ruleset printing
RULE_TILE_STR = {
    Tiles.FLOOR: "floor",
    Tiles.BALL: "ball",
    Tiles.SQUARE: "square",
    Tiles.PYRAMID: "pyramid",
    Tiles.GOAL: "goal",
    Tiles.KEY: "key",
    Tiles.HEX: "hexagon",
    Tiles.STAR: "star",
}

PLAYER_STR = {0: "^", 1: ">", 2: "V", 3: "<"}

In [None]:
def _wrap_with_color(string: str, color: str) -> str:
    return f"[bold {color}]{string}[/bold {color}]"


# WARN: will NOT work under jit and needed for debugging mainly.
def render(grid: jax.Array, agent: AgentState | None = None) -> str:
    string = ""

    for y in range(grid.shape[0]):
        for x in range(grid.shape[1]):
            tile_id, tile_color = grid[y, x]
            tile_str = TILE_STR[tile_id.item()]
            tile_color = COLOR_NAMES[tile_color.item()]

            if agent is not None and jnp.all(agent.position == jnp.array((y, x))):
                tile_str = PLAYER_STR[agent.direction.item()]
                tile_color = COLOR_NAMES[Colors.RED]

            string += _wrap_with_color(tile_str, tile_color)

        if y < grid.shape[0] - 1:
            string += "\n"

    return string

In [None]:
# WARN: This is for debugging mainly! Will refactor later if needed.
def _encode_tile(tile: list[int]) -> str:
    return f"{COLOR_NAMES[tile[1]]} {RULE_TILE_STR[tile[0]]}"

In [None]:
def _text_encode_goal(goal: list[int]) -> str:
    goal_id = goal[0]
    if goal_id == 1:
        return f"AgentHold({_encode_tile(goal[1:3])})"
    elif goal_id == 3:
        return f"AgentNear({_encode_tile(goal[1:3])})"
    elif goal_id == 4:
        return f"TileNear({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 7:
        return f"TileNearUpGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 8:
        return f"TileNearRightGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 9:
        return f"TileNearDownGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 10:
        return f"TileNearLeftGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 11:
        return f"AgentNearUpGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 12:
        return f"AgentNearRightGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 13:
        return f"AgentNearDownGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 14:
        return f"AgentNearLeftGoal({_encode_tile(goal[1:3])})"
    else:
        raise RuntimeError(f"Rendering: Unknown goal id: {goal_id}")

In [None]:
def _text_encode_rule(rule: list[int]) -> str:
    rule_id = rule[0]
    if rule_id == 1:
        return f"AgentHold({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 2:
        return f"AgentNear({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 3:
        return f"TileNear({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 4:
        return f"TileNearUpRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 5:
        return f"TileNearRightRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 6:
        return f"TileNearDownRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 7:
        return f"TileNearLeftRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 8:
        return f"AgentNearUpRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 9:
        return f"AgentNearRightRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 10:
        return f"AgentNearDownRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 11:
        return f"AgentNearLeftRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    else:
        raise RuntimeError(f"Rendering: Unknown rule id: {rule_id}")

In [None]:
def print_ruleset(ruleset: RuleSet) -> None:
    print("GOAL:")
    print(_text_encode_goal(ruleset.goal.tolist()))
    print()
    print("RULES:")
    for rule in ruleset.rules.tolist():
        if rule[0] != 0:
            print(_text_encode_rule(rule))
    print()
    print("INIT TILES:")
    for tile in ruleset.init_tiles.tolist():
        if tile[0] != 0:
            print(_encode_tile(tile))

### Utils for generating

In [49]:
import random
import pickle
from itertools import product
from collections import namedtuple, defaultdict

import jax
import jax.numpy as jnp
from tqdm.auto import tqdm, trange
from xminigrid.benchmarks import save_bz2_pickle
from xminigrid.core.constants import Colors, Tiles
from xminigrid.rendering.text_render import _encode_tile, _text_encode_goal, _text_encode_rule, print_ruleset
from xminigrid.core.goals import (
    AgentHoldGoal,
    AgentNearDownGoal,
    AgentNearGoal,
    AgentNearLeftGoal,
    AgentNearRightGoal,
    AgentNearUpGoal,
    TileNearDownGoal,
    TileNearGoal,
    TileNearLeftGoal,
    TileNearRightGoal,
    TileNearUpGoal,
)
from xminigrid.core.grid import pad_along_axis
from xminigrid.core.rules import (
    AgentHoldRule,
    AgentNearDownRule,
    AgentNearLeftRule,
    AgentNearRightRule,
    AgentNearRule,
    AgentNearUpRule,
    EmptyRule,
    TileNearDownRule,
    TileNearLeftRule,
    TileNearRightRule,
    TileNearRule,
    TileNearUpRule,
)

In [8]:
COLORS = [
    Colors.RED,
    Colors.GREEN,
    Colors.BLUE,
    Colors.PURPLE,
    Colors.YELLOW,
    Colors.GREY,
    Colors.WHITE,
    Colors.BROWN,
    Colors.PINK,
    Colors.ORANGE,
]

# we need to distinguish between them, to avoid sampling
# near(goal, goal) goal or rule as goal tiles are not pickable
NEAR_TILES_LHS = list(
    product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS)
)
# these are pickable!
NEAR_TILES_RHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

HOLD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

# to imitate disappearance production rule
PROD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))
PROD_TILES = PROD_TILES + [(Tiles.FLOOR, Colors.BLACK)]

# all possible tiles for randomly choosing a prod_tile for generating subtrees
ALL_TILES = list(
    product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS)
)

In [9]:
# Define the namedtuple for storing ruleset information
RuleSet = namedtuple('RuleSet', ['goal', 'rules', 'init_tiles'])

def print_rulesets(rulesets):
    """Converts each ruleset dictionary into a namedtuple and prints it."""
    for ruleset in rulesets:
        # Convert dictionary to namedtuple
        ruleset_nt = RuleSet(
            goal=ruleset['goal'],
            rules=ruleset['rules'],
            init_tiles=ruleset['init_tiles']
        )
        # Print the converted ruleset
        print_ruleset(ruleset_nt)
        print("\n===============================")

In [10]:
def encode(ruleset):
    flatten_encoding = jnp.concatenate([ruleset["goal"].encode(), *[r.encode() for r in ruleset["rules"]]]).tolist()
    return tuple(flatten_encoding)

def diff(list1, list2):
    return list(set(list1) - set(list2))

In [11]:
# checking this for the purpose of generating goals;
# if the tile type is GOAL, goal_idx has to be compatible with NEAR_TILES_LHS which
# is the only list of tiles that has tile type GOAL --> 1 <= goal_idx <= 5

# the encoding of goal types is the following:
# Tiles.BALL = 3, Tiles.SQUARE = 4, Tiles.PYRAMID = 5, 
# Tiles.KEY = 7, Tiles.STAR = 12, Tiles.HEX = 11, Tiles.GOAL = 6
def is_goal_tile(tile):
    return tile[0] == 6

In [12]:
def sample_goal(goal_tiles):
    goals = (
        AgentHoldGoal,
        # agent near variations
        AgentNearGoal,
        AgentNearUpGoal,
        AgentNearDownGoal,
        AgentNearLeftGoal,
        AgentNearRightGoal,
        # tile near variations
        TileNearGoal,
        TileNearUpGoal,
        TileNearDownGoal,
        TileNearLeftGoal,
        TileNearRightGoal,
    )

    if len(goal_tiles) == 1:
        print("Generating an AgentHoldGoal or AgentNearXXX goal.")
        tile = goal_tiles[0]
        if is_goal_tile(tile):  # if it is a GOAL tile type, it can't be the AgentHold Goal
            goal_idx = random.randint(1, 5)
        else:
            goal_idx = random.randint(0, 5)
        goal = goals[goal_idx](tile=jnp.array(tile))
        return goal, (tile,)
    elif len(goal_tiles) == 2:
        print("Generating a TileNearXXX goal.")
        tile_a, tile_b = goal_tiles
        if is_goal_tile(tile_b):
            raise RuntimeError("Imcompatible goal tile types.")
        else:
            goal_idx = random.randint(6, 10)
        goal = goals[goal_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b))
        return goal, (tile_a, tile_b)
    else:
        raise RuntimeError("Unknown goal.")

In [13]:
goal_tiles = (random.choice(ALL_TILES),)
tile = goal_tiles[0]
_encode_tile(tile)

'pink ball'

In [14]:
goal, tiles = sample_goal(goal_tiles)

Generating an AgentHoldGoal or AgentNearXXX goal.


In [15]:
_text_encode_goal(goal.encode().tolist())

'AgentHold(pink ball)'

In [16]:
def remove_tiles(source_list, tiles_to_remove):
    temp_list = source_list[:]
    for tile in tiles_to_remove:
        if tile in temp_list:
            temp_list.remove(tile)
    return temp_list

def add_tiles(source_list, tiles_to_add):
    source_list.extend(tiles_to_add)
    return source_list

In [17]:
def sample_rule(avail_tiles, prod_tile, used_tiles, mandatory_tile=None):
    rules = (
        AgentHoldRule,
        # agent near variations
        AgentNearRule,
        AgentNearUpRule,
        AgentNearDownRule,
        AgentNearLeftRule,
        AgentNearRightRule,
        # tile near variations
        TileNearRule,
        TileNearUpRule,
        TileNearDownRule,
        TileNearLeftRule,
        TileNearRightRule,
    )

    if mandatory_tile:
        if mandatory_tile in used_tiles:
            raise RuntimeError(f"The mandatory tile {_encode_tile(mandatory_tile)} has been used before.")
        if not (mandatory_tile in avail_tiles):
            raise RuntimeError(f"The mandatory tile {_encode_tile(mandatory_tile)} is not available.")
    
    filtered_tiles = diff(avail_tiles, used_tiles)  # we want to avoid tiles that are already used
    is_agent_near = random.choice([True, False])  # act as a switch to decide if use 1 tile or 2 tiles

    if len(filtered_tiles) == 0:
        raise RuntimeError("No available tiles left.")
        
    if len(filtered_tiles) == 1 or (mandatory_tile and len(filtered_tiles) >= 2 and is_agent_near):  
        print("Generating an AgentHoldGoal or AgentNearXXX goal.")
        tile = mandatory_tile if mandatory_tile else filtered_tiles[0]
        # if it is a GOAL tile type, it can't be of rules that requie only 1 tile
        if is_goal_tile(tile):
            raise RuntimeError("Rule tile incompatible with the rule.")
        rule_idx = random.randint(0, 5)
        rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
        avail_tiles = remove_tiles(avail_tiles, [tile])
        avail_tiles = add_tiles(avail_tiles, [prod_tile])
        used_tiles = add_tiles(used_tiles, [tile])
        return rule, (tile,), avail_tiles, used_tiles
    else:
        print("Generating a TileNearXXX goal.")
        if mandatory_tile:
            is_a_mandatory =  random.choice([True, False]) 
            filtered_tiles = remove_tiles(filtered_tiles, [mandatory_tile])
            if is_goal_tile(mandatory_tile) or is_a_mandatory:
                tile_a = mandatory_tile
                avail_for_b = [t for t in filtered_tiles if not is_goal_tile(t)]
                if not avail_for_b:
                    raise RuntimeError("No valid non-GOAL tiles available.")
                tile_b = random.choice(avail_for_b)
            else:
                tile_a, tile_b = random.sample(filtered_tiles, 1)[0], mandatory_tile
        else:
            tile_a, tile_b = random.sample(filtered_tiles, 2)

        rule_idx = random.randint(6, 10)
        rule = rules[rule_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b), prod_tile=jnp.array(prod_tile))
        avail_tiles = remove_tiles(avail_tiles, [tile_a, tile_b])
        avail_tiles = add_tiles(avail_tiles, [prod_tile])
        used_tiles = add_tiles(used_tiles, [tile_a, tile_b])
        return rule, (tile_a, tile_b), avail_tiles, used_tiles

In [18]:
avail_tiles = [random.choice(ALL_TILES), random.choice(ALL_TILES), random.choice(ALL_TILES), random.choice(ALL_TILES), random.choice(ALL_TILES)]
for tile in avail_tiles:
    print(_encode_tile(tile))

grey key
brown star
green key
grey hexagon
grey square


In [19]:
prod_tile = random.choice(ALL_TILES)
_encode_tile(prod_tile)

'green star'

In [20]:
used_tiles = [random.choice(ALL_TILES), random.choice(ALL_TILES), random.choice(ALL_TILES)]
for tile in used_tiles:
    print(_encode_tile(tile))

brown star
green pyramid
pink star


In [21]:
rule, rule_tiles, avail, used = sample_rule(avail_tiles, prod_tile, used_tiles, avail_tiles[0])

Generating an AgentHoldGoal or AgentNearXXX goal.


In [22]:
rule

AgentNearDownRule(tile=Array([7, 6], dtype=int32), prod_tile=Array([12,  2], dtype=int32))

In [23]:
_text_encode_rule(rule.encode().tolist())

'AgentNearDownRule(grey key) -> green star'

In [24]:
for tile in avail:
    print(_encode_tile(tile))

brown star
green key
grey hexagon
grey square
green star


In [25]:
for tile in used:
    print(_encode_tile(tile))

brown star
green pyramid
pink star
grey key


### Generating subtrees of different depth from the same pool of tiles

In [26]:
# i allow duplicate tiles here

def sample_subtree(depth, avail_tiles, used_tiles, mandatory_tile=None):
    if depth == 0:
        # base case: return an empty dict, and the avail_tiles and used_tiles after generation of this subtree
        return {}, avail_tiles, used_tiles

    # select a production tile randomly from all possible tiles
    prod_tile = random.choice(ALL_TILES)

    # sample a rule with the current state of available tiles and used tiles
    rule, rule_tiles, updated_avail_tiles, updated_used_tiles = sample_rule(avail_tiles, 
                                                                            prod_tile, 
                                                                            used_tiles, 
                                                                            mandatory_tile)

    # recursively sample subtree for the next depth level
    # passing the current prod_tile as the mandatory_tile for the next layer
    subtree, final_avail_tiles, final_used_tiles = sample_subtree(depth - 1, 
                                                                  updated_avail_tiles, 
                                                                  updated_used_tiles,
                                                                  prod_tile)

    # Merge the current rule into the subtree
    current_rule_data = {'rule': rule, 'rule_tiles': rule_tiles, 'prod_tile': prod_tile}
    if depth in subtree:
        subtree[depth].append(current_rule_data)
    else:
        subtree[depth] = [current_rule_data]

    return subtree, final_avail_tiles, final_used_tiles

In [27]:
def traverse_subtree(subtree):
    for depth, rules in subtree.items():
        print(f"Depth: {depth}")
        for rule_info in rules:
            print(f"  Rule: {rule_info['rule']}")
            print(f"  Rule Tiles: {rule_info['rule_tiles']}")
            print(f"  Production Tile: {rule_info['prod_tile']}")
            # If there are further subtrees, recurse into them
            if 'subtree' in rule_info and rule_info['subtree']:
                traverse_subtree(rule_info['subtree'])

In [28]:
avail_tiles = [random.choice(ALL_TILES), random.choice(ALL_TILES), random.choice(ALL_TILES), random.choice(ALL_TILES), random.choice(ALL_TILES), random.choice(ALL_TILES), random.choice(ALL_TILES), random.choice(ALL_TILES), random.choice(ALL_TILES)]
for tile in avail_tiles:
    print(_encode_tile(tile))

blue key
orange goal
brown star
yellow pyramid
orange hexagon
red goal
brown key
red ball
white pyramid


In [29]:
avail_tiles

[(7, 3), (6, 8), (12, 10), (5, 5), (11, 8), (6, 1), (7, 10), (3, 1), (5, 9)]

In [30]:
used_tiles = []

In [31]:
num_subtrees = 5
subtrees = []
depth = 1
for _ in range(0, num_subtrees):
    subtree, avail_tiles, used_tiles = sample_subtree(depth, avail_tiles, used_tiles)
    subtrees.append(subtree)

Generating a TileNearXXX goal.
Generating a TileNearXXX goal.
Generating a TileNearXXX goal.
Generating a TileNearXXX goal.
Generating a TileNearXXX goal.


In [32]:
subtrees

[{1: [{'rule': TileNearUpRule(tile_a=Array([5, 5], dtype=int32), tile_b=Array([6, 8], dtype=int32), prod_tile=Array([ 3, 11], dtype=int32)),
    'rule_tiles': ((5, 5), (6, 8)),
    'prod_tile': (3, 11)}]},
 {1: [{'rule': TileNearLeftRule(tile_a=Array([6, 1], dtype=int32), tile_b=Array([7, 3], dtype=int32), prod_tile=Array([3, 5], dtype=int32)),
    'rule_tiles': ((6, 1), (7, 3)),
    'prod_tile': (3, 5)}]},
 {1: [{'rule': TileNearDownRule(tile_a=Array([5, 9], dtype=int32), tile_b=Array([12, 10], dtype=int32), prod_tile=Array([12,  2], dtype=int32)),
    'rule_tiles': ((5, 9), (12, 10)),
    'prod_tile': (12, 2)}]},
 {1: [{'rule': TileNearLeftRule(tile_a=Array([3, 1], dtype=int32), tile_b=Array([ 7, 10], dtype=int32), prod_tile=Array([ 6, 10], dtype=int32)),
    'rule_tiles': ((3, 1), (7, 10)),
    'prod_tile': (6, 10)}]},
 {1: [{'rule': TileNearRightRule(tile_a=Array([12,  2], dtype=int32), tile_b=Array([11,  8], dtype=int32), prod_tile=Array([7, 8], dtype=int32)),
    'rule_tiles': ((

In [33]:
for subtree in subtrees:
    for depth, rules in subtree.items():
        print(f"Depth: {depth}")
        for rule_info in rules:
            print(_text_encode_rule(rule_info['rule'].encode().tolist()))
    print("\n")

Depth: 1
TileNearUpRule(yellow pyramid, orange goal) -> pink ball


Depth: 1
TileNearLeftRule(red goal, blue key) -> yellow ball


Depth: 1
TileNearDownRule(white pyramid, brown star) -> green star


Depth: 1
TileNearLeftRule(red ball, brown key) -> brown goal


Depth: 1
TileNearRightRule(green star, orange hexagon) -> orange key




In [34]:
for tile in avail_tiles:
    print(_encode_tile(tile))

pink ball
yellow ball
brown goal
orange key


In [35]:
# Example usage
depth = 2  # Desired depth of the subtree
initial_avail_tiles = avail_tiles.copy()  
initial_used_tiles = used_tiles.copy()
subtree, latest_avail, latest_used = sample_subtree(depth, initial_avail_tiles, initial_used_tiles)
print(subtree)

Generating a TileNearXXX goal.
Generating a TileNearXXX goal.
{1: [{'rule': TileNearUpRule(tile_a=Array([12,  4], dtype=int32), tile_b=Array([3, 5], dtype=int32), prod_tile=Array([ 6, 10], dtype=int32)), 'rule_tiles': ((12, 4), (3, 5)), 'prod_tile': (6, 10)}], 2: [{'rule': TileNearUpRule(tile_a=Array([ 3, 11], dtype=int32), tile_b=Array([7, 8], dtype=int32), prod_tile=Array([12,  4], dtype=int32)), 'rule_tiles': ((3, 11), (7, 8)), 'prod_tile': (12, 4)}]}


In [36]:
traverse_subtree(subtree)

Depth: 1
  Rule: TileNearUpRule(tile_a=Array([12,  4], dtype=int32), tile_b=Array([3, 5], dtype=int32), prod_tile=Array([ 6, 10], dtype=int32))
  Rule Tiles: ((12, 4), (3, 5))
  Production Tile: (6, 10)
Depth: 2
  Rule: TileNearUpRule(tile_a=Array([ 3, 11], dtype=int32), tile_b=Array([7, 8], dtype=int32), prod_tile=Array([12,  4], dtype=int32))
  Rule Tiles: ((3, 11), (7, 8))
  Production Tile: (12, 4)


In [37]:
for depth, rules in subtree.items():
    print(f"Depth: {depth}")
    for rule_info in rules:
        print(_text_encode_rule(rule_info['rule'].encode().tolist()))

Depth: 1
TileNearUpRule(purple star, yellow ball) -> brown goal
Depth: 2
TileNearUpRule(pink ball, orange key) -> purple star


In [38]:
for tile in latest_avail:
    print(_encode_tile(tile))

brown goal
brown goal


### Generating atomic subtrees of only depth of 1 for stacking

Here we sample all possible atomic trees of depth 1 and ensure that each tree starts with a full set of available tiles and an empty set of used tiles.

In [None]:
def sample_atomic_subtree(depth, avail_tiles, used_tiles, mandatory_tile=None):
    if depth == 0:
        # base case: return an empty dict, and the avail_tiles and used_tiles after generation of this subtree
        return {}, avail_tiles, used_tiles

    # select a production tile randomly from all possible tiles
    prod_tile = random.choice(ALL_TILES)

    # sample a rule with the current state of available tiles and used tiles
    rule, rule_tiles, updated_avail_tiles, updated_used_tiles = sample_rule(avail_tiles, 
                                                                            prod_tile, 
                                                                            used_tiles, 
                                                                            mandatory_tile)

    # recursively sample subtree for the next depth level
    # passing the current prod_tile as the mandatory_tile for the next layer
    subtree, final_avail_tiles, final_used_tiles = sample_subtree(depth - 1, 
                                                                  updated_avail_tiles, 
                                                                  updated_used_tiles,
                                                                  prod_tile)

    # Merge the current rule into the subtree
    current_rule_data = {'rule': rule, 'rule_tiles': rule_tiles, 'prod_tile': prod_tile}
    if depth in subtree:
        subtree[depth].append(current_rule_data)
    else:
        subtree[depth] = [current_rule_data]

    return subtree, final_avail_tiles, final_used_tiles

In [918]:
def generate_rule(rule_constructor, rule_tiles, prod_tile):
    if len(rule_tiles) == 1:
        tile = rule_tiles[0]
        if is_goal_tile(tile):
            raise RuntimeError("Rule tile incompatible with rule constructor.")
        rule = rule_constructor(tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
        return rule, (tile,)
    elif len(rule_tiles) == 2:
        tile_a, tile_b = rule_tiles
        if is_goal_tile(tile_b):
            raise RuntimeError("Rule tiles incompatible with rule constructor.")
        rule = rule_constructor(tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b), prod_tile=jnp.array(prod_tile))
        return rule, (tile_a, tile_b)
    else:
        raise RuntimeError("Unknown rule.")

In [957]:
rule_constructor = TileNearRule

In [991]:
rule_tiles = (random.choice(ALL_TILES),random.choice(ALL_TILES))
for tile in rule_tiles:
    print(_encode_tile(tile))

purple star
brown star


In [992]:
prod_tile = random.choice(ALL_TILES)
_encode_tile(prod_tile)

'yellow ball'

In [993]:
rule, rule_tiles = generate_rule(rule_constructor, rule_tiles, prod_tile)

In [994]:
_text_encode_rule(rule.encode().tolist())

'TileNear(purple star, brown star) -> yellow ball'

In [996]:
def create_and_store_atomic_tree(rule_constructor, rule_tiles, prod_tile):
    try:
        rule, rule_tiles = generate_rule(rule_constructor, rule_tiles, prod_tile)
        tree = {
            'rule': rule,
            'rule_tiles': rule_tiles,
            'prod_tile': prod_tile,
        }
        return tree
    except Exception as e:
        print(f"Error generating tree for rule tiles {rule_tiles}, {str(e)}")

In [1020]:
def sample_all_atomic_trees():
    trees = []
    rules = (
        (AgentHoldRule, 1),
        (AgentNearRule, 1),
        (AgentNearUpRule, 1),
        (AgentNearDownRule, 1),
        (AgentNearLeftRule, 1),
        (AgentNearRightRule, 1),
        (TileNearRule, 2),
        (TileNearUpRule, 2),
        (TileNearDownRule, 2),
        (TileNearLeftRule, 2),
        (TileNearRightRule, 2),
    )

    for rule_constructor, num_inputs in rules:
        if num_inputs == 1:  # then the input tiles should be chosen from HOLD_TILES
            for rule_tile in HOLD_TILES:
                for prod_tile in PROD_TILES:
                    tree = create_and_store_atomic_tree(rule_constructor, (rule_tile, ), prod_tile)
                    trees.append(tree)
        else:
            for tile_a in NEAR_TILES_LHS:
                for tile_b in NEAR_TILES_RHS:
                    for prod_tile in PROD_TILES:
                        tree = create_and_store_atomic_tree(rule_constructor, (tile_a, tile_b), prod_tile)
                        trees.append(tree)
    return trees
                

In [1002]:
all_atomic_trees = sample_all_atomic_trees()

In [1003]:
all_atomic_trees

[{'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 1], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 1)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 2], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 2)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 3], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 3)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 4], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 4)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 5], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 5)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 6], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 6)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 9], dtype=int32)),
  'rule_tiles': ((3, 1

In [2]:
import pickle

def save_results(data, filename):
    with open(filename, "wb") as f:
        pickle.dump(data, f)
    print(f"Results saved to {filename}")

def load_results(filename):
    with open(filename, "rb") as f:
        data = pickle.load(f)
    return data

In [1047]:
save_results(all_atomic_trees, "saved_atomic_prod_rules.pkl")

Results saved to saved_atomic_trees.pkl


In [45]:
all_atomic_prod_rules = load_results("saved_atomic_prod_rules.pkl")

In [46]:
n = 1000

for i in range(0, n):
    print(_text_encode_rule(all_atomic_prod_rules[i]['rule'].encode().tolist()))

AgentHold(red ball) -> red ball
AgentHold(red ball) -> green ball
AgentHold(red ball) -> blue ball
AgentHold(red ball) -> purple ball
AgentHold(red ball) -> yellow ball
AgentHold(red ball) -> grey ball
AgentHold(red ball) -> white ball
AgentHold(red ball) -> brown ball
AgentHold(red ball) -> pink ball
AgentHold(red ball) -> orange ball
AgentHold(red ball) -> red square
AgentHold(red ball) -> green square
AgentHold(red ball) -> blue square
AgentHold(red ball) -> purple square
AgentHold(red ball) -> yellow square
AgentHold(red ball) -> grey square
AgentHold(red ball) -> white square
AgentHold(red ball) -> brown square
AgentHold(red ball) -> pink square
AgentHold(red ball) -> orange square
AgentHold(red ball) -> red pyramid
AgentHold(red ball) -> green pyramid
AgentHold(red ball) -> blue pyramid
AgentHold(red ball) -> purple pyramid
AgentHold(red ball) -> yellow pyramid
AgentHold(red ball) -> grey pyramid
AgentHold(red ball) -> white pyramid
AgentHold(red ball) -> brown pyramid
AgentHold(

In [39]:
len(HOLD_TILES)

60

In [40]:
len(PROD_TILES)

61

In [41]:
len(NEAR_TILES_LHS)

70

In [42]:
len(NEAR_TILES_RHS)

60

In [47]:
len(all_atomic_prod_rules)  # is the expected number of atomic production rules

1302960

In [48]:
all_atomic_prod_rules[:100]

[{'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 1], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 1)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 2], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 2)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 3], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 3)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 4], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 4)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 5], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 5)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 6], dtype=int32)),
  'rule_tiles': ((3, 1),),
  'prod_tile': (3, 6)},
 {'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([3, 9], dtype=int32)),
  'rule_tiles': ((3, 1

In [50]:
# create a default dictionary to store lists of rules under each prod_tile key
rules_by_prod_tile = defaultdict(list)

In [51]:
# Iterate over each rule
for prod_rule in all_atomic_prod_rules:
    # use 'prod_tile' directly from the dictionary as the dictionary key
    prod_tile_key = prod_rule['prod_tile']
    # append the rule to the list of rules for this 'prod_tile'
    rules_by_prod_tile[prod_tile_key].append({
        'rule': prod_rule['rule'], 
        'rule_tiles': prod_rule['rule_tiles'],
    })

In [60]:
rules_by_prod_tile[(1, 7)]

[{'rule': AgentHoldRule(tile=Array([3, 1], dtype=int32), prod_tile=Array([1, 7], dtype=int32)),
  'rule_tiles': ((3, 1),)},
 {'rule': AgentHoldRule(tile=Array([3, 2], dtype=int32), prod_tile=Array([1, 7], dtype=int32)),
  'rule_tiles': ((3, 2),)},
 {'rule': AgentHoldRule(tile=Array([3, 3], dtype=int32), prod_tile=Array([1, 7], dtype=int32)),
  'rule_tiles': ((3, 3),)},
 {'rule': AgentHoldRule(tile=Array([3, 4], dtype=int32), prod_tile=Array([1, 7], dtype=int32)),
  'rule_tiles': ((3, 4),)},
 {'rule': AgentHoldRule(tile=Array([3, 5], dtype=int32), prod_tile=Array([1, 7], dtype=int32)),
  'rule_tiles': ((3, 5),)},
 {'rule': AgentHoldRule(tile=Array([3, 6], dtype=int32), prod_tile=Array([1, 7], dtype=int32)),
  'rule_tiles': ((3, 6),)},
 {'rule': AgentHoldRule(tile=Array([3, 9], dtype=int32), prod_tile=Array([1, 7], dtype=int32)),
  'rule_tiles': ((3, 9),)},
 {'rule': AgentHoldRule(tile=Array([ 3, 10], dtype=int32), prod_tile=Array([1, 7], dtype=int32)),
  'rule_tiles': ((3, 10),)},
 {'ru

In [62]:
total_rules_count = sum(len(rules) for rules in rules_by_prod_tile.values())
print("Total number of rules:", total_rules_count)

Total number of rules: 1302960


In [None]:
save_results(rules_by_prod_tile, "saved_rules_by_prod_tile.pkl")