In [None]:
from IPython.display import display
import jax

import atlas
from atlas.envs.xminigrid.types import XMinigridEnvParams
from atlas.envs.xminigrid.labeling_function import XMinigridLabelingFunction
from atlas.hrm.sampling.random_walk import RandomWalkHRMSampler
from atlas.hrm.sampling.single_path_flat import SinglePathFlatHRMSampler
from atlas.hrm.visualization import render_to_img
from atlas.utils.plotting_utils import plotly_avg_states_per_rm_sampling

from pathlib import Path
PROJECT_DIR = (Path(atlas.__path__[0]))

SAMPLER_TYPE = "sequential"
NUM_SAMPLES = 5
DISPLAY_SAMPLES = True

In [None]:
env_params = XMinigridEnvParams()
labeling_function = XMinigridLabelingFunction(env_params)
alphabet = labeling_function.get_str_alphabet()

In [None]:
# Create a sampler and jit+vmap its sampling function
if SAMPLER_TYPE == "random_walk":
    sampler = RandomWalkHRMSampler(
        max_num_rms=1,
        max_num_gen_rms=1,
        max_num_states=6,
        max_num_gen_states=6,
        max_num_edges=1,
        max_num_literals=5,
        avg_state_connectivity=2, 
        alphabet_size=labeling_function.get_alphabet_size(),
        alphabet=labeling_function.get_str_alphabet(),
        enforce_mutex=True,
        enforce_sequentiality=True,
        splittiness=0.5,
        use_transition_compat_matrix=False,
        use_call_compat_matrix=False,
        eps=1.0,
        reward_shaping=False,
        gamma=0.96,
    )
elif SAMPLER_TYPE == "sequential":
    # The sampler produces flat HRMs with a single 3-transition path from
    # the initial to the accepting state
    sampler = SinglePathFlatHRMSampler(
        max_num_rms=1,
        max_num_states=5,
        max_num_edges=5,
        max_num_literals=1,
        alphabet_size=labeling_function.get_alphabet_size(),
        num_transitions=3,
        reward_on_acceptance_only=True,
    )
sample_fn = jax.jit(jax.vmap(sampler.sample))

In [None]:
hrms = sample_fn(jax.random.split(jax.random.PRNGKey(0), NUM_SAMPLES))

# Show distribution
# plotly_avg_states_per_rm_sampling(hrms).show()

# Show HRMs
if DISPLAY_SAMPLES:
    for i in range(NUM_SAMPLES):
        display(render_to_img(
            jax.tree_util.tree_map(lambda x: x[i], hrms),
            alphabet=alphabet,
        ))
