In [1]:
import numpy as np
import random

# Colors and object types
COLORS = {
    "red": np.array([255, 0, 0]),
    "green": np.array([0, 255, 0]),
    "blue": np.array([0, 0, 255]),
    "purple": np.array([112, 39, 195]),
    "yellow": np.array([255, 255, 0]),
    "grey": np.array([100, 100, 100]),
}
OBJECT_TYPES = ["key", "ball", "box"]

# Function to create the maze
def create_maze(colors, object_types):
    # Select 4 random colors for initial keys
    initial_key_colors = random.sample(colors, 4)
    initial_keys = [("key", color) for color in initial_key_colors]

    # Create all possible pairs
    all_pairs = [(obj_type, color) for obj_type in object_types for color in colors]

    # Remove initial key pairs from all pairs
    remaining_pairs = [pair for pair in all_pairs if pair not in initial_keys]

    # Select 8 unique pairs from the remaining pairs
    selected_pairs = random.sample(remaining_pairs, 8)

    # Split into 4 sets of 2
    sets_of_pairs = [selected_pairs[i:i+2] for i in range(0, len(selected_pairs), 2)]

    return initial_keys, sets_of_pairs

# Create the maze
initial_keys, sets_of_pairs = create_maze(list(COLORS.keys()), OBJECT_TYPES)

from pprint import pprint
print("Initial Keys:")
pprint(initial_keys)
print("Sets of Pairs:")
pprint(sets_of_pairs)


Initial Keys:
[('key', 'yellow'), ('key', 'purple'), ('key', 'blue'), ('key', 'grey')]
Sets of Pairs:
[[('ball', 'yellow'), ('ball', 'purple')],
 [('key', 'red'), ('key', 'green')],
 [('box', 'grey'), ('ball', 'green')],
 [('box', 'green'), ('ball', 'grey')]]


In [2]:
# this one has the test pairs overlap in shape/color with the keys/training pairs
def create_maze_overlap(colors, object_types):
    # Select 4 random colors for initial keys
    initial_key_colors = random.sample(colors, 4)
    initial_keys = [("key", color) for color in initial_key_colors]

    # Create all possible pairs
    all_pairs = [(obj_type, color) for obj_type in object_types for color in colors]

    # Remove initial key pairs from all pairs
    remaining_pairs = [pair for pair in all_pairs if pair not in initial_keys]

    # Select 4 unique pairs from the remaining pairs
    first_selected_pairs = random.sample(remaining_pairs, 4)

    # Observe types and colors from these pairs
    observed_types = set([pair[0] for pair in first_selected_pairs+initial_keys])
    observed_colors = set([pair[1] for pair in first_selected_pairs+initial_keys])

    # Create possible pairs from observed types and colors
    possible_new_pairs = [(obj_type, color) for obj_type in observed_types for color in observed_colors]

    # Remove already selected pairs
    new_pairs_candidates = [pair for pair in possible_new_pairs if pair not in first_selected_pairs and pair not in initial_keys]

    # Select 4 unique pairs from new pairs candidates
    second_selected_pairs = random.sample(new_pairs_candidates, 4)

    # Interleave first and second pairs to form 4 sets of 2
    sets_of_pairs = [(first_selected_pairs[i], second_selected_pairs[i]) for i in range(4)]

    return initial_keys, sets_of_pairs

# Create the maze
initial_keys, sets_of_pairs = create_maze_overlap(list(COLORS.keys()), OBJECT_TYPES)

from pprint import pprint
print("Initial Keys:")
pprint(initial_keys)
print("Sets of Pairs:")
pprint(sets_of_pairs)


Initial Keys:
[('key', 'purple'), ('key', 'grey'), ('key', 'blue'), ('key', 'green')]
Sets of Pairs:
[(('ball', 'purple'), ('key', 'yellow')),
 (('key', 'red'), ('box', 'green')),
 (('ball', 'yellow'), ('ball', 'grey')),
 (('box', 'grey'), ('box', 'purple'))]


In [3]:
import json

def create_n_pairs_of_mazes(n, colors, object_types, create_fn=create_maze):
    pairs_of_mazes = []

    for _ in range(n):
        # Generate a pair of unique mazes
        init_keys, room_pairs = create_fn(colors, object_types)
        
        # Ensure uniqueness in the pair
        while init_keys in pairs_of_mazes or room_pairs in pairs_of_mazes or init_keys == room_pairs:
            init_keys, room_pairs = create_fn(colors, object_types)
        
        pairs_of_mazes.append({'keys': init_keys, 'pairs': room_pairs})

    return pairs_of_mazes

# ###########
# # shared atributes
# ###########
json_file = 'maze_pairs_shared.json'
create_fn = create_maze_overlap

###########
# regular
###########
# json_file = 'maze_pairs.json'
# create_fn = create_maze

###########


# Example usage
n = 5  # Number of pairs of mazes to generate
maze_pairs = create_n_pairs_of_mazes(n, list(COLORS.keys()), OBJECT_TYPES, create_fn=create_fn)
pprint(maze_pairs)
# Save to a JSON file


with open(json_file, 'w') as file:
    json.dump(maze_pairs, file, indent=4)

print(f'Maze pairs saved to {json_file}')


[{'keys': [('key', 'grey'),
           ('key', 'green'),
           ('key', 'purple'),
           ('key', 'blue')],
  'pairs': [(('key', 'red'), ('ball', 'grey')),
            (('box', 'blue'), ('box', 'red')),
            (('ball', 'blue'), ('ball', 'green')),
            (('box', 'purple'), ('box', 'green'))]},
 {'keys': [('key', 'purple'), ('key', 'blue'), ('key', 'grey'), ('key', 'red')],
  'pairs': [(('ball', 'yellow'), ('box', 'grey')),
            (('box', 'yellow'), ('ball', 'blue')),
            (('ball', 'grey'), ('ball', 'red')),
            (('box', 'purple'), ('key', 'yellow'))]},
 {'keys': [('key', 'purple'),
           ('key', 'blue'),
           ('key', 'green'),
           ('key', 'grey')],
  'pairs': [(('ball', 'grey'), ('ball', 'purple')),
            (('ball', 'yellow'), ('ball', 'blue')),
            (('ball', 'green'), ('key', 'red')),
            (('ball', 'red'), ('key', 'yellow'))]},
 {'keys': [('key', 'yellow'), ('key', 'blue'), ('key', 'red'), ('key', 'grey')

In [None]:
from projects.human_sf import key_room_v3 as key_room
import json 

json_file = 'maze_pairs.json'
with open(json_file, 'r') as file:
    mazes = json.load(file)

maze = mazes[0]
pprint(maze)

import minigrid
import matplotlib.pyplot as plt

env = key_room.KeyRoom(
    maze_config=mazes[0],
    # training=True,
    # swap_episodes=1_000_000,
    # num_dists=0,
    # max_steps_per_room=1_000_000,
    # num_task_rooms=2,
    # basic_only=0,
    # color_rooms=False,
    # flat_task=True
)

def prep_ax(ax):
    ax.clear()  # Clear the previous plot
    # Remove ticks on both the x and y axes
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Optionally, you can also remove the tick labels
    ax.set_xticklabels([])
    ax.set_yticklabels([])

def display(obs, reward, prime=False):
    computed_r = (obs['state_features']*obs['task']).sum(-1)
    suff = lambda x: x+"'" if prime else x

    title = suff("s")
    title += f" | r={reward}, r_={computed_r}"
    title += f"\n {str(obs['state_features'])}"
    fig, ax = plt.subplots(1,1, figsize=(3,3))
    prep_ax(ax)
    ax.imshow(obs['image'])  # Assuming image has shape (C, H, W)
    ax.set_title(title)
    plt.show()
    plt.pause(0.1)  # Pause for a moment to display the image

obs, info = env.reset()
# Create a figure and axis for plotting
fig, ax = plt.subplots(1,1, figsize=(3,3))
prep_ax(ax)
ax.imshow(obs['image'])  # Assuming image has shape (C, H, W)
plt.show()