In [None]:
import os

from pathlib import Path
import timeit

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import atlas
from atlas.hrm import ops
from atlas.hrm.types import HRM, HRMState, Label

HRMS_FOLDER = Path(atlas.__path__[0]).parent / "notebooks/hrms/data"

In [None]:
def _time_fun(hrm: HRM, hrm_state: HRMState, label: Label, step_fn, reps=5):
    next_hrm_state = None
    durations = []
    for _ in range(reps):
        start = timeit.default_timer()
        next_hrm_state, _ = step_fn(hrm, hrm_state, label)
        next_hrm_state.state_id.block_until_ready()
        duration = timeit.default_timer() - start
        durations.append(duration)
    return next_hrm_state, jnp.array(durations)

def _print_durations_sps(durations):
    for i, duration in enumerate(durations):
        print(f"\t{i}\tt={duration:<10.7f}\tsps={1 / duration:<10.2f}")


# Flat HRM

In [None]:
# Two functions for the experiments: one without jitting, other with
# Initialize here to make sure that the jitted function is not persisted across
# different executions of the same cell!
step_fn_njit = ops.step
step_fn_jit = jax.jit(step_fn_njit)

# Init the HRM
hrm = ops.init_hrm(
    root_id=0, max_num_rms=1, max_num_states=3, max_num_edges=1, max_num_literals=2
)
ops.load(hrm, os.path.join(HRMS_FOLDER, "simple_flat_hrm.yaml"))

print(
    "1. Run a step from the initial HRM state without jit compilation.\n"
    "   Obs: First step takes longer, the following are consistently similar."
)
initial_hrm_state = ops.get_initial_hrm_state(hrm)
_, durations = _time_fun(hrm, initial_hrm_state, jnp.array([1, -1]), step_fn_njit)
_print_durations_sps(durations)

print(
    "2. Run a step from the next HRM state without jit compilation.\n"
    "   Obs: Same as before, but it is not clear whether the first step always takes much longer."
)
next_hrm_state, _ = step_fn_njit(hrm, initial_hrm_state, jnp.array([1, -1]))
_, durations = _time_fun(hrm, initial_hrm_state, jnp.array([1, -1]), step_fn_njit)
_print_durations_sps(durations)

print(
    "3. Run a step from the initial HRM state *with* jit compilation.\n"
    "   Obs: First step takes longer due to compilation (if needed!), the following are much shorter."
)
_, durations = _time_fun(hrm, initial_hrm_state, jnp.array([1, -1]), step_fn_jit)
_print_durations_sps(durations)

print(
    "4. Run a step from the initial HRM state using another label *with* jit compilation.\n"
    "   Obs: First step takes very short time due to the fact that the function was already compiled before."
)
_, durations = _time_fun(hrm, initial_hrm_state, jnp.array([-1, -1]), step_fn_jit)
_print_durations_sps(durations)

print(
    "5. Run a step from the next HRM state *with* jit compilation.\n"
    "   Obs: Same as before."
)
_, durations = _time_fun(hrm, next_hrm_state, jnp.array([-1, -1]), step_fn_jit)
_print_durations_sps(durations)

# Non-Flat 2-Level HRM

In [None]:
step_fn_njit = ops.step
step_fn_jit = jax.jit(step_fn_njit)

# Init the HRM
hrm = ops.init_hrm(
    root_id=0, max_num_rms=3, max_num_states=5, max_num_edges=1, max_num_literals=4
)
ops.load(hrm, os.path.join(HRMS_FOLDER, "cw_diamond_2l_hrm.yaml"))

print(
    "1. Run a step from the initial HRM state without jit compilation.\n"
    "   Obs: First step takes longer, the following are consistently similar."
)
initial_hrm_state = ops.get_initial_hrm_state(hrm)
_, durations = _time_fun(hrm, initial_hrm_state, jnp.array([-1, 1, -1, -1]), step_fn_njit)
_print_durations_sps(durations)

print(
    "2. Run a step from the initial HRM state *with* jit compilation.\n"
    "   Obs: First step takes longer due to compilation, the following are orders of magnitude shorter."
)
_, durations = _time_fun(hrm, initial_hrm_state, jnp.array([-1, 1, -1, -1]), step_fn_jit)
_print_durations_sps(durations)

print(
    "3. Run a step from the initial HRM state using another label *with* jit compilation.\n"
    "   Obs: First step takes very short time due to the fact that the function was already compiled before."
)
_, durations = _time_fun(hrm, initial_hrm_state, jnp.array([1, -1, -1, -1]), step_fn_jit)
_print_durations_sps(durations)

print(
    "4. Run a step from the next HRM state *with* jit compilation.\n"
    "   Obs: Same as before."
)
next_hrm_state, _ = step_fn_njit(hrm, initial_hrm_state, jnp.array([-1, 1, -1, -1]))
_print_durations_sps(durations)

# Non-Flat 4-Level HRM

In [None]:
step_fn_njit = ops.step
step_fn_jit = jax.jit(step_fn_njit)

# Init the HRM
max_alphabet_size = 11
hrm = ops.init_hrm(
    root_id=12, max_num_rms=13, max_num_states=5, max_num_edges=1, max_num_literals=max_alphabet_size
)
ops.load(hrm, os.path.join(HRMS_FOLDER, "cw_4l_hrm_full.yaml"))

# Init base label and initial state
base_label = -jnp.ones((max_alphabet_size,), dtype=jnp.int32)
initial_hrm_state = ops.get_initial_hrm_state(hrm)

print(
    "1. Run a step from the initial HRM state without jit compilation.\n"
    "   Obs: First step takes longer, the following are consistently similar."
)
_, durations = _time_fun(hrm, initial_hrm_state, base_label.at[4].set(1), step_fn_njit)
_print_durations_sps(durations)

print(
    "2. Run a step from the initial HRM state *with* jit compilation.\n"
    "   Obs: First step takes longer due to compilation, the following are orders of magnitude shorter."
)
_, durations = _time_fun(hrm, initial_hrm_state, base_label.at[4].set(1), step_fn_jit)
_print_durations_sps(durations)

print(
    "3. Run a sequence of steps to check if any of them is incurring a\n"
    "   recompilation (i.e., needs considerably more time to run)."
)
hrm_state = initial_hrm_state
durations = []
for label in [
    base_label.at[4].set(1), base_label.at[5].set(1), base_label.at[1].set(1),
    base_label, base_label.at[0].set(1), base_label.at[1].set(1),
    base_label.at[2].set(1), base_label.at[3].set(1), base_label.at[1].set(1),
    base_label.at[9].set(1),
]:
    hrm_state, duration = _time_fun(hrm, hrm_state, label, step_fn_jit, reps=1)
    durations.extend(duration)
_print_durations_sps(durations)

# Sweep Across Different HRM sizes

In [None]:
step_fn_njit = ops.step
step_fn_jit = jax.jit(step_fn_njit)

fig, ax = plt.subplots()    
ax.set_xlabel('alphabet size')
ax.set_ylabel('sps')

MAX_NUM_RMS = [1, 5, 10]
MAX_NUM_STATES = [5, 10]
MAX_NUM_LITERALS = [1, 5, 10]
MAX_ALPHABET_SIZE = [10, 100, 1000]

for max_num_rms in MAX_NUM_RMS:
    for max_num_states in MAX_NUM_STATES:
        for max_num_literals in MAX_NUM_LITERALS:
            alph_durations = []
            for max_alphabet_size in MAX_ALPHABET_SIZE:
                hrm = ops.init_hrm(
                    root_id=0, max_num_rms=max_num_rms, max_num_states=max_num_states, max_num_edges=1, max_num_literals=max_num_literals
                )
                ops.load(hrm, os.path.join(HRMS_FOLDER, "simple_flat_hrm.yaml"))
                
                label = -jnp.ones((max_alphabet_size,), dtype=jnp.int32)
                label = label.at[4].set(1)
                
                hrm_state = ops.get_initial_hrm_state(hrm)
                
                # To compile
                _time_fun(hrm, hrm_state, label, step_fn_jit, reps=1)
                
                # Proper timing
                _, durations = _time_fun(hrm, hrm_state, label, step_fn_jit, reps=10)
                alph_durations.append(1 / durations.mean())
            ax.plot(MAX_ALPHABET_SIZE, alph_durations, label=f"R{max_num_rms}-S{max_num_states}-L{max_num_literals}")

ax.legend()
plt.show()