In [None]:
from pathlib import Path

import imageio
from IPython.display import Image, display
import jax
import jax.numpy as jnp

import atlas
from atlas.hrm.ops import (
    get_initial_hrm_state,
    init_hrm,
    load,
    step,
)
from atlas.hrm.visualization import render_to_img

PROJECT_DIR = Path(atlas.__path__[0]).parent

In [None]:
# Initialize an HRM with specific bounds
hrm = init_hrm(
    root_id=12,
    max_num_rms=13,
    max_num_states=5,
    max_num_edges=1,
    max_num_literals=11,
)
alphabet_size = 11

# Load the file into the HRM we have initialized
load(hrm, PROJECT_DIR / "notebooks/hrms/data/cw_4l_hrm_full.yaml")

# Create a label to go through the HRM
base_label = -jnp.ones((alphabet_size,), dtype=jnp.int32)
label_trace = jnp.array(
    [
        base_label,
        base_label.at[4].set(1),
        base_label.at[5].set(1),
        base_label.at[1].set(1),
        base_label,
        base_label.at[0].set(1),
        base_label.at[1].set(1),
        base_label.at[2].set(1),
        base_label.at[3].set(1),
        base_label.at[1].set(1),
        base_label.at[9].set(1),
    ],
    dtype=jnp.int32,
)

# Create a compiled version of the step function (it will actually be
# compiled the first time it is called)
step_fn = jax.jit(step)

In [None]:
# Render the HRM for each of the steps in the trace
hrm_state = get_initial_hrm_state(hrm)
frames = []
for i, label in enumerate(label_trace):
    hrm_state, _ = step_fn(hrm, hrm_state, label)
    frame = render_to_img(hrm, hrm_state, title=f"Step {i}")
    frames.append(frame)

In [None]:
# Save the images as a GIF
output_path = PROJECT_DIR / "hrm.gif"
imageio.mimsave(output_path, frames, fps=1, format="gif")
display(Image(output_path))