### Utils for printing

In [None]:
import jax
import jax.numpy as jnp

from xminigrid.core.constants import Colors, Tiles
from xminigrid.types import AgentState, RuleSet

In [None]:
COLOR_NAMES = {
    Colors.EMPTY: "white",
    Colors.RED: "red",
    Colors.GREEN: "green",
    Colors.BLUE: "blue",
    Colors.PURPLE: "purple",
    Colors.YELLOW: "yellow",
    Colors.GREY: "grey",
    Colors.BLACK: "black",
    Colors.ORANGE: "orange",
    Colors.WHITE: "white",
    Colors.BROWN: "brown",
    Colors.PINK: "pink",
}

TILE_STR = {
    Tiles.EMPTY: " ",
    Tiles.FLOOR: ".",
    Tiles.WALL: "☰",
    Tiles.BALL: "⏺",
    Tiles.SQUARE: "▪",
    Tiles.PYRAMID: "▲",
    Tiles.HEX: "⬢",
    Tiles.STAR: "★",
    Tiles.GOAL: "■",
    Tiles.DOOR_LOCKED: "L",
    Tiles.DOOR_CLOSED: "C",
    Tiles.DOOR_OPEN: "O",
    Tiles.KEY: "K",
}

# for ruleset printing
RULE_TILE_STR = {
    Tiles.FLOOR: "floor",
    Tiles.BALL: "ball",
    Tiles.SQUARE: "square",
    Tiles.PYRAMID: "pyramid",
    Tiles.GOAL: "goal",
    Tiles.KEY: "key",
    Tiles.HEX: "hexagon",
    Tiles.STAR: "star",
}

PLAYER_STR = {0: "^", 1: ">", 2: "V", 3: "<"}

In [None]:
def _wrap_with_color(string: str, color: str) -> str:
    return f"[bold {color}]{string}[/bold {color}]"


# WARN: will NOT work under jit and needed for debugging mainly.
def render(grid: jax.Array, agent: AgentState | None = None) -> str:
    string = ""

    for y in range(grid.shape[0]):
        for x in range(grid.shape[1]):
            tile_id, tile_color = grid[y, x]
            tile_str = TILE_STR[tile_id.item()]
            tile_color = COLOR_NAMES[tile_color.item()]

            if agent is not None and jnp.all(agent.position == jnp.array((y, x))):
                tile_str = PLAYER_STR[agent.direction.item()]
                tile_color = COLOR_NAMES[Colors.RED]

            string += _wrap_with_color(tile_str, tile_color)

        if y < grid.shape[0] - 1:
            string += "\n"

    return string

In [None]:
# WARN: This is for debugging mainly! Will refactor later if needed.
def _encode_tile(tile: list[int]) -> str:
    return f"{COLOR_NAMES[tile[1]]} {RULE_TILE_STR[tile[0]]}"

In [None]:
def _text_encode_goal(goal: list[int]) -> str:
    goal_id = goal[0]
    if goal_id == 1:
        return f"AgentHold({_encode_tile(goal[1:3])})"
    elif goal_id == 3:
        return f"AgentNear({_encode_tile(goal[1:3])})"
    elif goal_id == 4:
        return f"TileNear({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 7:
        return f"TileNearUpGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 8:
        return f"TileNearRightGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 9:
        return f"TileNearDownGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 10:
        return f"TileNearLeftGoal({_encode_tile(goal[1:3])}, {_encode_tile(goal[3:5])})"
    elif goal_id == 11:
        return f"AgentNearUpGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 12:
        return f"AgentNearRightGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 13:
        return f"AgentNearDownGoal({_encode_tile(goal[1:3])})"
    elif goal_id == 14:
        return f"AgentNearLeftGoal({_encode_tile(goal[1:3])})"
    else:
        raise RuntimeError(f"Rendering: Unknown goal id: {goal_id}")

In [None]:
def _text_encode_rule(rule: list[int]) -> str:
    rule_id = rule[0]
    if rule_id == 1:
        return f"AgentHold({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 2:
        return f"AgentNear({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 3:
        return f"TileNear({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 4:
        return f"TileNearUpRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 5:
        return f"TileNearRightRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 6:
        return f"TileNearDownRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 7:
        return f"TileNearLeftRule({_encode_tile(rule[1:3])}, {_encode_tile(rule[3:5])}) -> {_encode_tile(rule[5:7])}"
    elif rule_id == 8:
        return f"AgentNearUpRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 9:
        return f"AgentNearRightRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 10:
        return f"AgentNearDownRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    elif rule_id == 11:
        return f"AgentNearLeftRule({_encode_tile(rule[1:3])}) -> {_encode_tile(rule[3:5])}"
    else:
        raise RuntimeError(f"Rendering: Unknown rule id: {rule_id}")

In [None]:
def print_ruleset(ruleset: RuleSet) -> None:
    print("GOAL:")
    print(_text_encode_goal(ruleset.goal.tolist()))
    print()
    print("RULES:")
    for rule in ruleset.rules.tolist():
        if rule[0] != 0:
            print(_text_encode_rule(rule))
    print()
    print("INIT TILES:")
    for tile in ruleset.init_tiles.tolist():
        if tile[0] != 0:
            print(_encode_tile(tile))

### Utils for generating

In [1]:
import random
from itertools import product
from collections import namedtuple

import jax.numpy as jnp
from tqdm.auto import tqdm, trange
from xminigrid.benchmarks import save_bz2_pickle
from xminigrid.core.constants import Colors, Tiles
from xminigrid.rendering.text_render import _encode_tile, _text_encode_goal, _text_encode_rule, print_ruleset
from xminigrid.core.goals import (
    AgentHoldGoal,
    AgentNearDownGoal,
    AgentNearGoal,
    AgentNearLeftGoal,
    AgentNearRightGoal,
    AgentNearUpGoal,
    TileNearDownGoal,
    TileNearGoal,
    TileNearLeftGoal,
    TileNearRightGoal,
    TileNearUpGoal,
)
from xminigrid.core.grid import pad_along_axis
from xminigrid.core.rules import (
    AgentHoldRule,
    AgentNearDownRule,
    AgentNearLeftRule,
    AgentNearRightRule,
    AgentNearRule,
    AgentNearUpRule,
    EmptyRule,
    TileNearDownRule,
    TileNearLeftRule,
    TileNearRightRule,
    TileNearRule,
    TileNearUpRule,
)

In [2]:
COLORS = [
    Colors.RED,
    Colors.GREEN,
    Colors.BLUE,
    Colors.PURPLE,
    Colors.YELLOW,
    Colors.GREY,
    Colors.WHITE,
    Colors.BROWN,
    Colors.PINK,
    Colors.ORANGE,
]

# we need to distinguish between them, to avoid sampling
# near(goal, goal) goal or rule as goal tiles are not pickable
NEAR_TILES_LHS = list(
    product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS)
)
# these are pickable!
NEAR_TILES_RHS = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

HOLD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))

# to imitate disappearance production rule
PROD_TILES = list(product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX], COLORS))
PROD_TILES = PROD_TILES + [(Tiles.FLOOR, Colors.BLACK)]

# all possible tiles for randomly choosing a prod_tile for generating subtrees
ALL_TILES = list(
    product([Tiles.BALL, Tiles.SQUARE, Tiles.PYRAMID, Tiles.KEY, Tiles.STAR, Tiles.HEX, Tiles.GOAL], COLORS)
)

In [3]:
def encode(ruleset):
    flatten_encoding = jnp.concatenate([ruleset["goal"].encode(), *[r.encode() for r in ruleset["rules"]]]).tolist()
    return tuple(flatten_encoding)

def diff(list1, list2):
    return list(set(list1) - set(list2))

In [52]:
# checking this for the purpose of generating goals;
# if the tile type is GOAL, goal_idx has to be compatible with NEAR_TILES_LHS which
# is the only list of tiles that has tile type GOAL --> 1 <= goal_idx <= 5

# the encoding of goal types is the following:
# Tiles.BALL = 3, Tiles.SQUARE = 4, Tiles.PYRAMID = 5, 
# Tiles.KEY = 7, Tiles.STAR = 12, Tiles.HEX = 11, Tiles.GOAL = 6
def is_goal_tile(tile):
    return tile[0] == 6

In [63]:
def sample_goal(goal_tiles):
    goals = (
        AgentHoldGoal,
        # agent near variations
        AgentNearGoal,
        AgentNearUpGoal,
        AgentNearDownGoal,
        AgentNearLeftGoal,
        AgentNearRightGoal,
        # tile near variations
        TileNearGoal,
        TileNearUpGoal,
        TileNearDownGoal,
        TileNearLeftGoal,
        TileNearRightGoal,
    )

    if len(goal_tiles) == 1:
        print("Generating an AgentHoldGoal or AgentNearXXX goal.")
        tile = goal_tiles[0]
        if is_goal_tile(tile):
            goal_idx = random.randint(1, 5)
        else:
            goal_idx = random.randint(0, 5)
        goal = goals[goal_idx](tile=jnp.array(tile))
        return goal, (tile,)
    elif len(goal_tiles) == 2:
        print("Generating a TileNearXXX goal.")
        tile_a, tile_b = goal_tiles
        if is_goal_tile(tile_b):
            raise RuntimeError("Imcompatible goal tile types.")
        else:
            goal_idx = random.randint(6, 10)
        goal = goals[goal_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b))
        return goal, (tile_a, tile_b)
    else:
        raise RuntimeError("Unknown goal.")

In [80]:
goal_tiles = random.choice(ALL_TILES), random.choice(ALL_TILES)
tile_a, tile_b = goal_tiles
_encode_tile(tile_a), _encode_tile(tile_b)

('pink goal', 'brown pyramid')

In [83]:
goal, goal_tiles = sample_goal(goal_tiles)

Generating a TileNearXXX goal.


In [84]:
_text_encode_goal(goal.encode().tolist())

'TileNearUpGoal(pink goal, brown pyramid)'

In [None]:
def sample_rule(prod_tile, used_tiles):
    rules = (
        AgentHoldRule,
        # agent near variations
        AgentNearRule,
        AgentNearUpRule,
        AgentNearDownRule,
        AgentNearLeftRule,
        AgentNearRightRule,
        # tile near variations
        TileNearRule,
        TileNearUpRule,
        TileNearDownRule,
        TileNearLeftRule,
        TileNearRightRule,
    )
    rule_idx = random.randint(0, 10)

    if rule_idx == 0:
        tile = random.choice(diff(HOLD_TILES, used_tiles))
        rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
        return rule, (tile,)
    elif 1 <= rule_idx <= 5:
        tile = random.choice(diff(HOLD_TILES, used_tiles))
        rule = rules[rule_idx](tile=jnp.array(tile), prod_tile=jnp.array(prod_tile))
        return rule, (tile,)
    elif 6 <= rule_idx <= 10:
        tile_a = random.choice(diff(NEAR_TILES_LHS, used_tiles))
        tile_b = random.choice(diff(NEAR_TILES_RHS, used_tiles))
        rule = rules[rule_idx](tile_a=jnp.array(tile_a), tile_b=jnp.array(tile_b), prod_tile=jnp.array(prod_tile))
        return rule, (tile_a, tile_b)
    else:
        raise RuntimeError("Unknown rule")

In [None]:
# Define the namedtuple for storing ruleset information
RuleSet = namedtuple('RuleSet', ['goal', 'rules', 'init_tiles'])

def print_rulesets(rulesets):
    """Converts each ruleset dictionary into a namedtuple and prints it."""
    for ruleset in rulesets:
        # Convert dictionary to namedtuple
        ruleset_nt = RuleSet(
            goal=ruleset['goal'],
            rules=ruleset['rules'],
            init_tiles=ruleset['init_tiles']
        )
        # Print the converted ruleset
        print_ruleset(ruleset_nt)
        print("\n===============================")