In [16]:
import random
from itertools import product
from collections import namedtuple

import jax.numpy as jnp
from tqdm.auto import tqdm, trange
from xminigrid.benchmarks import save_bz2_pickle
from xminigrid.core.constants import Colors, Tiles
from xminigrid.rendering.text_render import print_ruleset
from xminigrid.core.goals import (
    AgentHoldGoal,
    AgentNearDownGoal,
    AgentNearGoal,
    AgentNearLeftGoal,
    AgentNearRightGoal,
    AgentNearUpGoal,
    TileNearDownGoal,
    TileNearGoal,
    TileNearLeftGoal,
    TileNearRightGoal,
    TileNearUpGoal,
)
from xminigrid.core.grid import pad_along_axis
from xminigrid.core.rules import (
    AgentHoldRule,
    AgentNearDownRule,
    AgentNearLeftRule,
    AgentNearRightRule,
    AgentNearRule,
    AgentNearUpRule,
    EmptyRule,
    TileNearDownRule,
    TileNearLeftRule,
    TileNearRightRule,
    TileNearRule,
    TileNearUpRule,
)

In [2]:
COLORS = [
    Colors.RED,
    Colors.GREEN,
    Colors.BLUE,
    Colors.PURPLE,
    Colors.YELLOW,
    Colors.GREY,
    Colors.WHITE,
    Colors.BROWN,
    Colors.PINK,
    Colors.ORANGE,
]

In [3]:
# we need to distinguish between them, to avoid sampling
# near(goal, goal) goal or rule as goal tiles are not pickable
NEAR_TILES_LHS = list(
    product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS)
)
# these are pickable!
NEAR_TILES_RHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

HOLD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

# to imitate disappearance production rule
PROD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))
PROD_TILES = PROD_TILES + [(Tiles.FLOOR, Colors.BLACK)]

In [4]:
def encode(ruleset):
    flatten_encoding = jnp.concatenate([ruleset["goal"].encode(), *[r.encode() for r in ruleset["rules"]]]).tolist()
    return tuple(flatten_encoding)


def diff(list1, list2):
    return list(set(list1) - set(list2))

In [5]:
def sample_goal_tiles():
    # This function now returns just the tiles, not the goal object
    goal_idx = random.randint(0, 10)
    print(goal_idx)
    if goal_idx == 0:
        tile = random.choice(HOLD_TILES)
    elif 1 <= goal_idx <= 5:
        tile = random.choice(NEAR_TILES_LHS)
    else:
        tile_a = random.choice(NEAR_TILES_LHS)
        tile_b = random.choice(NEAR_TILES_RHS)
        return goal_idx, (tile_a, tile_b)
    return goal_idx, (tile,)

In [6]:
def sample_goal(goal_idx, goal_tiles):
    goals = (
        AgentHoldGoal,
        # Agent near variations
        AgentNearGoal,
        AgentNearUpGoal,
        AgentNearDownGoal,
        AgentNearLeftGoal,
        AgentNearRightGoal,
        # Tile near variations
        TileNearGoal,
        TileNearUpGoal,
        TileNearDownGoal,
        TileNearLeftGoal,
        TileNearRightGoal,
    )
    
    if goal_idx == 0:
        return goals[0](tile=jnp.array(goal_tiles[0]))
    elif 1 <= goal_idx <= 5:
        selected_goal = random.choice(goals[1:6])
        return selected_goal(tile=jnp.array(goal_tiles[0]))
    elif 6 <= goal_idx <= 10:
        selected_goal = random.choice(goals[6:])
        return selected_goal(tile_a=jnp.array(goal_tiles[0]), tile_b=jnp.array(goal_tiles[1]))


In [7]:
def sample_rule(prod_tile, used_tiles):
    rules = (
        AgentHoldRule,
        # agent near variations
        AgentNearRule,
        AgentNearUpRule,
        AgentNearDownRule,
        AgentNearLeftRule,
        AgentNearRightRule,
        # tile near variations
        TileNearRule,
        TileNearUpRule,
        TileNearDownRule,
        TileNearLeftRule,
        TileNearRightRule,
    )
    rule_idx = random.randint(0, 10)

    if rule_idx == 0:
        tile = random.choice(diff(HOLD_TILES, used_tiles))
        rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
        return rule, (tile,)
    elif 1 <= rule_idx <= 5:
        tile = random.choice(diff(HOLD_TILES, used_tiles))
        rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
        return rule, (tile,)
    elif 6 <= rule_idx <= 10:
        tile_a = random.choice(diff(NEAR_TILES_LHS, used_tiles))
        tile_b = random.choice(diff(NEAR_TILES_RHS, used_tiles))
        rule = rules[rule_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b), prod_tile=jnp.array(prod_tile))
        return rule, (tile_a, tile_b)
    else:
        raise RuntimeError("Unknown rule")

In [24]:
def pre_sample_rules(
    chain_depth: int, 
    num_distractor_rules: int,
    num_distractor_objects: int,
    sample_depth: bool,
    sample_distractor_rules: bool,
    prune_chain: bool,
    prune_prob: float = 0.0
):
    used_tiles = []
    rules = []
    init_tiles = []
    chain_tiles = [random.choice(PROD_TILES)]  # Start chain with a random product tile

    num_levels = random.randint(0, chain_depth) if sample_depth else chain_depth

    for level in range(num_levels):
        next_chain_tiles = []
        while chain_tiles:
            prod_tile = chain_tiles.pop()
            if prune_chain and random.random() < prune_prob:
                init_tiles.append(prod_tile)
                continue

            rule, rule_tiles = sample_rule(prod_tile, used_tiles)
            used_tiles.extend(rule_tiles)
            next_chain_tiles.extend(rule_tiles)
            rules.append(rule)

        chain_tiles = next_chain_tiles

    if sample_distractor_rules:
        for _ in range(num_distractor_rules):
            prod_tile = random.choice(diff(PROD_TILES, used_tiles))
            rule, rule_tiles = sample_rule(prod_tile, [])
            rules.append(rule)
            init_tiles.extend(rule_tiles)

    if not rules:  # Ensure there's at least one rule
        rules.append(EmptyRule())

    return rules, used_tiles, init_tiles

In [9]:
def sample_ruleset(goal_idx, goal_tiles, pre_sampled_rules, pre_sampled_init_tiles, num_distractor_objects):
    goal = sample_goal(goal_idx, goal_tiles)  # Dynamic goal sampling
    init_tiles = list(pre_sampled_init_tiles)  # Start with pre-sampled initial tiles

    # Adding goal tiles to initial tiles ensuring they don't overlap with used tiles
    init_tiles.extend([tile for tile in goal_tiles if tile not in init_tiles])

    # Adding a few more distractor objects if needed
    init_tiles.extend(random.choices(diff(NEAR_TILES_LHS, init_tiles), k=num_distractor_objects))

    return {
        "goal": goal,
        "rules": pre_sampled_rules,  # Use pre-sampled rules
        "init_tiles": init_tiles,
        "num_rules": len([r for r in pre_sampled_rules if not isinstance(r, EmptyRule)]),
    }

In [11]:
goal_idx, goal_tiles = sample_goal_tiles()

5


In [22]:
# Example of using the modified functions
pre_sampled_rules, used_tiles, pre_sampled_init_tiles = pre_sample_rules(chain_depth=3, 
                                                                         num_distractor_rules=0, 
                                                                         num_distractor_objects=0, 
                                                                         sample_depth=False, 
                                                                         sample_distractor_rules=False, 
                                                                         prune_chain=False, 
                                                                         prune_prob=0.0)

rulesets = []
for _ in range(5):
    ruleset = sample_ruleset(goal_idx, goal_tiles, pre_sampled_rules, pre_sampled_init_tiles, num_distractor_objects=0)
    rulesets.append({
                "goal": ruleset["goal"].encode(),
                "rules": jnp.vstack([r.encode() for r in ruleset["rules"]]),
                "init_tiles": jnp.array(ruleset["init_tiles"], dtype=jnp.uint8),
                "num_rules": jnp.asarray(ruleset["num_rules"], dtype=jnp.uint8),
    })

In [23]:
RuleSet = namedtuple('RuleSet', ['goal', 'rules', 'init_tiles'])

for ruleset in rulesets:
    # Convert dictionary to namedtuple
    ruleset_nt = RuleSet(
        goal=ruleset['goal'],
        rules=ruleset['rules'],
        init_tiles=ruleset['init_tiles']
    )
    print_ruleset(ruleset_nt)
    print("\n===============================")

GOAL:
AgentNearRightGoal(blue pyramid)

RULES:
AgentHold(pink hexagon) -> purple square
AgentNearDownRule(brown hexagon) -> pink hexagon
TileNear(yellow square, white star) -> brown hexagon

INIT TILES:
blue pyramid

GOAL:
AgentNear(blue pyramid)

RULES:
AgentHold(pink hexagon) -> purple square
AgentNearDownRule(brown hexagon) -> pink hexagon
TileNear(yellow square, white star) -> brown hexagon

INIT TILES:
blue pyramid

GOAL:
AgentNear(blue pyramid)

RULES:
AgentHold(pink hexagon) -> purple square
AgentNearDownRule(brown hexagon) -> pink hexagon
TileNear(yellow square, white star) -> brown hexagon

INIT TILES:
blue pyramid

GOAL:
AgentNearRightGoal(blue pyramid)

RULES:
AgentHold(pink hexagon) -> purple square
AgentNearDownRule(brown hexagon) -> pink hexagon
TileNear(yellow square, white star) -> brown hexagon

INIT TILES:
blue pyramid

GOAL:
AgentNearDownGoal(blue pyramid)

RULES:
AgentHold(pink hexagon) -> purple square
AgentNearDownRule(brown hexagon) -> pink hexagon
TileNear(yello