In [48]:
import jax
import jax.numpy as jnp
import xminigrid
import os
from xminigrid.types import RuleSet
from xminigrid.benchmarks import Benchmark, load_benchmark, load_benchmark_from_path, load_bz2_pickle, DATA_PATH, NAME2HFFILENAME
from xminigrid.rendering.text_render import print_ruleset

# utils for the demonstration
from xminigrid.core.grid import room
from xminigrid.types import AgentState
from xminigrid.core.actions import take_action
from xminigrid.core.constants import Tiles, Colors, TILES_REGISTRY
from xminigrid.rendering.rgb_render import render

# rules and goals
from xminigrid.core.goals import check_goal, AgentNearGoal, TileNearGoal
from xminigrid.core.rules import check_rule, AgentNearRule, TileNearRule, TileNearRightRule, AgentNearRightRule, AgentNearUpRule

In [16]:
def encode(ruleset):
    flatten_encoding = jnp.concatenate([ruleset["goal"].encode(), *[r.encode() for r in ruleset["rules"]]]).tolist()
    return tuple(flatten_encoding)

In [24]:
# GOAL: AgentNear(purple square)
goal = AgentNearGoal(tile=TILES_REGISTRY[Tiles.SQUARE, Colors.PURPLE])

In [25]:
# RULES: AgentNear(yellow ball) -> purple square
rule1 = AgentNearRule(
    tile=TILES_REGISTRY[Tiles.BALL, Colors.YELLOW],
    prod_tile=TILES_REGISTRY[Tiles.SQUARE, Colors.PURPLE],
)

In [26]:
rule2 = AgentNearRule(
    tile=TILES_REGISTRY[Tiles.BALL, Colors.GREEN],
    prod_tile=TILES_REGISTRY[Tiles.BALL, Colors.YELLOW],
)

In [27]:
rules = jnp.array([rule.encode(), rule2.encode()])

In [34]:
ruleset = RuleSet(
    goal=goal.encode(),
    rules=rules,
    init_tiles=jnp.array((
        TILES_REGISTRY[Tiles.BALL, Colors.GREEN],
        TILES_REGISTRY[Tiles.SQUARE, Colors.GREEN],
        TILES_REGISTRY[Tiles.PYRAMID, Colors.YELLOW]
    ))
)

In [35]:
print_ruleset(ruleset)

GOAL:
AgentNear(purple square)

RULES:
AgentNear(yellow ball) -> purple square
AgentNear(green ball) -> yellow ball

INIT TILES:
green ball
green square
yellow pyramid


In [39]:
goal = TileNearGoal(tile_a=TILES_REGISTRY[Tiles.BALL, Colors.GREEN], tile_b=TILES_REGISTRY[Tiles.BALL, Colors.BLUE])

In [42]:
rule_1 = TileNearRightRule(
    tile_a=TILES_REGISTRY[Tiles.SQUARE, Colors.BROWN],
    tile_b=TILES_REGISTRY[Tiles.SQUARE, Colors.BROWN],
    prod_tile=TILES_REGISTRY[Tiles.BALL, Colors.BLUE],
)

In [45]:
rule_2 = AgentNearRightRule(
    tile=TILES_REGISTRY[Tiles.BALL, Colors.BROWN],
    prod_tile=TILES_REGISTRY[Tiles.BALL, Colors.GREEN],
)

In [46]:
rule_3 = TileNearRule(
    tile_a=TILES_REGISTRY[Tiles.KEY, Colors.YELLOW],
    tile_b=TILES_REGISTRY[Tiles.KEY, Colors.BLUE],
    prod_tile=TILES_REGISTRY[Tiles.BALL, Colors.BROWN],
)

In [49]:
rule_4 = AgentNearUpRule(
    tile=TILES_REGISTRY[Tiles.SQUARE, Colors.WHITE],
    prod_tile=TILES_REGISTRY[Tiles.SQUARE, Colors.BROWN],
)

In [50]:
rule_5 = AgentNearRule(
    tile=TILES_REGISTRY[Tiles.SQUARE, Colors.PURPLE],
    prod_tile=TILES_REGISTRY[Tiles.SQUARE, Colors.YELLOW],
)

In [53]:
rules = jnp.array([rule_1.encode(), rule_2.encode(), rule_3.encode(), rule_4.encode(), rule_5.encode()])

In [52]:
initial_tiles = jnp.array((
    TILES_REGISTRY[Tiles.KEY, Colors.YELLOW],
    TILES_REGISTRY[Tiles.KEY, Colors.BLUE],
    TILES_REGISTRY[Tiles.SQUARE, Colors.WHITE],
    TILES_REGISTRY[Tiles.SQUARE, Colors.BROWN],
    TILES_REGISTRY[Tiles.GOAL, Colors.GREEN],
    TILES_REGISTRY[Tiles.GOAL, Colors.WHITE],
    TILES_REGISTRY[Tiles.SQUARE, Colors.PURPLE],
))

In [55]:
ruleset = RuleSet(
    goal=goal.encode(),
    rules=rules,
    init_tiles=initial_tiles
)

In [56]:
print_ruleset(ruleset)

GOAL:
TileNear(green ball, blue ball)

RULES:
TileNearRightRule(brown square, brown square) -> blue ball
AgentNearRightRule(brown ball) -> green ball
TileNear(yellow key, blue key) -> brown ball
AgentNearUpRule(white square) -> brown square
AgentNear(purple square) -> yellow square

INIT TILES:
yellow key
blue key
white square
brown square
green goal
white goal
purple square


### Benchmarks

In [None]:
print("Benchmarks available:", xminigrid.registered_benchmarks())

In [5]:
benchmark_trivial = xminigrid.load_benchmark(name="trivial-1m")
print("Total rulesets:", benchmark_trivial.num_rulesets())
print("Ruleset with id 128: \n", benchmark_trivial.get_ruleset(ruleset_id=128))
print("Random ruleset: \n", benchmark_trivial.sample_ruleset(jax.random.PRNGKey(0)))

Total rulesets: 1000000
Ruleset with id 128: 
 RuleSet(goal=Array([ 9, 12,  3,  5,  3], dtype=int32), rules=Array([[0, 0, 0, 0, 0, 0, 0]], dtype=uint8), init_tiles=Array([[12,  3],
       [ 5,  3],
       [ 5,  9],
       [ 5,  6],
       [ 3,  9]], dtype=uint8))
Random ruleset: 
 RuleSet(goal=Array([ 7,  7, 10,  5,  8], dtype=int32), rules=Array([[0, 0, 0, 0, 0, 0, 0]], dtype=uint8), init_tiles=Array([[ 7, 10],
       [ 5,  8],
       [ 6,  2],
       [ 5, 10],
       [ 7,  1]], dtype=uint8))


In [6]:
benchmark_small = xminigrid.load_benchmark(name="small-1m")
print("Total rulesets:", benchmark_small.num_rulesets())
print("Ruleset with id 128: \n", benchmark_small.get_ruleset(ruleset_id=128))
print("Random ruleset: \n", benchmark_small.sample_ruleset(jax.random.PRNGKey(0)))

Total rulesets: 1000000
Ruleset with id 128: 
 RuleSet(goal=Array([12,  7,  8,  0,  0], dtype=int32), rules=Array([[10, 12, 10, 11,  2,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0]], dtype=uint8), init_tiles=Array([[ 7,  8],
       [ 7, 10],
       [11,  3],
       [12, 10],
       [ 0,  0],
       [ 0,  0],
       [ 0,  0],
       [ 0,  0],
       [ 0,  0],
       [ 0,  0]], dtype=uint8))
Random ruleset: 
 RuleSet(goal=Array([3, 6, 4, 0, 0], dtype=int32), rules=Array([[11,  3,  6,  6,  4,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0]], dtype=uint8), init_tiles=Array([[3, 6],
       [3, 8],
       [5, 8],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0]], dtype=uint8))


In [7]:
benchmark_medium = xminigrid.load_benchmark(name="medium-1m")
print("Total rulesets:", benchmark_medium.num_rulesets())
print("Ruleset with id 128: \n", benchmark_medium.get_ruleset(ruleset_id=128))
print("Random ruleset: \n", benchmark_medium.sample_ruleset(jax.random.PRNGKey(0)))

Total rulesets: 1000000
Ruleset with id 128: 
 RuleSet(goal=Array([12,  3,  3,  0,  0], dtype=int32), rules=Array([[ 7,  3,  9,  5,  2,  3,  3],
       [ 3, 11, 11,  4,  2,  5,  2],
       [10,  3,  2,  3,  9,  0,  0],
       [ 8,  5,  5,  5,  3,  0,  0],
       [ 2, 11,  4,  5,  5,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0,  0]], dtype=uint8), init_tiles=Array([[11, 11],
       [ 4,  2],
       [ 3,  2],
       [12, 11],
       [ 4,  1],
       [ 5,  5],
       [11,  4],
       [ 0,  0],
       [ 0,  0],
       [ 0,  0],
       [ 0,  0],
       [ 0,  0],
       [ 0,  0],
       [ 0,  0],
       [ 0,  0],
       [ 0,  0]], dtype=uint8))
Random ruleset: 
 RuleSet(goal=Array([ 8,  7,  6, 11,  9], dtype=int32), rules=Array([[11,  4,  1, 11,  9,  0,  0],
       [ 6, 12,  4, 12, 11,  7,  6],
       [10,  3,  4, 12,  4,  0,  0],
       [ 7, 12,  3,  3, 11,  4,  1],
       [ 4,  4,  8,

In [8]:
print_ruleset(benchmark_trivial.get_ruleset(ruleset_id=128))

GOAL:
TileNearDownGoal(blue star, blue pyramid)

RULES:

INIT TILES:
blue star
blue pyramid
white pyramid
grey pyramid
white ball


In [9]:
print_ruleset(benchmark_small.get_ruleset(ruleset_id=128))

GOAL:
AgentNearRightGoal(orange key)

RULES:
AgentNearDownRule(brown star) -> green hexagon

INIT TILES:
orange key
brown key
blue hexagon
brown star


In [33]:
print_ruleset(benchmark_medium.get_ruleset(ruleset_id=2000))

GOAL:
TileNear(green ball, blue ball)

RULES:
TileNearRightRule(brown square, brown square) -> blue ball
AgentNearRightRule(brown ball) -> green ball
TileNear(yellow key, blue key) -> brown ball
AgentNearUpRule(white square) -> brown square
AgentNear(purple square) -> yellow square

INIT TILES:
yellow key
blue key
white square
brown square
green goal
white goal
purple square


### Trying own script

In [None]:
import random
from itertools import product

import jax.numpy as jnp
from tqdm.auto import tqdm, trange
from xminigrid.benchmarks import save_bz2_pickle
from xminigrid.core.constants import Colors, Tiles
from xminigrid.core.goals import (
    AgentHoldGoal,
    AgentNearDownGoal,
    AgentNearGoal,
    AgentNearLeftGoal,
    AgentNearRightGoal,
    AgentNearUpGoal,
    TileNearDownGoal,
    TileNearGoal,
    TileNearLeftGoal,
    TileNearRightGoal,
    TileNearUpGoal,
)
from xminigrid.core.grid import pad_along_axis
from xminigrid.core.rules import (
    AgentHoldRule,
    AgentNearDownRule,
    AgentNearLeftRule,
    AgentNearRightRule,
    AgentNearRule,
    AgentNearUpRule,
    EmptyRule,
    TileNearDownRule,
    TileNearLeftRule,
    TileNearRightRule,
    TileNearRule,
    TileNearUpRule,
)

In [None]:
COLORS = [
    Colors.RED,
    Colors.GREEN,
    Colors.BLUE,
    Colors.PURPLE,
    Colors.YELLOW,
    Colors.GREY,
    Colors.WHITE,
    Colors.BROWN,
    Colors.PINK,
    Colors.ORANGE,
]

In [None]:
# we need to distinguish between them, to avoid sampling
# near(goal, goal) goal or rule as goal tiles are not pickable
NEAR_TILES_LHS = list(
    product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS)
)
# these are pickable!
NEAR_TILES_RHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

HOLD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

# to imitate disappearance production rule
PROD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))
PROD_TILES = PROD_TILES + [(Tiles.FLOOR, Colors.BLACK)]

In [None]:
def encode(ruleset):
    flatten_encoding = jnp.concatenate([ruleset["goal"].encode(), *[r.encode() for r in ruleset["rules"]]]).tolist()
    return tuple(flatten_encoding)

In [None]:
def diff(list1, list2):
    return list(set(list1) - set(list2))

In [None]:
goals = (
        AgentHoldGoal,
        # agent near variations
        AgentNearGoal,
        AgentNearUpGoal,
        AgentNearDownGoal,
        AgentNearLeftGoal,
        AgentNearRightGoal,
        # tile near variations
        TileNearGoal,
        TileNearUpGoal,
        TileNearDownGoal,
        TileNearLeftGoal,
        TileNearRightGoal,
    )

In [None]:
rules = (
        AgentHoldRule,
        # agent near variations
        AgentNearRule,
        AgentNearUpRule,
        AgentNearDownRule,
        AgentNearLeftRule,
        AgentNearRightRule,
        # tile near variations
        TileNearRule,
        TileNearUpRule,
        TileNearDownRule,
        TileNearLeftRule,
        TileNearRightRule,
    )

In [None]:
goals[1]

In [None]:
tile_goal = NEAR_TILES_LHS[1]
print(f"tile_goal: {tile_goal}")

In [None]:
goal = goals[1](tile=jnp.array(tile_goal))
goal

In [None]:
used_tiles = []
chain_tiles = []

In [None]:
used_tiles.extend(tile_goal)
chain_tiles.extend(tile_goal)

In [None]:
used_tiles

In [None]:
chain_tiles

In [None]:
rules = []
init_tiles = []

In [None]:
prod_tile = chain_tiles[1]


In [None]:
tile_rule = HOLD_TILES[3]
print(f"tile_rule: {tile_rule}")

In [None]:
rule = rules[1](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))

### My Own Script

In [59]:
import random
import jax.numpy as jnp

from xminigrid.core.constants import Colors, Tiles
from xminigrid.core.goals import *
from xminigrid.core.rules import *
from xminigrid.types import RuleSet

from itertools import product

In [61]:
COLORS = [
    Colors.RED,
    Colors.GREEN,
    Colors.BLUE,
    Colors.PURPLE,
    Colors.YELLOW,
    Colors.GREY,
    Colors.WHITE,
    Colors.BROWN,
    Colors.PINK,
    Colors.ORANGE,
]

In [62]:
# Define tiles and colors
tiles = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

In [86]:
def generate_subtree(depth, prod_tile):
    rules = []
    current_prod_tiles = [prod_tile]

    for i in range(depth):
        next_prod_tiles = []
        for pt in current_prod_tiles:
            tile_a = random.choice(tiles)
            tile_b = random.choice(tiles)

            # Example rule: tiles near each other cause a production
            rule = TileNearRule(tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b), prod_tile=jnp.array(pt))
            rules.append(rule)
            next_prod_tiles.append(tile_a)  # Chaining the effect
            next_prod_tiles.append(tile_b)

        current_prod_tiles = next_prod_tiles

    return rules

In [99]:
prod_tile = (Tiles.KEY, Colors.BLUE)

In [100]:
# Generate a shared subtree
shared_subtree = generate_subtree(2, prod_tile)

In [101]:
shared_subtree

[TileNearRule(tile_a=Array([3, 8], dtype=int32), tile_b=Array([7, 4], dtype=int32), prod_tile=Array([7, 3], dtype=int32)),
 TileNearRule(tile_a=Array([3, 5], dtype=int32), tile_b=Array([5, 1], dtype=int32), prod_tile=Array([3, 8], dtype=int32)),
 TileNearRule(tile_a=Array([12,  6], dtype=int32), tile_b=Array([12, 10], dtype=int32), prod_tile=Array([7, 4], dtype=int32))]

In [102]:
def generate_ruleset_with_shared_subtree(subtree, goal, init_tiles):
    rules = subtree.copy()  # Use the pre-generated subtree
    return RuleSet(
        goal=goal.encode(),
        rules=jnp.array([r.encode() for r in rules]),
        init_tiles=jnp.array(init_tiles)  # Ensure this is correctly formatted
    )

In [103]:
random_init_tiles = [random.choice(tiles) for _ in range(5)]

In [105]:
goals = [
    AgentNearGoal(tile=jnp.array(prod_tile)),
    AgentNearUpGoal(tile=jnp.array(prod_tile)),
    AgentNearDownGoal(tile=jnp.array(prod_tile))
]

In [106]:
rulesets = [generate_ruleset_with_shared_subtree(shared_subtree, goal, random_init_tiles) for goal in goals]

In [107]:
rulesets

[RuleSet(goal=Array([3, 7, 3, 0, 0], dtype=uint8), rules=Array([[ 3,  3,  8,  7,  4,  7,  3],
        [ 3,  3,  5,  5,  1,  3,  8],
        [ 3, 12,  6, 12, 10,  7,  4]], dtype=uint8), init_tiles=Array([[ 7,  3],
        [ 3,  1],
        [11,  8],
        [ 5, 11],
        [ 5,  4]], dtype=int32)),
 RuleSet(goal=Array([11,  7,  3,  0,  0], dtype=int32), rules=Array([[ 3,  3,  8,  7,  4,  7,  3],
        [ 3,  3,  5,  5,  1,  3,  8],
        [ 3, 12,  6, 12, 10,  7,  4]], dtype=uint8), init_tiles=Array([[ 7,  3],
        [ 3,  1],
        [11,  8],
        [ 5, 11],
        [ 5,  4]], dtype=int32)),
 RuleSet(goal=Array([13,  7,  3,  0,  0], dtype=int32), rules=Array([[ 3,  3,  8,  7,  4,  7,  3],
        [ 3,  3,  5,  5,  1,  3,  8],
        [ 3, 12,  6, 12, 10,  7,  4]], dtype=uint8), init_tiles=Array([[ 7,  3],
        [ 3,  1],
        [11,  8],
        [ 5, 11],
        [ 5,  4]], dtype=int32))]

In [108]:
for ruleset in rulesets:
    print_ruleset(ruleset)
    print("\n===============")

GOAL:
AgentNear(blue key)

RULES:
TileNear(orange ball, purple key) -> blue key
TileNear(yellow ball, red pyramid) -> orange ball
TileNear(grey star, brown star) -> purple key

INIT TILES:
blue key
red ball
orange hexagon
pink pyramid
purple pyramid

GOAL:
AgentNearUpGoal(blue key)

RULES:
TileNear(orange ball, purple key) -> blue key
TileNear(yellow ball, red pyramid) -> orange ball
TileNear(grey star, brown star) -> purple key

INIT TILES:
blue key
red ball
orange hexagon
pink pyramid
purple pyramid

GOAL:
AgentNearDownGoal(blue key)

RULES:
TileNear(orange ball, purple key) -> blue key
TileNear(yellow ball, red pyramid) -> orange ball
TileNear(grey star, brown star) -> purple key

INIT TILES:
blue key
red ball
orange hexagon
pink pyramid
purple pyramid



In [109]:
prod_tile_2 = (Tiles.SQUARE, Colors.RED)

In [118]:
# Generate a shared subtree
shared_subtree_2 = generate_subtree(3, prod_tile_2)

In [119]:
goals_2 = [
    AgentNearGoal(tile=jnp.array(prod_tile_2)),
    AgentNearUpGoal(tile=jnp.array(prod_tile_2)),
    AgentNearDownGoal(tile=jnp.array(prod_tile_2))
]

In [120]:
rulesets = [generate_ruleset_with_shared_subtree(shared_subtree_2, goal, random_init_tiles) for goal in goals_2]

In [121]:
for ruleset in rulesets:
    print_ruleset(ruleset)
    print("\n===============")

GOAL:
AgentNear(red square)

RULES:
TileNear(green square, blue hexagon) -> red square
TileNear(yellow star, brown key) -> green square
TileNear(pink pyramid, red ball) -> blue hexagon
TileNear(pink key, grey star) -> yellow star
TileNear(orange pyramid, brown key) -> brown key
TileNear(blue pyramid, grey key) -> pink pyramid
TileNear(red ball, orange ball) -> red ball

INIT TILES:
blue key
red ball
orange hexagon
pink pyramid
purple pyramid

GOAL:
AgentNearUpGoal(red square)

RULES:
TileNear(green square, blue hexagon) -> red square
TileNear(yellow star, brown key) -> green square
TileNear(pink pyramid, red ball) -> blue hexagon
TileNear(pink key, grey star) -> yellow star
TileNear(orange pyramid, brown key) -> brown key
TileNear(blue pyramid, grey key) -> pink pyramid
TileNear(red ball, orange ball) -> red ball

INIT TILES:
blue key
red ball
orange hexagon
pink pyramid
purple pyramid

GOAL:
AgentNearDownGoal(red square)

RULES:
TileNear(green square, blue hexagon) -> red square
TileN

In [122]:
import random
from itertools import product

import jax.numpy as jnp
from tqdm.auto import tqdm, trange
from xminigrid.benchmarks import save_bz2_pickle
from xminigrid.core.constants import Colors, Tiles
from xminigrid.core.goals import (
    AgentHoldGoal,
    AgentNearDownGoal,
    AgentNearGoal,
    AgentNearLeftGoal,
    AgentNearRightGoal,
    AgentNearUpGoal,
    TileNearDownGoal,
    TileNearGoal,
    TileNearLeftGoal,
    TileNearRightGoal,
    TileNearUpGoal,
)
from xminigrid.core.grid import pad_along_axis
from xminigrid.core.rules import (
    AgentHoldRule,
    AgentNearDownRule,
    AgentNearLeftRule,
    AgentNearRightRule,
    AgentNearRule,
    AgentNearUpRule,
    EmptyRule,
    TileNearDownRule,
    TileNearLeftRule,
    TileNearRightRule,
    TileNearRule,
    TileNearUpRule,
)

In [123]:
COLORS = [
    Colors.RED,
    Colors.GREEN,
    Colors.BLUE,
    Colors.PURPLE,
    Colors.YELLOW,
    Colors.GREY,
    Colors.WHITE,
    Colors.BROWN,
    Colors.PINK,
    Colors.ORANGE,
]

In [124]:
# we need to distinguish between them, to avoid sampling
# near(goal, goal) goal or rule as goal tiles are not pickable
NEAR_TILES_LHS = list(
    product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS)
)
# these are pickable!
NEAR_TILES_RHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

HOLD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

# to imitate disappearance production rule
PROD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))
PROD_TILES = PROD_TILES + [(Tiles.FLOOR, Colors.BLACK)]

In [125]:
def encode(ruleset):
    flatten_encoding = jnp.concatenate([ruleset["goal"].encode(), *[r.encode() for r in ruleset["rules"]]]).tolist()
    return tuple(flatten_encoding)


def diff(list1, list2):
    return list(set(list1) - set(list2))

In [138]:
import jax
import jax.numpy as jnp

from xminigrid.core.constants import Colors, Tiles
from xminigrid.types import AgentState, RuleSet

COLOR_NAMES = {
    Colors.EMPTY: "white",
    Colors.RED: "red",
    Colors.GREEN: "green",
    Colors.BLUE: "blue",
    Colors.PURPLE: "purple",
    Colors.YELLOW: "yellow",
    Colors.GREY: "grey",
    Colors.BLACK: "black",
    Colors.ORANGE: "orange",
    Colors.WHITE: "white",
    Colors.BROWN: "brown",
    Colors.PINK: "pink",
}

TILE_STR = {
    Tiles.EMPTY: " ",
    Tiles.FLOOR: ".",
    Tiles.WALL: "☰",
    Tiles.BALL: "⏺",
    Tiles.SQUARE: "▪",
    Tiles.PYRAMID: "▲",
    Tiles.HEX: "⬢",
    Tiles.STAR: "★",
    Tiles.GOAL: "■",
    Tiles.DOOR_LOCKED: "L",
    Tiles.DOOR_CLOSED: "C",
    Tiles.DOOR_OPEN: "O",
    Tiles.KEY: "K",
}

# for ruleset printing
RULE_TILE_STR = {
    Tiles.FLOOR: "floor",
    Tiles.BALL: "ball",
    Tiles.SQUARE: "square",
    Tiles.PYRAMID: "pyramid",
    Tiles.GOAL: "goal",
    Tiles.KEY: "key",
    Tiles.HEX: "hexagon",
    Tiles.STAR: "star",
}

PLAYER_STR = {0: "^", 1: ">", 2: "V", 3: "<"}


def _wrap_with_color(string: str, color: str) -> str:
    return f"[bold {color}]{string}[/bold {color}]"


# WARN: will NOT work under jit and needed for debugging mainly.
def render(grid: jax.Array, agent: AgentState | None = None) -> str:
    string = ""

    for y in range(grid.shape[0]):
        for x in range(grid.shape[1]):
            tile_id, tile_color = grid[y, x]
            tile_str = TILE_STR[tile_id.item()]
            tile_color = COLOR_NAMES[tile_color.item()]

            if agent is not None and jnp.all(agent.position == jnp.array((y, x))):
                tile_str = PLAYER_STR[agent.direction.item()]
                tile_color = COLOR_NAMES[Colors.RED]

            string += _wrap_with_color(tile_str, tile_color)

        if y < grid.shape[0] - 1:
            string += "\n"

    return string


# WARN: This is for debugging mainly! Will refactor later if needed.
def _encode_tile(tile: list[int]) -> str:
    return f"{COLOR_NAMES[tile[1]]} {RULE_TILE_STR[tile[0]]}"


def _text_encode_goal(goal: list[int]) -> str:
    goal_id = goal[0]
    if goal_id == 1:
        return f"AgentHold({_encode_tile(goal[1:3])})"
    elif goal_id == 3:
        return f"AgentNear({_encode_tile(goal[1:3])})"
    elif goal_id == 4:
        return f"TileNear({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 7:
        return f"TileNearUpGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 8:
        return f"TileNearRightGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 9:
        return f"TileNearDownGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 10:
        return f"TileNearLeftGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 11:
        return f"AgentNearUpGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 12:
        return f"AgentNearRightGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 13:
        return f"AgentNearDownGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 14:
        return f"AgentNearLeftGoal({_encode_tile(goal[1:3])})"
    else:
        raise RuntimeError(f"Rendering: Unknown goal id: {goal_id}")


def _text_encode_rule(rule: list[int]) -> str:
    rule_id = rule[0]
    if rule_id == 1:
        return f"AgentHold({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 2:
        return f"AgentNear({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 3:
        return f"TileNear({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 4:
        return f"TileNearUpRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 5:
        return f"TileNearRightRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 6:
        return f"TileNearDownRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 7:
        return f"TileNearLeftRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 8:
        return f"AgentNearUpRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 9:
        return f"AgentNearRightRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 10:
        return f"AgentNearDownRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 11:
        return f"AgentNearLeftRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    else:
        raise RuntimeError(f"Rendering: Unknown rule id: {rule_id}")


def print_ruleset(ruleset: RuleSet) -> None:
    print("GOAL:")
    print(_text_encode_goal(ruleset.goal.tolist()))
    print()
    print("RULES:")
    for rule in ruleset.rules.tolist():
        if rule[0] != 0:
            print(_text_encode_rule(rule))
    print()
    print("INIT TILES:")
    for tile in ruleset.init_tiles.tolist():
        if tile[0] != 0:
            print(_encode_tile(tile))

In [192]:
def sample_goal():
    goals = (
        AgentHoldGoal,
        # agent near variations
        AgentNearGoal,
        AgentNearUpGoal,
        AgentNearDownGoal,
        AgentNearLeftGoal,
        AgentNearRightGoal,
        # tile near variations
        TileNearGoal,
        TileNearUpGoal,
        TileNearDownGoal,
        TileNearLeftGoal,
        TileNearRightGoal,
    )
    goal_idx = random.randint(0, 10)
    if goal_idx == 0:
        tile = random.choice(HOLD_TILES)
        goal = goals[0](tile=jnp.array(tile))
        print(goals[goal_idx], _encode_tile(tile))
        return goal, (tile,)
    elif 1 <= goal_idx <= 5:
        tile = random.choice(NEAR_TILES_LHS)
        goal = goals[goal_idx](tile=jnp.array(tile))
        print(goals[goal_idx], _encode_tile(tile))
        return goal, (tile,)
    elif 6 <= goal_idx <= 10:
        tile_a = random.choice(NEAR_TILES_LHS)
        tile_b = random.choice(NEAR_TILES_RHS)
        goal = goals[goal_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b))
        print(goals[goal_idx], _encode_tile(tile_a), _encode_tile(tile_b))
        return goal, (tile_a, tile_b)
    else:
        raise RuntimeError("Unknown goal")

In [194]:
sample_goal()

<class 'xminigrid.core.goals.AgentNearLeftGoal'> blue ball


(AgentNearLeftGoal(tile=Array([3, 3], dtype=int32)), ((3, 3),))

In [196]:
import random
from itertools import product

import jax.numpy as jnp
from tqdm.auto import tqdm, trange
from xminigrid.benchmarks import save_bz2_pickle
from xminigrid.core.constants import Colors, Tiles
from xminigrid.core.goals import (
    AgentHoldGoal,
    AgentNearDownGoal,
    AgentNearGoal,
    AgentNearLeftGoal,
    AgentNearRightGoal,
    AgentNearUpGoal,
    TileNearDownGoal,
    TileNearGoal,
    TileNearLeftGoal,
    TileNearRightGoal,
    TileNearUpGoal,
)
from xminigrid.core.grid import pad_along_axis
from xminigrid.core.rules import (
    AgentHoldRule,
    AgentNearDownRule,
    AgentNearLeftRule,
    AgentNearRightRule,
    AgentNearRule,
    AgentNearUpRule,
    EmptyRule,
    TileNearDownRule,
    TileNearLeftRule,
    TileNearRightRule,
    TileNearRule,
    TileNearUpRule,
)

In [197]:
COLORS = [
    Colors.RED,
    Colors.GREEN,
    Colors.BLUE,
    Colors.PURPLE,
    Colors.YELLOW,
    Colors.GREY,
    Colors.WHITE,
    Colors.BROWN,
    Colors.PINK,
    Colors.ORANGE,
]

In [198]:
# we need to distinguish between them, to avoid sampling
# near(goal, goal) goal or rule as goal tiles are not pickable
NEAR_TILES_LHS = list(
    product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS)
)
# these are pickable!
NEAR_TILES_RHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

HOLD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

# to imitate disappearance production rule
PROD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))
PROD_TILES = PROD_TILES + [(Tiles.FLOOR, Colors.BLACK)]

In [199]:
def encode(ruleset):
    flatten_encoding = jnp.concatenate([ruleset["goal"].encode(), *[r.encode() for r in ruleset["rules"]]]).tolist()
    return tuple(flatten_encoding)


def diff(list1, list2):
    return list(set(list1) - set(list2))

In [201]:
def sample_rule(prod_tile, used_tiles):
    rules = (
        AgentHoldRule,
        # agent near variations
        AgentNearRule,
        AgentNearUpRule,
        AgentNearDownRule,
        AgentNearLeftRule,
        AgentNearRightRule,
        # tile near variations
        TileNearRule,
        TileNearUpRule,
        TileNearDownRule,
        TileNearLeftRule,
        TileNearRightRule,
    )
    rule_idx = random.randint(0, 10)

    if rule_idx == 0:
        tile = random.choice(diff(HOLD_TILES, used_tiles))
        rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
        return rule, (tile,)
    elif 1 <= rule_idx <= 5:
        tile = random.choice(diff(HOLD_TILES, used_tiles))
        rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
        return rule, (tile,)
    elif 6 <= rule_idx <= 10:
        tile_a = random.choice(diff(NEAR_TILES_LHS, used_tiles))
        tile_b = random.choice(diff(NEAR_TILES_RHS, used_tiles))
        rule = rules[rule_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b), prod_tile=jnp.array(prod_tile))
        return rule, (tile_a, tile_b)
    else:
        raise RuntimeError("Unknown rule")

In [406]:
def sample_goal_tiles():
    # This function now returns just the tiles, not the goal object
    goal_idx = random.randint(0, 10)
    print(goal_idx)
    if goal_idx == 0:
        tile = random.choice(HOLD_TILES)
    elif 1 <= goal_idx <= 5:
        tile = random.choice(NEAR_TILES_LHS)
    else:
        tile_a = random.choice(NEAR_TILES_LHS)
        tile_b = random.choice(NEAR_TILES_RHS)
        return goal_idx, (tile_a, tile_b)
    return goal_idx, (tile,)

In [407]:
def sample_goal(goal_idx, goal_tiles):
    goals = (
        AgentHoldGoal,
        # Agent near variations
        AgentNearGoal,
        AgentNearUpGoal,
        AgentNearDownGoal,
        AgentNearLeftGoal,
        AgentNearRightGoal,
        # Tile near variations
        TileNearGoal,
        TileNearUpGoal,
        TileNearDownGoal,
        TileNearLeftGoal,
        TileNearRightGoal,
    )
    
    if goal_idx == 0:
        return goals[0](tile=jnp.array(goal_tiles[0]))
    elif 1 <= goal_idx <= 5:
        selected_goal = random.choice(goals[1:6])
        return selected_goal(tile=jnp.array(goal_tiles[0]))
    elif 6 <= goal_idx <= 10:
        selected_goal = random.choice(goals[6:])
        return selected_goal(tile_a=jnp.array(goal_tiles[0]), tile_b=jnp.array(goal_tiles[1]))


In [409]:
goal_idx, goal_tiles = sample_goal_tiles()
goal_tiles

7


((12, 8), (12, 1))

In [415]:
goal_idx

7

In [423]:
goal = sample_goal(goal_idx, goal_tiles)
goal

TileNearRightGoal(tile_a=Array([12,  8], dtype=int32), tile_b=Array([12,  1], dtype=int32))

In [437]:
def pre_sample_rules(
    chain_depth: int, 
    num_distractor_rules: int,
    num_distractor_objects: int,
    sample_depth: bool,
    sample_distractor_rules: bool,
    prune_chain: bool,
    prune_prob: float = 0.0
):
    used_tiles = []
    rules = []
    init_tiles = []
    chain_tiles = [random.choice(PROD_TILES)]  # Start chain with a random product tile

    num_levels = random.randint(0, chain_depth) if sample_depth else chain_depth

    for level in range(num_levels):
        next_chain_tiles = []
        while chain_tiles:
            prod_tile = chain_tiles.pop()
            if prune_chain and random.random() < prune_prob:
                init_tiles.append(prod_tile)
                continue

            rule, rule_tiles = sample_rule(prod_tile, used_tiles)
            used_tiles.extend(rule_tiles)
            next_chain_tiles.extend(rule_tiles)
            rules.append(rule)

        chain_tiles = next_chain_tiles

    if sample_distractor_rules:
        for _ in range(num_distractor_rules):
            prod_tile = random.choice(diff(PROD_TILES, used_tiles))
            rule, rule_tiles = sample_rule(prod_tile, [])
            rules.append(rule)
            init_tiles.extend(rule_tiles)

    if not rules:  # Ensure there's at least one rule
        rules.append(EmptyRule())

    return rules, used_tiles, init_tiles

In [440]:
def sample_ruleset(goal_idx, goal_tiles, pre_sampled_rules, pre_sampled_init_tiles, num_distractor_objects):
    goal = sample_goal(goal_idx, goal_tiles)  # Dynamic goal sampling
    init_tiles = list(pre_sampled_init_tiles)  # Start with pre-sampled initial tiles

    # Adding goal tiles to initial tiles ensuring they don't overlap with used tiles
    init_tiles.extend([tile for tile in goal_tiles if tile not in init_tiles])

    # Adding a few more distractor objects if needed
    init_tiles.extend(random.choices(diff(NEAR_TILES_LHS, init_tiles), k=num_distractor_objects))

    return {
        "goal": goal,
        "rules": pre_sampled_rules,  # Use pre-sampled rules
        "init_tiles": init_tiles,
        "num_rules": len([r for r in pre_sampled_rules if not isinstance(r, EmptyRule)]),
    }

In [461]:
# Example of using the modified functions
pre_sampled_rules, used_tiles, pre_sampled_init_tiles = pre_sample_rules(chain_depth=3, 
                                                                         num_distractor_rules=0, 
                                                                         num_distractor_objects=0, 
                                                                         sample_depth=False, 
                                                                         sample_distractor_rules=False, 
                                                                         prune_chain=False, 
                                                                         prune_prob=0.0)

rulesets = []
for _ in range(5):
    ruleset = sample_ruleset(goal_idx, goal_tiles, pre_sampled_rules, pre_sampled_init_tiles, num_distractor_objects=0)
    rulesets.append({
                "goal": ruleset["goal"].encode(),
                "rules": jnp.vstack([r.encode() for r in ruleset["rules"]]),
                "init_tiles": jnp.array(ruleset["init_tiles"], dtype=jnp.uint8),
                "num_rules": jnp.asarray(ruleset["num_rules"], dtype=jnp.uint8),
    })

In [462]:
rulesets

[{'goal': Array([ 4,  6,  6, 11, 11], dtype=uint8),
  'rules': Array([[ 3,  6,  8, 11,  5, 12,  6],
         [ 4,  4,  3,  7,  6, 11,  5],
         [ 4, 11,  2,  7,  9,  6,  8],
         [ 8,  5,  4,  7,  9,  0,  0],
         [ 4,  4, 10,  5, 11, 11,  2],
         [ 7,  5,  2,  3,  4,  7,  6],
         [ 2,  7,  4,  4,  3,  0,  0]], dtype=uint8),
  'init_tiles': Array([[ 6,  6],
         [11, 11]], dtype=uint8),
  'num_rules': Array(7, dtype=uint8)},
 {'goal': Array([ 7,  6,  6, 11, 11], dtype=int32),
  'rules': Array([[ 3,  6,  8, 11,  5, 12,  6],
         [ 4,  4,  3,  7,  6, 11,  5],
         [ 4, 11,  2,  7,  9,  6,  8],
         [ 8,  5,  4,  7,  9,  0,  0],
         [ 4,  4, 10,  5, 11, 11,  2],
         [ 7,  5,  2,  3,  4,  7,  6],
         [ 2,  7,  4,  4,  3,  0,  0]], dtype=uint8),
  'init_tiles': Array([[ 6,  6],
         [11, 11]], dtype=uint8),
  'num_rules': Array(7, dtype=uint8)},
 {'goal': Array([ 7,  6,  6, 11, 11], dtype=int32),
  'rules': Array([[ 3,  6,  8, 11,  5,

In [463]:
RuleSet = namedtuple('RuleSet', ['goal', 'rules', 'init_tiles'])

for ruleset in rulesets:
    # Convert dictionary to namedtuple
    ruleset_nt = RuleSet(
        goal=ruleset['goal'],
        rules=ruleset['rules'],
        init_tiles=ruleset['init_tiles']
    )
    print_ruleset(ruleset_nt)
    print("\n===============================")

GOAL:
TileNear(grey goal, pink hexagon)

RULES:
TileNear(orange goal, yellow hexagon) -> grey star
TileNearUpRule(blue square, grey key) -> yellow hexagon
TileNearUpRule(green hexagon, white key) -> orange goal
AgentNearUpRule(purple pyramid) -> white key
TileNearUpRule(brown square, pink pyramid) -> green hexagon
TileNearLeftRule(green pyramid, purple ball) -> grey key
AgentNear(purple key) -> blue square

INIT TILES:
grey goal
pink hexagon

GOAL:
TileNearUpGoal(grey goal, pink hexagon)

RULES:
TileNear(orange goal, yellow hexagon) -> grey star
TileNearUpRule(blue square, grey key) -> yellow hexagon
TileNearUpRule(green hexagon, white key) -> orange goal
AgentNearUpRule(purple pyramid) -> white key
TileNearUpRule(brown square, pink pyramid) -> green hexagon
TileNearLeftRule(green pyramid, purple ball) -> grey key
AgentNear(purple key) -> blue square

INIT TILES:
grey goal
pink hexagon

GOAL:
TileNearUpGoal(grey goal, pink hexagon)

RULES:
TileNear(orange goal, yellow hexagon) -> grey 

In [424]:
def sample_ruleset(
    goal_idx,
    goal_tiles,
    chain_depth: int,
    num_distractor_rules: int,
    num_distractor_objects: int,
    sample_depth: bool,
    sample_distractor_rules: bool,
    prune_chain: bool,
    # actually, we can vary prune_prob on each sample to diversify even further
    prune_prob: float = 0.0,
):
    used_tiles = list(goal_tiles)
    goal = sample_goal(goal_idx, goal_tiles)

    # sample main rules in a chain
    rules = []
    init_tiles = []

    num_levels = random.randint(0, chain_depth) if sample_depth else chain_depth
    # there is no rules, just one goal, thus we need to add goal tiles to init tiles
    if num_levels == 0:
        # WARN: you really should add distractor objects in this case, as without them goal will be obvious
        init_tiles.extend(goal_tiles)
        # one empty rule as a placeholder, to fill up "rule" key, this will not introduce overhead under jit
        rules.append(EmptyRule())

    chain_tiles = list(goal_tiles)  # Starting point for chaining rules
    
    # for logging
    for level in range(num_levels):
        next_chain_tiles = []

        # sampling in a chain, rules inputs from previous layer are rule results from this layer
        while chain_tiles:
            prod_tile = chain_tiles.pop()
            if prune_chain and random.random() < prune_prob:
                # prune this branch and add this tile to initial tiles
                init_tiles.append(prod_tile)
                continue

            rule, rule_tiles = sample_rule(prod_tile, used_tiles)
            used_tiles.extend(rule_tiles)
            next_chain_tiles.extend(rule_tiles)
            rules.append(rule)

            # inputs to the last rules in the chain are the initial tiles for the current level
            if level == num_levels - 1:
                init_tiles.extend(rule_tiles)

        chain_tiles = next_chain_tiles

    # sample distractor objects
    init_tiles.extend(random.choices(diff(NEAR_TILES_LHS, used_tiles), k=num_distractor_objects))
    # sample distractor rules
    if sample_distractor_rules:
        num_distractor_rules = random.randint(0, num_distractor_rules)

    for _ in range(num_distractor_rules):
        prod_tile = random.choice(diff(PROD_TILES, used_tiles))
        # distractors can include already sampled tiles, to create dead-end rules
        rule, rule_tiles = sample_rule(prod_tile, used_tiles=[])
        rules.append(rule)
        init_tiles.extend(rule_tiles)

    # if for some reason there are no rules, add one empty (we will ignore it later)
    if len(rules) == 0:
        rules.append(EmptyRule())

    return {
        "goal": goal,
        "rules": rules,
        "init_tiles": init_tiles,
        # additional info (for example for biasing sampling by number of rules)
        # you can add other field if needed, just copy-paste this file!
        # saving counts, as later they will be padded to the same size
        "num_rules": len([r for r in rules if not isinstance(r, EmptyRule)]),
    }

In [425]:
# Sample goal tiles once, outside the ruleset generation loop
goal_idx, goal_tiles = sample_goal_tiles()

7


In [429]:
chain_depth = 2
num_distractor_rules = 0
num_distractor_objects = 0
sample_depth = False
sample_distractor_rules = False
prune_chain = True
prune_prob = 0

In [430]:
# Generate multiple rulesets with the same goal tiles but potentially different goals
rulesets = []
for _ in range(10):  # Creating 10 different rulesets
    ruleset = sample_ruleset(
        goal_idx, goal_tiles, chain_depth, num_distractor_rules,
        num_distractor_objects, sample_depth, sample_distractor_rules,
        prune_chain, prune_prob
    )
    rulesets.append(
            {
                "goal": ruleset["goal"].encode(),
                "rules": jnp.vstack([r.encode() for r in ruleset["rules"]]),
                "init_tiles": jnp.array(ruleset["init_tiles"], dtype=jnp.uint8),
                "num_rules": jnp.asarray(ruleset["num_rules"], dtype=jnp.uint8),
            }
    )

In [431]:
from collections import namedtuple

RuleSet = namedtuple('RuleSet', ['goal', 'rules', 'init_tiles'])

for ruleset in rulesets:
    # Convert dictionary to namedtuple
    ruleset_nt = RuleSet(
        goal=ruleset['goal'],
        rules=ruleset['rules'],
        init_tiles=ruleset['init_tiles']
    )
    print_ruleset(ruleset_nt)
    print("\n===============================")


GOAL:
TileNearDownGoal(grey goal, pink hexagon)

RULES:
AgentNearDownRule(green ball) -> pink hexagon
AgentHold(red pyramid) -> grey goal
AgentNearRightRule(pink ball) -> red pyramid
TileNearLeftRule(orange star, yellow square) -> green ball

INIT TILES:
pink ball
orange star
yellow square

GOAL:
TileNear(grey goal, pink hexagon)

RULES:
AgentHold(white pyramid) -> pink hexagon
TileNearDownRule(pink square, purple pyramid) -> grey goal
TileNearDownRule(orange star, grey pyramid) -> purple pyramid
TileNearRightRule(purple square, brown pyramid) -> pink square
AgentNearRightRule(green ball) -> white pyramid

INIT TILES:
orange star
grey pyramid
purple square
brown pyramid
green ball

GOAL:
TileNearLeftGoal(grey goal, pink hexagon)

RULES:
TileNearUpRule(blue key, grey square) -> pink hexagon
AgentNearDownRule(grey star) -> grey goal
AgentNearRightRule(white key) -> grey star
TileNearDownRule(green key, pink square) -> grey square
TileNear(red hexagon, orange ball) -> blue key

INIT TILES

In [386]:
def sample_fixed_rules(goal_tiles, chain_depth, used_tiles):
    rules = []
    init_tiles = []
    chain_tiles = list(goal_tiles)

    for level in range(chain_depth):
        next_chain_tiles = []
        while chain_tiles:
            prod_tile = chain_tiles.pop()
            rule, rule_tiles = sample_rule(prod_tile, used_tiles)
            used_tiles.extend(rule_tiles)
            next_chain_tiles.extend(rule_tiles)
            rules.append(rule)
        chain_tiles = next_chain_tiles

    return rules, used_tiles, init_tiles

In [387]:
# Pre-sample rules
used_tiles = list(goal_tiles)  # This could be an empty list if you want completely independent rules
fixed_rules, used_tiles, init_tiles = sample_fixed_rules(goal_tiles, chain_depth, used_tiles)

In [388]:
def sample_ruleset(goal_idx, goal_tiles):
    goal = sample_goal(goal_idx, goal_tiles)
    return {
        "goal": goal,
        "rules": fixed_rules,  # Use the pre-sampled fixed rules
        "init_tiles": init_tiles,
        "num_rules": len([r for r in fixed_rules if not isinstance(r, EmptyRule)]),
    }

In [391]:
# Generate multiple rulesets with different goals
rulesets = []
for _ in range(10):  # Creating 10 different rulesets
    ruleset = sample_ruleset(goal_idx, goal_tiles)
    rulesets.append({
        "goal": ruleset["goal"].encode(),
        "rules": jnp.vstack([r.encode() for r in ruleset["rules"]]),
        "init_tiles": jnp.array(ruleset["init_tiles"], dtype=jnp.uint8),
        "num_rules": jnp.asarray(ruleset["num_rules"], dtype=jnp.uint8),
    })

In [392]:
from collections import namedtuple

RuleSet = namedtuple('RuleSet', ['goal', 'rules', 'init_tiles'])

for ruleset in rulesets:
    # Convert dictionary to namedtuple
    ruleset_nt = RuleSet(
        goal=ruleset['goal'],
        rules=ruleset['rules'],
        init_tiles=ruleset['init_tiles']
    )
    print_ruleset(ruleset_nt)
    print("\n===============================")


GOAL:
AgentNearRightGoal(pink goal)

RULES:
TileNearUpRule(orange square, purple star) -> pink goal
AgentNearRightRule(orange hexagon) -> purple star
TileNearUpRule(purple key, yellow ball) -> orange square

INIT TILES:

GOAL:
AgentNearRightGoal(pink goal)

RULES:
TileNearUpRule(orange square, purple star) -> pink goal
AgentNearRightRule(orange hexagon) -> purple star
TileNearUpRule(purple key, yellow ball) -> orange square

INIT TILES:

GOAL:
AgentNearLeftGoal(pink goal)

RULES:
TileNearUpRule(orange square, purple star) -> pink goal
AgentNearRightRule(orange hexagon) -> purple star
TileNearUpRule(purple key, yellow ball) -> orange square

INIT TILES:

GOAL:
AgentNear(pink goal)

RULES:
TileNearUpRule(orange square, purple star) -> pink goal
AgentNearRightRule(orange hexagon) -> purple star
TileNearUpRule(purple key, yellow ball) -> orange square

INIT TILES:

GOAL:
AgentNearUpGoal(pink goal)

RULES:
TileNearUpRule(orange square, purple star) -> pink goal
AgentNearRightRule(orange hex