### Utils for printing

In [None]:
import jax
import jax.numpy as jnp

from xminigrid.core.constants import Colors, Tiles
from xminigrid.types import AgentState, RuleSet

In [None]:
COLOR_NAMES = {
    Colors.EMPTY: "white",
    Colors.RED: "red",
    Colors.GREEN: "green",
    Colors.BLUE: "blue",
    Colors.PURPLE: "purple",
    Colors.YELLOW: "yellow",
    Colors.GREY: "grey",
    Colors.BLACK: "black",
    Colors.ORANGE: "orange",
    Colors.WHITE: "white",
    Colors.BROWN: "brown",
    Colors.PINK: "pink",
}

TILE_STR = {
    Tiles.EMPTY: " ",
    Tiles.FLOOR: ".",
    Tiles.WALL: "☰",
    Tiles.BALL: "⏺",
    Tiles.SQUARE: "▪",
    Tiles.PYRAMID: "▲",
    Tiles.HEX: "⬢",
    Tiles.STAR: "★",
    Tiles.GOAL: "■",
    Tiles.DOOR_LOCKED: "L",
    Tiles.DOOR_CLOSED: "C",
    Tiles.DOOR_OPEN: "O",
    Tiles.KEY: "K",
}

# for ruleset printing
RULE_TILE_STR = {
    Tiles.FLOOR: "floor",
    Tiles.BALL: "ball",
    Tiles.SQUARE: "square",
    Tiles.PYRAMID: "pyramid",
    Tiles.GOAL: "goal",
    Tiles.KEY: "key",
    Tiles.HEX: "hexagon",
    Tiles.STAR: "star",
}

PLAYER_STR = {0: "^", 1: ">", 2: "V", 3: "<"}

In [None]:
def _wrap_with_color(string: str, color: str) -> str:
    return f"[bold {color}]{string}[/bold {color}]"


# WARN: will NOT work under jit and needed for debugging mainly.
def render(grid: jax.Array, agent: AgentState | None = None) -> str:
    string = ""

    for y in range(grid.shape[0]):
        for x in range(grid.shape[1]):
            tile_id, tile_color = grid[y, x]
            tile_str = TILE_STR[tile_id.item()]
            tile_color = COLOR_NAMES[tile_color.item()]

            if agent is not None and jnp.all(agent.position == jnp.array((y, x))):
                tile_str = PLAYER_STR[agent.direction.item()]
                tile_color = COLOR_NAMES[Colors.RED]

            string += _wrap_with_color(tile_str, tile_color)

        if y < grid.shape[0] - 1:
            string += "\n"

    return string

In [None]:
# WARN: This is for debugging mainly! Will refactor later if needed.
def _encode_tile(tile: list[int]) -> str:
    return f"{COLOR_NAMES[tile[1]]} {RULE_TILE_STR[tile[0]]}"

In [None]:
def _text_encode_goal(goal: list[int]) -> str:
    goal_id = goal[0]
    if goal_id == 1:
        return f"AgentHold({_encode_tile(goal[1:3])})"
    elif goal_id == 3:
        return f"AgentNear({_encode_tile(goal[1:3])})"
    elif goal_id == 4:
        return f"TileNear({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 7:
        return f"TileNearUpGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 8:
        return f"TileNearRightGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 9:
        return f"TileNearDownGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 10:
        return f"TileNearLeftGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 11:
        return f"AgentNearUpGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 12:
        return f"AgentNearRightGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 13:
        return f"AgentNearDownGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 14:
        return f"AgentNearLeftGoal({_encode_tile(goal[1:3])})"
    else:
        raise RuntimeError(f"Rendering: Unknown goal id: {goal_id}")

In [None]:
def _text_encode_rule(rule: list[int]) -> str:
    rule_id = rule[0]
    if rule_id == 1:
        return f"AgentHold({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 2:
        return f"AgentNear({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 3:
        return f"TileNear({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 4:
        return f"TileNearUpRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 5:
        return f"TileNearRightRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 6:
        return f"TileNearDownRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 7:
        return f"TileNearLeftRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 8:
        return f"AgentNearUpRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 9:
        return f"AgentNearRightRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 10:
        return f"AgentNearDownRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 11:
        return f"AgentNearLeftRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    else:
        raise RuntimeError(f"Rendering: Unknown rule id: {rule_id}")

In [None]:
def print_ruleset(ruleset: RuleSet) -> None:
    print("GOAL:")
    print(_text_encode_goal(ruleset.goal.tolist()))
    print()
    print("RULES:")
    for rule in ruleset.rules.tolist():
        if rule[0] != 0:
            print(_text_encode_rule(rule))
    print()
    print("INIT TILES:")
    for tile in ruleset.init_tiles.tolist():
        if tile[0] != 0:
            print(_encode_tile(tile))

### Utils for generating

In [1]:
import random
from itertools import product
from collections import namedtuple

import jax.numpy as jnp
from tqdm.auto import tqdm, trange
from xminigrid.benchmarks import save_bz2_pickle
from xminigrid.core.constants import Colors, Tiles
from xminigrid.rendering.text_render import _encode_tile, _text_encode_goal, _text_encode_rule, print_ruleset
from xminigrid.core.goals import (
    AgentHoldGoal,
    AgentNearDownGoal,
    AgentNearGoal,
    AgentNearLeftGoal,
    AgentNearRightGoal,
    AgentNearUpGoal,
    TileNearDownGoal,
    TileNearGoal,
    TileNearLeftGoal,
    TileNearRightGoal,
    TileNearUpGoal,
)
from xminigrid.core.grid import pad_along_axis
from xminigrid.core.rules import (
    AgentHoldRule,
    AgentNearDownRule,
    AgentNearLeftRule,
    AgentNearRightRule,
    AgentNearRule,
    AgentNearUpRule,
    EmptyRule,
    TileNearDownRule,
    TileNearLeftRule,
    TileNearRightRule,
    TileNearRule,
    TileNearUpRule,
)

In [2]:
COLORS = [
    Colors.RED,
    Colors.GREEN,
    Colors.BLUE,
    Colors.PURPLE,
    Colors.YELLOW,
    Colors.GREY,
    Colors.WHITE,
    Colors.BROWN,
    Colors.PINK,
    Colors.ORANGE,
]

# we need to distinguish between them, to avoid sampling
# near(goal, goal) goal or rule as goal tiles are not pickable
NEAR_TILES_LHS = list(
    product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS)
)
# these are pickable!
NEAR_TILES_RHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

HOLD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

# to imitate disappearance production rule
PROD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))
PROD_TILES = PROD_TILES + [(Tiles.FLOOR, Colors.BLACK)]

# all possible tiles for randomly choosing a prod_tile for generating subtrees
ALL_TILES = list(
    product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS)
)

In [3]:
def encode(ruleset):
    flatten_encoding = jnp.concatenate([ruleset["goal"].encode(), *[r.encode() for r in ruleset["rules"]]]).tolist()
    return tuple(flatten_encoding)


def diff(list1, list2):
    return list(set(list1) - set(list2))

In [4]:
def sample_rule(prod_tile, used_tiles):
    rules = (
        AgentHoldRule,
        # agent near variations
        AgentNearRule,
        AgentNearUpRule,
        AgentNearDownRule,
        AgentNearLeftRule,
        AgentNearRightRule,
        # tile near variations
        TileNearRule,
        TileNearUpRule,
        TileNearDownRule,
        TileNearLeftRule,
        TileNearRightRule,
    )
    rule_idx = random.randint(0, 10)

    if rule_idx == 0:
        tile = random.choice(diff(HOLD_TILES, used_tiles))
        rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
        return rule, (tile,)
    elif 1 <= rule_idx <= 5:
        tile = random.choice(diff(HOLD_TILES, used_tiles))
        rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
        return rule, (tile,)
    elif 6 <= rule_idx <= 10:
        tile_a = random.choice(diff(NEAR_TILES_LHS, used_tiles))
        tile_b = random.choice(diff(NEAR_TILES_RHS, used_tiles))
        rule = rules[rule_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b), prod_tile=jnp.array(prod_tile))
        return rule, (tile_a, tile_b)
    else:
        raise RuntimeError("Unknown rule")

In [5]:
def sample_shared_rules(
    prod_tile,
    chain_depth: int,
    num_distractor_rules: int,
    num_distractor_objects: int,
    sample_depth: bool,
    sample_distractor_rules: bool,
    prune_chain: bool,
    prune_prob: float = 0.0,
):
    used_tiles = []
    chain_tiles = []

    # no sample goal first
    # goal, goal_tiles = sample_goal()
    # used_tiles.extend(goal_tiles)
    # chain_tiles.extend(goal_tiles)

    used_tiles.extend((prod_tile, ))
    chain_tiles.extend((prod_tile, ))
    
    # sample main rules in a chain
    rules = []
    init_tiles = []

    num_levels = random.randint(0, chain_depth) if sample_depth else chain_depth

    
    # there is no rules, thus we need to add prod tiles to init tiles
    if num_levels == 0:
        # WARN: you really should add distractor objects in this case, as without them goal will be obvious
        init_tiles.extend(prod_tile)
        # one empty rule as a placeholder, to fill up "rule" key, this will not introduce overhead under jit
        rules.append(EmptyRule())

    for level in range(num_levels):
        next_chain_tiles = []

        # sampling in a chain, rules inputs from previous layer are rule results from this layer
        while chain_tiles:
            prod_tile = chain_tiles.pop()
            if prune_chain and random.random() < prune_prob:
                # prune this branch and add this tile to initial tiles
                init_tiles.append(prod_tile)
                continue

            rule, rule_tiles = sample_rule(prod_tile, used_tiles)
            used_tiles.extend(rule_tiles)
            #=======================for debugging
            # print("New layer")
            # print(f"used tiles: {used_tiles}")
            #=======================for debugging
            next_chain_tiles.extend(rule_tiles)
            rules.append(rule)

            # inputs to the last rules in the chain are the initial tiles for the current level
            if level == num_levels - 1:
                init_tiles.extend(rule_tiles)

        chain_tiles = next_chain_tiles

    # sample distractor objects
    init_tiles.extend(random.choices(diff(NEAR_TILES_LHS, used_tiles), k=num_distractor_objects))

    # sample distractor rules
    if sample_distractor_rules:
        num_distractor_rules = random.randint(0, num_distractor_rules)

    for _ in range(num_distractor_rules):
        prod_tile = random.choice(diff(PROD_TILES, used_tiles))
        # distractors can include already sampled tiles, to create dead-end rules
        rule, rule_tiles = sample_rule(prod_tile, used_tiles=[])
        rules.append(rule)
        init_tiles.extend(rule_tiles)

    # if for some reason there are no rules, add one empty (we will ignore it later)
    if len(rules) == 0:
        rules.append(EmptyRule())
        
    return rules, init_tiles


In [6]:
# checking this for the purpose of generating goals;
# if the tile type is GOAL, goal_idx has to be compatible with NEAR_TILES_LHS which
# is the only list of tiles that has tile type GOAL --> 1 <= goal_idx <= 5
def is_goal_tile(tile):
    return tile[0] == Tiles.GOAL

In [7]:
def generate_goal_idx(prod_tile):
    if is_goal_tile:
        goal_idx = random.randint(1, 5)
    else:
        goal_idx = random.randint(0, 5)  # restricting to 5 here cuz haven't figured out multiple prod_tiles
    return goal_idx

In [8]:
def generate_goal(prod_tile):
    goals = (
        AgentHoldGoal,
        # agent near variations
        AgentNearGoal,
        AgentNearUpGoal,
        AgentNearDownGoal,
        AgentNearLeftGoal,
        AgentNearRightGoal,
        # tile near variations
        TileNearGoal,
        TileNearUpGoal,
        TileNearDownGoal,
        TileNearLeftGoal,
        TileNearRightGoal,
    )
    
    goal_idx = generate_goal_idx(prod_tile)

    if goal_idx == 0:
        return goals[0](tile=jnp.array(prod_tile))
    elif 1 <= goal_idx <= 5:
        selected_goal = goals[goal_idx]
        return selected_goal(tile=jnp.array(prod_tile))
    # commenting out cuz haven't figured out multiple prod tiles yet
    # elif 6 <= goal_idx <= 10:
    #     selected_goal = random.choice(goals[6:])
    #     return selected_goal(tile_a=jnp.array(goal_tiles[0]), tile_b=jnp.array(goal_tiles[1]))
    return goals[goal_idx], (prod_tile,)

In [9]:
# Define the namedtuple for storing ruleset information
RuleSet = namedtuple('RuleSet', ['goal', 'rules', 'init_tiles'])

def print_rulesets(rulesets):
    """Converts each ruleset dictionary into a namedtuple and prints it."""
    for ruleset in rulesets:
        # Convert dictionary to namedtuple
        ruleset_nt = RuleSet(
            goal=ruleset['goal'],
            rules=ruleset['rules'],
            init_tiles=ruleset['init_tiles']
        )
        # Print the converted ruleset
        print_ruleset(ruleset_nt)
        print("\n===============================")

### Small prototype of sharing entire rule chain

In [10]:
rand_prod_tile = random.choice(ALL_TILES)
print(f"{_encode_tile(rand_prod_tile)}: {rand_prod_tile}")

green ball: (3, 2)


In [11]:
# following settings of small benchmark
chain_depth = 2
num_distractor_rules = 0
num_distractor_objects = 0
sample_depth = False
sample_distractor_rules = False
prune_chain = False
prune_prob = 0.0

In [12]:
rules, init_tiles = sample_shared_rules(
    rand_prod_tile,
    chain_depth,
    num_distractor_rules,
    num_distractor_objects,
    sample_depth,
    sample_distractor_rules,
    prune_chain,
    prune_prob,
)

In [13]:
encoded_rules = jnp.vstack([r.encode() for r in rules]).tolist()

In [14]:
for rule in encoded_rules:
    print(_text_encode_rule(rule))

TileNearLeftRule(white square, blue pyramid) -> green ball
TileNearRightRule(blue goal, green pyramid) -> blue pyramid
AgentNearRightRule(red hexagon) -> white square


In [15]:
for tile in init_tiles:
    print(_encode_tile(tile))

blue goal
green pyramid
red hexagon


In [16]:
num_rulesets = 3

In [17]:
rulesets = []
for _ in range(num_rulesets):
    ruleset = {
        "goal": generate_goal(rand_prod_tile),
        "rules": rules,
        "init_tiles": init_tiles,
    }

    rulesets.append({
        "goal": ruleset["goal"].encode(),
        "rules": jnp.vstack([r.encode() for r in ruleset["rules"]]),
        "init_tiles": jnp.array(ruleset["init_tiles"], dtype=jnp.uint8),
    })

In [18]:
rulesets

[{'goal': Array([13,  3,  2,  0,  0], dtype=int32),
  'rules': Array([[ 7,  4,  9,  5,  3,  3,  2],
         [ 5,  6,  3,  5,  2,  5,  3],
         [ 9, 11,  1,  4,  9,  0,  0]], dtype=uint8),
  'init_tiles': Array([[ 6,  3],
         [ 5,  2],
         [11,  1]], dtype=uint8)},
 {'goal': Array([13,  3,  2,  0,  0], dtype=int32),
  'rules': Array([[ 7,  4,  9,  5,  3,  3,  2],
         [ 5,  6,  3,  5,  2,  5,  3],
         [ 9, 11,  1,  4,  9,  0,  0]], dtype=uint8),
  'init_tiles': Array([[ 6,  3],
         [ 5,  2],
         [11,  1]], dtype=uint8)},
 {'goal': Array([11,  3,  2,  0,  0], dtype=int32),
  'rules': Array([[ 7,  4,  9,  5,  3,  3,  2],
         [ 5,  6,  3,  5,  2,  5,  3],
         [ 9, 11,  1,  4,  9,  0,  0]], dtype=uint8),
  'init_tiles': Array([[ 6,  3],
         [ 5,  2],
         [11,  1]], dtype=uint8)}]

In [19]:
print_rulesets(rulesets)

GOAL:
AgentNearDownGoal(green ball)

RULES:
TileNearLeftRule(white square, blue pyramid) -> green ball
TileNearRightRule(blue goal, green pyramid) -> blue pyramid
AgentNearRightRule(red hexagon) -> white square

INIT TILES:
blue goal
green pyramid
red hexagon

GOAL:
AgentNearDownGoal(green ball)

RULES:
TileNearLeftRule(white square, blue pyramid) -> green ball
TileNearRightRule(blue goal, green pyramid) -> blue pyramid
AgentNearRightRule(red hexagon) -> white square

INIT TILES:
blue goal
green pyramid
red hexagon

GOAL:
AgentNearUpGoal(green ball)

RULES:
TileNearLeftRule(white square, blue pyramid) -> green ball
TileNearRightRule(blue goal, green pyramid) -> blue pyramid
AgentNearRightRule(red hexagon) -> white square

INIT TILES:
blue goal
green pyramid
red hexagon



#### Generate 1k

In [None]:
num_rulesets_1k = 1000

In [None]:
rulesets_1k = []
for _ in range(num_rulesets_1k):
    ruleset = {
        "goal": generate_goal(rand_prod_tile),
        "rules": rules,
        "init_tiles": init_tiles,
    }

    rulesets_1k.append({
        "goal": ruleset["goal"].encode(),
        "rules": jnp.vstack([r.encode() for r in ruleset["rules"]]),
        "init_tiles": jnp.array(ruleset["init_tiles"], dtype=jnp.uint8),
    })

In [None]:
print_rulesets(rulesets_1k)

### Step-by-step plan for generating multiple rulesets with shared lower-level rules

1. **Define Shared Lower-Level Rules:** Instead of generating the entire rule chain anew for each ruleset, generate a common base of rules that can be shared among different rulesets. This base can be the bottom few layers of the rule chain.
2. **Add Variation to Top-Level Rules:** After generating the shared lower-level rules, add unique top-level rules to each ruleset. These rules can differ by the type of rule or the tiles involved, ensuring that the production tile of these top-level rules varies, thereby affecting the goal tile.
3. **Generate Goals Based on Top Production Tile:** Adjust the goal generation method to dynamically select the goal based on the production tile of the last rule added to the ruleset.
4. **Combine and Encode:** Combine the shared and unique rules into multiple rulesets, encode them, and generate goals accordingly.

#### Modifications to Existing Code

1. **Split sample_shared_rules Function**
* We can reuse the sample_shared_rules function to generate base rules. The tricky part is coming up with a way to generate different top rules. But the function now has to return used_tiles too to feed into the function for generating top rules.
* Create a function `generate_top_rules` for generating the varied top levels for each ruleset.
* This function will probably start from a random production tile and generate random rules until the last layer of required rules. In the last layer, will have to generate rule that produces the production tile of the base shared rules.

2. **Modify Goal Generation**
Adjust the generate_goal function to derive the goal from the top production tile, which will now vary between rulesets.

3. **Integrate and Test**
Once the functions for generating base and top rules are defined, integrate them in a loop to generate multiple rulesets, apply the unique top-level rules, and dynamically generate goals.

In [217]:
def sample_base_rules(
    prod_tile,
    chain_depth: int,
    num_distractor_rules: int,
    num_distractor_objects: int,
    sample_depth: bool,
    sample_distractor_rules: bool,
    prune_chain: bool,
    prune_prob: float = 0.0,
):
    used_tiles = []
    chain_tiles = []

    # no sample goal first
    # goal, goal_tiles = sample_goal()
    # used_tiles.extend(goal_tiles)
    # chain_tiles.extend(goal_tiles)

    used_tiles.extend((prod_tile, ))
    chain_tiles.extend((prod_tile, ))
    
    # sample main rules in a chain
    rules = []
    init_tiles = []

    num_levels = random.randint(0, chain_depth) if sample_depth else chain_depth

    
    # there is no rules, thus we need to add prod tiles to init tiles
    if num_levels == 0:
        # WARN: you really should add distractor objects in this case, as without them goal will be obvious
        init_tiles.extend(prod_tile)
        # one empty rule as a placeholder, to fill up "rule" key, this will not introduce overhead under jit
        rules.append(EmptyRule())

    for level in range(num_levels):
        next_chain_tiles = []

        # sampling in a chain, rules inputs from previous layer are rule results from this layer
        while chain_tiles:
            prod_tile = chain_tiles.pop()
            if prune_chain and random.random() < prune_prob:
                # prune this branch and add this tile to initial tiles
                init_tiles.append(prod_tile)
                continue

            rule, rule_tiles = sample_rule(prod_tile, used_tiles)
            used_tiles.extend(rule_tiles)
            #=======================for debugging
            # print("New layer")
            # print(f"used tiles: {used_tiles}")
            #=======================for debugging
            next_chain_tiles.extend(rule_tiles)
            rules.append(rule)

            # inputs to the last rules in the chain are the initial tiles for the current level
            if level == num_levels - 1:
                init_tiles.extend(rule_tiles)

        chain_tiles = next_chain_tiles

    # sample distractor objects
    init_tiles.extend(random.choices(diff(NEAR_TILES_LHS, used_tiles), k=num_distractor_objects))

    # sample distractor rules
    if sample_distractor_rules:
        num_distractor_rules = random.randint(0, num_distractor_rules)

    for _ in range(num_distractor_rules):
        prod_tile = random.choice(diff(PROD_TILES, used_tiles))
        # distractors can include already sampled tiles, to create dead-end rules
        rule, rule_tiles = sample_rule(prod_tile, used_tiles=[])
        rules.append(rule)
        init_tiles.extend(rule_tiles)

    # if for some reason there are no rules, add one empty (we will ignore it later)
    if len(rules) == 0:
        rules.append(EmptyRule())
        
    return rules, init_tiles, used_tiles


In [218]:
# following settings of small benchmark
base_chain_depth = 2
base_num_distractor_rules = 0
base_num_distractor_objects = 0
base_sample_depth = False
base_sample_distractor_rules = False
base_prune_chain = False
base_prune_prob = 0.0

In [219]:
base_rand_prod_tile = random.choice(ALL_TILES)
print(f"{_encode_tile(base_rand_prod_tile)}: {base_rand_prod_tile}")

green hexagon: (11, 2)


In [220]:
base_rules, base_init_tiles, base_used_tiles = sample_base_rules(
    base_rand_prod_tile,
    base_chain_depth,
    base_num_distractor_rules,
    base_num_distractor_objects,
    base_sample_depth,
    base_sample_distractor_rules,
    base_prune_chain,
    base_prune_prob,
)

In [221]:
for tile in base_init_tiles:
    print(_encode_tile(tile))

purple ball
pink hexagon


In [222]:
for tile in base_used_tiles:
    print(_encode_tile(tile))

green hexagon
brown square
white hexagon
purple ball
pink hexagon


In [223]:
for rule in base_rules:
    print(_text_encode_rule(rule.encode().tolist()))

TileNearLeftRule(brown square, white hexagon) -> green hexagon
AgentHold(purple ball) -> white hexagon
AgentNear(pink hexagon) -> brown square


In [224]:
def sample_goal():
    goals = (
        AgentHoldGoal,
        # agent near variations
        AgentNearGoal,
        AgentNearUpGoal,
        AgentNearDownGoal,
        AgentNearLeftGoal,
        AgentNearRightGoal,
        # tile near variations
        TileNearGoal,
        TileNearUpGoal,
        TileNearDownGoal,
        TileNearLeftGoal,
        TileNearRightGoal,
    )
    goal_idx = random.randint(0, 10)
    if goal_idx == 0:
        tile = random.choice(HOLD_TILES)
        goal = goals[0](tile=jnp.array(tile))
        return goal, (tile,)
    elif 1 <= goal_idx <= 5:
        tile = random.choice(NEAR_TILES_LHS)
        goal = goals[goal_idx](tile=jnp.array(tile))
        return goal, (tile,)
    elif 6 <= goal_idx <= 10:
        tile_a = random.choice(NEAR_TILES_LHS)
        tile_b = random.choice(NEAR_TILES_RHS)
        goal = goals[goal_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b))
        return goal, (tile_a, tile_b)
    else:
        raise RuntimeError("Unknown goal")

In [225]:
def sample_rule_with_input(input_tile, prod_tile):
    rules = (
        AgentHoldRule,
        # agent near variations
        AgentNearRule,
        AgentNearUpRule,
        AgentNearDownRule,
        AgentNearLeftRule,
        AgentNearRightRule,
        # tile near variations
        TileNearRule,
        TileNearUpRule,
        TileNearDownRule,
        TileNearLeftRule,
        TileNearRightRule,
    )
    
    if is_goal_tile(input_tile):
        rule_idx = 0
    else:
        rule_idx = random.randint(1, 5)  # only till 5 cuz i haven't implement multiple tiles logic
    
    rule = rules[rule_idx](tile=jnp.array(input_tile), prod_tile=jnp.array(prod_tile))
    return rule, (input_tile,)

In [226]:
def sample_top_rules(
    base_prod_tile,
    base_used_tiles,
    chain_depth: int,
    num_distractor_rules: int,
    num_distractor_objects: int,
    sample_depth: bool,
    sample_distractor_rules: bool,
    prune_chain: bool,
    prune_prob: float = 0.0,
):
    # the used tiles start with what was already used in generating base rules
    used_tiles = base_used_tiles.copy()
    chain_tiles = []

    # sample goal first
    goal, goal_tiles = sample_goal()
    print(f"goal tiles: {_encode_tile(goal_tiles[0])}")
    used_tiles.extend(goal_tiles)
    chain_tiles.extend(goal_tiles)

    # sample main rules in a chain
    rules = []
    init_tiles = []

    num_levels = random.randint(0, chain_depth) if sample_depth else chain_depth

    # there is no rules, just one goal, thus we need to add goal tiles to init tiles
    if num_levels == 0:
        # WARN: you really should add distractor objects in this case, as without them goal will be obvious
        init_tiles.extend(goal_tiles)
        # one empty rule as a placeholder, to fill up "rule" key, this will not introduce overhead under jit
        rules.append(EmptyRule())

    for level in range(num_levels):
        next_chain_tiles = []

        # sampling in a chain, rules inputs from previous layer are rule results from this layer
        while chain_tiles:
            prod_tile = chain_tiles.pop()
            if prune_chain and random.random() < prune_prob:
                # prune this branch and add this tile to initial tiles
                init_tiles.append(prod_tile)
                continue

            #========================== modify to only use original logic for all layers until the last one 
            if level == num_levels - 1:
                # On the last level, modify the rule to ensure it uses the base_prod_tile as input
                rule, rule_tiles = sample_rule_with_input(base_prod_tile, prod_tile)
            else:
                rule, rule_tiles = sample_rule(prod_tile, used_tiles)
            #========================== modify to only use original logic for all layers until the last one 
            used_tiles.extend(rule_tiles)
            next_chain_tiles.extend(rule_tiles)
            rules.append(rule)

            # inputs to the last rules in the chain are the initial tiles for the current level
            if level == num_levels - 1:
                init_tiles.extend(rule_tiles)

        chain_tiles = next_chain_tiles

    # sample distractor objects
    init_tiles.extend(random.choices(diff(NEAR_TILES_LHS, used_tiles), k=num_distractor_objects))
    print(f"init_tiles: {init_tiles}")
    # sample distractor rules
    if sample_distractor_rules:
        num_distractor_rules = random.randint(0, num_distractor_rules)

    for _ in range(num_distractor_rules):
        prod_tile = random.choice(diff(PROD_TILES, used_tiles))
        # distractors can include already sampled tiles, to create dead-end rules
        rule, rule_tiles = sample_rule(prod_tile, used_tiles=[])
        rules.append(rule)
        init_tiles.extend(rule_tiles)

    # if for some reason there are no rules, add one empty (we will ignore it later)
    if len(rules) == 0:
        rules.append(EmptyRule())

    return goal, rules, init_tiles
    

In [227]:
# following settings of small benchmark
top_chain_depth = 1
top_num_distractor_rules = 0
top_num_distractor_objects = 0
top_sample_depth = False
top_sample_distractor_rules = False
top_prune_chain = False
top_prune_prob = 0.0

In [228]:
goal, top_rules, top_init_tiles = sample_top_rules(
    base_rand_prod_tile,
    base_used_tiles,
    top_chain_depth,
    top_num_distractor_rules,
    top_num_distractor_objects,
    top_sample_depth,
    top_sample_distractor_rules,
    top_prune_chain,
    top_prune_prob,
)

goal tiles: brown key
init_tiles: [(11, 2), (11, 2)]


In [229]:
encoded_top_rules = jnp.vstack([r.encode() for r in top_rules]).tolist()

In [230]:
for rule in encoded_top_rules:
    print(_text_encode_rule(rule))

AgentNearLeftRule(green hexagon) -> white pyramid
AgentNearDownRule(green hexagon) -> brown key


In [231]:
for tile in top_init_tiles:
    print(_encode_tile(tile))

green hexagon
green hexagon


In [232]:
_text_encode_goal(goal.encode().tolist())

'TileNearDownGoal(brown key, white pyramid)'

In [233]:
full_rules = top_rules + base_rules

In [234]:
for rule in full_rules:
    print(_text_encode_rule(rule.encode().tolist()))

AgentNearLeftRule(green hexagon) -> white pyramid
AgentNearDownRule(green hexagon) -> brown key
TileNearLeftRule(brown square, white hexagon) -> green hexagon
AgentHold(purple ball) -> white hexagon
AgentNear(pink hexagon) -> brown square


In [235]:
num_concat_rulesets = 3

In [236]:
concat_rulesets = []

for _ in range(num_concat_rulesets):
    goal, top_rules, top_init_tiles = sample_top_rules(
        base_rand_prod_tile,
        base_used_tiles,
        top_chain_depth,
        top_num_distractor_rules,
        top_num_distractor_objects,
        top_sample_depth,
        top_sample_distractor_rules,
        top_prune_chain,
        top_prune_prob,
    )

    full_rules = top_rules + base_rules
    full_init_tiles = base_init_tiles
    
    ruleset = {
        "goal": goal,
        "rules": full_rules,
        "init_tiles": full_init_tiles,
    }

    concat_rulesets.append({
        "goal": ruleset["goal"].encode(),
        "rules": jnp.vstack([r.encode() for r in ruleset["rules"]]),
        "init_tiles": jnp.array(ruleset["init_tiles"], dtype=jnp.uint8),
    })

goal tiles: blue square
init_tiles: [(11, 2)]
goal tiles: blue key
init_tiles: [(11, 2)]
goal tiles: purple hexagon
init_tiles: [(11, 2)]


In [237]:
print_rulesets(concat_rulesets)

GOAL:
AgentNearRightGoal(blue square)

RULES:
AgentNear(green hexagon) -> blue square
TileNearLeftRule(brown square, white hexagon) -> green hexagon
AgentHold(purple ball) -> white hexagon
AgentNear(pink hexagon) -> brown square

INIT TILES:
purple ball
pink hexagon

GOAL:
AgentHold(blue key)

RULES:
AgentNearUpRule(green hexagon) -> blue key
TileNearLeftRule(brown square, white hexagon) -> green hexagon
AgentHold(purple ball) -> white hexagon
AgentNear(pink hexagon) -> brown square

INIT TILES:
purple ball
pink hexagon

GOAL:
AgentHold(purple hexagon)

RULES:
AgentNearDownRule(green hexagon) -> purple hexagon
TileNearLeftRule(brown square, white hexagon) -> green hexagon
AgentHold(purple ball) -> white hexagon
AgentNear(pink hexagon) -> brown square

INIT TILES:
purple ball
pink hexagon



In [239]:
top_chain_depth_2 = 3

In [240]:
concat_rulesets_2 = []

for _ in range(num_concat_rulesets):
    goal, top_rules, top_init_tiles = sample_top_rules(
        base_rand_prod_tile,
        base_used_tiles,
        top_chain_depth_2,
        top_num_distractor_rules,
        top_num_distractor_objects,
        top_sample_depth,
        top_sample_distractor_rules,
        top_prune_chain,
        top_prune_prob,
    )

    full_rules = top_rules + base_rules
    full_init_tiles = base_init_tiles
    
    ruleset = {
        "goal": goal,
        "rules": full_rules,
        "init_tiles": full_init_tiles,
    }

    concat_rulesets_2.append({
        "goal": ruleset["goal"].encode(),
        "rules": jnp.vstack([r.encode() for r in ruleset["rules"]]),
        "init_tiles": jnp.array(ruleset["init_tiles"], dtype=jnp.uint8),
    })

goal tiles: grey star
init_tiles: [(11, 2)]
goal tiles: red star
init_tiles: [(11, 2), (11, 2), (11, 2), (11, 2), (11, 2)]
goal tiles: yellow pyramid
init_tiles: [(11, 2), (11, 2), (11, 2), (11, 2)]


In [242]:
print_rulesets(concat_rulesets_2)

GOAL:
AgentNearLeftGoal(grey star)

RULES:
AgentHold(red hexagon) -> grey star
AgentNearUpRule(grey key) -> red hexagon
AgentNearUpRule(green hexagon) -> grey key
TileNearLeftRule(brown square, white hexagon) -> green hexagon
AgentHold(purple ball) -> white hexagon
AgentNear(pink hexagon) -> brown square

INIT TILES:
purple ball
pink hexagon

GOAL:
TileNearUpGoal(red star, red ball)

RULES:
AgentNearDownRule(green pyramid) -> red ball
TileNearUpRule(red goal, yellow ball) -> red star
AgentNearLeftRule(white key) -> yellow ball
TileNearUpRule(blue pyramid, blue hexagon) -> red goal
TileNearRightRule(orange star, purple key) -> green pyramid
AgentNear(green hexagon) -> purple key
AgentNear(green hexagon) -> orange star
AgentNear(green hexagon) -> blue hexagon
AgentNearLeftRule(green hexagon) -> blue pyramid
AgentNearRightRule(green hexagon) -> white key
TileNearLeftRule(brown square, white hexagon) -> green hexagon
AgentHold(purple ball) -> white hexagon
AgentNear(pink hexagon) -> brown 