In [None]:
import time

from IPython.display import display
import jax

from atlas.envs.xminigrid.labeling_function import XMinigridLabelingFunction
from atlas.envs.xminigrid.level_sampling.single_room import XMinigridSingleRoomLevelSampler
from atlas.envs.xminigrid.level_sampling.two_rooms import XMinigridTwoRoomsLevelSampler
from atlas.envs.xminigrid.level_sampling.four_rooms import XMinigridFourRoomsLevelSampler
from atlas.envs.xminigrid.level_sampling.six_rooms import XMinigridSixRoomsLevelSampler
from atlas.envs.xminigrid.level_sampling.meta import XMinigridMetaLevelSampler
from atlas.envs.xminigrid.renderer import XMinigridRenderer
from atlas.envs.xminigrid.problem_sampling.hrm_conditioned import XMinigridHRMConditionedProblemSampler
from atlas.envs.xminigrid.problem_sampling.level_conditioned import XMinigridLevelConditionedProblemSampler
from atlas.envs.xminigrid.types import XMinigridEnvParams
from atlas.hrm.sampling.meta import MetaHRMSampler
from atlas.hrm.sampling.random_walk import RandomWalkHRMSampler
from atlas.hrm.sampling.single_path_flat import SinglePathFlatHRMSampler
from atlas.hrm.visualization import render_to_img
from atlas.problem_samplers.independent import IndependentProblemSampler

In [None]:
renderer = XMinigridRenderer()
env_params = XMinigridEnvParams(height=19, width=19)
label_fn = XMinigridLabelingFunction(env_params)
alphabet = label_fn.get_str_alphabet()
level_sampler = XMinigridMetaLevelSampler(env_params, [
    XMinigridSingleRoomLevelSampler(env_params),
    XMinigridTwoRoomsLevelSampler(env_params),
    XMinigridFourRoomsLevelSampler(env_params),
    XMinigridSixRoomsLevelSampler(env_params),
])

In [None]:
hrm_sampler_name = "meta"

if hrm_sampler_name == "meta":
    hrm_sampler_args = dict(
        max_num_rms=1,
        max_num_states=5,
        max_num_edges=1,
        max_num_literals=5,
        alphabet_size=label_fn.get_alphabet_size(),
    )
    
    hrm_sampler = MetaHRMSampler(
        **hrm_sampler_args,
        samplers=[
            SinglePathFlatHRMSampler(
                **hrm_sampler_args, num_transitions=num_transitions, reward_on_acceptance_only=True
            )
            for num_transitions in range(1, 5)
        ]
    )
elif hrm_sampler_name == "random_walk":
    hrm_sampler = RandomWalkHRMSampler(
        max_num_rms=1,
        max_num_gen_rms=1,
        max_num_states=6,
        max_num_gen_states=6,
        max_num_edges=1,
        max_num_literals=5,
        alphabet_size=label_fn.get_alphabet_size(),
        alphabet=label_fn.get_str_alphabet(),
        enforce_mutex=True,
        enforce_sequentiality=True,
        splittiness=0.5,
        use_transition_compat_matrix=False,
        use_call_compat_matrix=False,
        eps=1.0,
        reward_shaping=False,
        gamma=0.96,
    )

In [None]:
def speed_test(problem_sampler_cls):
    problem_sampler = problem_sampler_cls(level_sampler, hrm_sampler, label_fn)
    sample_fn = jax.jit(jax.vmap(problem_sampler.sample))

    sample_fn(jax.random.split(jax.random.PRNGKey(0), 4096))

    start = time.time()
    levels, hrms = sample_fn(jax.random.split(jax.random.PRNGKey(0), 4096))
    duration = time.time() - start

    return duration, levels, hrms

In [None]:
def render(levels, hrms, n=1):
    for i in range(n):
        print(f"Problem {i}")
        level, hrm = jax.tree_util.tree_map(lambda x: x[i], (levels, hrms))
        display(renderer.render_level(level))
        display(render_to_img(hrm, alphabet=alphabet))

In [None]:
# Independent problem sampler
t, levels, hrms = speed_test(IndependentProblemSampler)
print("independent", t)
render(levels, hrms)

In [None]:
# Level-conditioned problem sampler
t, levels, hrms = speed_test(XMinigridLevelConditionedProblemSampler)
print("level-conditioned", t)
render(levels, hrms)

In [None]:
# HRM-conditioned problem sampler
t, levels, hrms = speed_test(XMinigridHRMConditionedProblemSampler)
print("hrm-conditioned", t)
render(levels, hrms)