This notebook allows to locally learn a model online on a particular game, and then inspect and debug the model.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import mediapy

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu

import numpy as np

import gameworld.envs # Triggers registration of the environments in Gymnasium
import gymnasium

from axiom import visualize as vis
from axiom import infer as ax
from axiom.models import rmm as rmm_tools
from axiom.models import imm as imm_tools


import defaults
from tqdm import tqdm


import matplotlib.pyplot as plt
import warnings

import pickle
import rich
import os

store_path = "data/models"
if not os.path.exists(store_path):
    os.makedirs(store_path)

# ignore int64 warnings
warnings.filterwarnings("ignore")

## Set-up

Specify the game you want to run on, and specify all hyperparams of the model to evaluate.

In [None]:
game = "Explode"

In [None]:
config = defaults.parse_args(
    [
        f"--game={game}",
        "--num_steps=1000",
        "--planning_rollouts=128",  # reduce planning rollouts for faster experimentation
        # uncomment these lines to run with a "fixed" interacting radius
        # "--fixed_r",
        # "--r_interacting=1.25",
        # "--r_interacting_predict=0.416",
    ]
)

rich.print(config)

## Experiment

Train the agent by running it for the specified number of steps (default 10k). This will generate intermediate reports every 500 steps with the reward curve, visualizing the rMM model, and inspecting the planner of some failed (negative reward) cases.

In [None]:
# reset seed
key = jr.PRNGKey(config.seed)
np.random.seed(config.seed)

# store some data to inspect later
inspect = []
observations = []
nc = []
actions = []
rewards = []
xs = []
probs = []
tracked = []
switches = []
rmm_switches = []
identities = []
used = []
moving = []

# create env
env = gymnasium.make(f'Gameworld-{config.game}-v0')

# reset
obs, _ = env.reset()
obs = obs.astype(np.uint8)
reward = 0

# initialize
key, subkey = jr.split(key)
carry = ax.init(subkey, config, obs, env.action_space.n)

observations.append(obs.astype(np.uint8))
rewards.append(reward)
actions.append(0)
xs.append(carry["x"][config.layer_for_dynamics])
tracked.append(carry["tracked_obj_ids"][config.layer_for_dynamics])
identity_t = imm_tools.infer_identity(
    carry["imm_model"], xs[-1][..., None], config.imm.color_only_identity
)
identities.append(identity_t)

In [None]:
# Helper function to investigate a plan.


def investigate_plan(carry, xs, tracked, observations, actions, idx, t):
    # c = {k: v for k, v in carry.items()}
    c = jtu.tree_map(lambda x: x, carry)
    c["mppi_probs"] = None
    c["current_plan"] = None

    c["x"][config.layer_for_dynamics] = xs[idx + t]
    c["tracked_obj_ids"][config.layer_for_dynamics] = tracked[idx + t]

    key, subkey = jr.split(jr.PRNGKey(0))
    action, c, plan_info = ax.plan_fn(subkey, c, config, env.action_space.n)

    print("predicted vs actual action")
    print(action, actions[idx + t])
    best = jnp.argsort(plan_info["rewards"][:, :, 0].sum(0))[-1]
    print("current plan")
    print(plan_info["current_plan"], plan_info["rewards"][:, :, 0].sum(0)[best])
    print("prev plan")
    print(plan_info["actions"][:, 0], plan_info["rewards"][:, :, 0].sum(0)[0])
    print(plan_info["probs"][0])

    mediapy.show_images(
        {
            "top-1": vis.plot_plan(
                observations[idx + t],
                plan_info,
                tracked[idx + t],
                carry["smm_model"].stats,
                topk=1,
            ),
            "prev": vis.plot_plan(
                observations[idx + t],
                plan_info,
                tracked[idx + t],
                carry["smm_model"].stats,
                indices=jnp.array([0]),
            ),
            "top-20": vis.plot_plan(
                observations[idx + t],
                plan_info,
                tracked[idx + t],
                carry["smm_model"].stats,
                topk=20,
            ),
            "worst-5": vis.plot_plan(
                observations[idx + t],
                plan_info,
                tracked[idx + t],
                carry["smm_model"].stats,
                descending=False,
            ),
        }
    )

    x = jnp.asarray(
        [xs[i] for i in range(idx + t, idx + t + plan_info["states"].shape[0])]
    )
    indices = plan_info["rewards"].sum(0)[:, 0].argsort()[-100:]

    pred_rewards = plan_info["rewards"].sum(0)[indices, 0]
    pred_states = plan_info["states"][:, indices, 0].transpose((1, 0, 2, 3))

    mediapy.show_image(
        vis.rollout_samples_lineplot(
            pred_states,
            x,
            jnp.argwhere(tracked[idx]).flatten(),
            pred_rewards,
            plan_info["rewards"][:, indices[:5], 0],
        ),
        width=800,
    )

    return plan_info

In [None]:
bmr_buffer = None, None

jax.config.update("jax_debug_nans", False)
for t in tqdm(range(config.num_steps)):
    # action selection
    key, subkey = jr.split(key)
    action, carry, plan_info = ax.plan_fn(subkey, carry, config, env.action_space.n)
    probs.append(plan_info["probs"])

    # step env
    obs, reward, done, truncated, info = env.step(action)
    obs = obs.astype(np.uint8)

    # update models
    carry, rec = ax.step_fn(
        carry, config, obs, jnp.array(reward), action, num_tracked=0
    )

    # log stuff
    observations.append(obs)
    actions.append(action)
    rewards.append(reward)
    tracked.append(carry["tracked_obj_ids"][config.layer_for_dynamics])
    nc.append(carry["rmm_model"].used_mask.sum())

    xs.append(carry["x"][config.layer_for_dynamics])
    switches.append(rec["switches"])
    rmm_switches.append(rec["rmm_switches"])
    used.append(carry["used"])
    moving.append(carry["moving"])

    identity_t = imm_tools.infer_identity(
        carry["imm_model"], xs[-1][..., None], config.imm.color_only_identity
    )
    identities.append(identity_t)

    if done:
        obs, _ = env.reset()
        obs = obs.astype(np.uint8)
        reward = 0
        carry, rec = ax.step_fn(
            carry,
            config,
            obs,
            jnp.array(reward),
            jnp.array(0),
            num_tracked=0,
            update=False,
        )

        observations.append(obs)
        rewards.append(reward)
        actions.append(0)
        tracked.append(carry["tracked_obj_ids"][config.layer_for_dynamics])
        nc.append(carry["rmm_model"].used_mask.sum())

        xs.append(carry["x"][config.layer_for_dynamics])
        switches.append(rec["switches"])
        rmm_switches.append(rec["rmm_switches"])
        probs.append(jnp.ones_like(probs[-1]) / probs[-1].shape[-1])
        used.append(carry["used"])
        moving.append(carry["moving"])

    if (t + 1) % config.prune_every == 0:
        key, subkey = jr.split(key)
        new_rmm, pairs, *bmr_buffer = ax.reduce_fn_rmm(
            subkey, carry["rmm_model"], *bmr_buffer
        )
        vis.generate_report(
            rewards, None, nc, carry["rmm_model"], new_rmm, carry["imm_model"]
        )
        carry["rmm_model"] = new_rmm

    if (t + 1) % 500 == 0:
        if config.prune_every >= config.num_steps:
            vis.generate_report(
                rewards,
                None,
                nc,
                carry["rmm_model"],
                carry["rmm_model"],
                carry["imm_model"],
            )

        if jnp.sum(jnp.asarray(rewards[-500:-50]) == -1) > 0:
            idx = jnp.argwhere(jnp.asarray(rewards[-500:]) == -1).flatten()[-1] - 20
            idx = len(rewards) - 500 + idx

            best = jnp.argsort(plan_info["rewards"][:, :, 0].sum(0))[-1]
            mediapy.show_videos(
                {
                    "last_obs": observations[-500:],
                    "fail": observations[idx : idx + 20],
                    "plan": [
                        vis.plot_obs_and_info(None, plan_info["states"][t, best, 0])
                        for t in range(32)
                    ],
                }
            )
            try:
                print("Investigating index", idx)
                investigate_plan(carry, xs, tracked, observations, actions, idx, 1)
            except:
                print("failed to plot stuff")
        else:
            mediapy.show_videos({"last_obs": observations[-500:]})

## Results

Visualize gameplay of the last 1000 frames.

In [None]:
mediapy.show_videos({game: observations[-1000:]}, fps=40, codec="gif")

Visualize the SMM output of the final frame

In [None]:
smm_model = carry["smm_model"]
rmm_model = carry["rmm_model"]

width, height = smm_model.width, smm_model.height
stats = smm_model.stats

mediapy.show_images(
    {
        "qx": vis.plot_qx_smm(
            rec["decoded_mu"][config.layer_for_dynamics],
            rec["decoded_sigma"][config.layer_for_dynamics],
            stats["offset"],
            stats["stdevs"],
            width,
            height,
            rec["qz"][config.layer_for_dynamics],
        ),
        "qz": vis.plot_qz_smm(rec["qz"][config.layer_for_dynamics], width, height),
        "smm_eloglike": vis.plot_elbo_smm(
            rec["smm_eloglike"][config.layer_for_dynamics], width, height
        ),
    }
)

Visualize the discovered "identities" by the model. You can also get a sense of which slots a particular object identity occupies during the experiment.

In [None]:
vis.plot_identity_model(carry["imm_model"], return_ax=True)
plt.show()

Check which slots are used and tracked over time, as well as the inferred identities.

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(jnp.stack(tracked), aspect="auto", interpolation="none", cmap="gray")
axes[0].set_title("Tracked slot timeseries")
axes[1].imshow(
    jnp.stack(xs)[:, :, 2] == 0, aspect="auto", interpolation="none", cmap="gray"
)
axes[1].set_title("Used (and hence visible) timeseries")
axes[2].imshow(jnp.stack(identities), aspect="auto", interpolation="none")
axes[2].set_title("Inferred identity timeseries")
plt.show()

Inspect particular RMM clusters, filtered based on some discrete/continuous observations. Add or modify the filters in the select clause.

In [None]:
# Only show used clusters
select = carry["rmm_model"].used_mask > 0

# Only show clusters with identity 2
select = select & (
    carry["rmm_model"].model.discrete_likelihoods[0].alpha[:, :, 0].argmax(-1) == 2
)

# Only show clusters where it's interacting with object identity 0
select = select & (
    carry["rmm_model"].model.discrete_likelihoods[1].alpha[:, :, 0].argmax(-1) == 0
)

# Only show clusters where the object dissappears (i.e. has SLDS switch 2)
select = select & (
    carry["rmm_model"].model.discrete_likelihoods[-1].alpha[:, :, 0].argmax(-1) == 2
)


vis.plot_rmm(
    carry["rmm_model"],
    carry["imm_model"],
    width=20,
    height=20,
    colorize="cluster",
    indices=jnp.where(select)[0],
    return_ax=True,
)
plt.show()

Use the plot_hybrid_detail to look into the details of a particular cluster.

In [None]:
vis.plot_rmm_detail(carry["rmm_model"].model, jnp.argwhere(select).flatten()[0])

### Inspecting particular planner rollouts

To further debug, let's look at the last failure case and inspect what the planner predicts to do. First find an index before a failure.

In [None]:
start_t, end_t = 0, 10000
num_steps_before = 20
reward_type = -1  # change to 1 to time-lock to a reward

idx = (
    jnp.argwhere(jnp.asarray(rewards[start_t:end_t]) == reward_type).flatten()[0]
    + start_t
    - num_steps_before
)
print(f"Reward: {reward_type} found at t={idx + num_steps_before}")
mediapy.show_videos({"reward": observations[idx : idx + num_steps_before]}, codec="gif")

Given some time-offset t from the starting idx, inspect the plan.

In [None]:
t = 0
plan_info = investigate_plan(carry, xs, tracked, observations, actions, idx, t)

You can also visualize one particular plan index, e.g. just the best one found.

In [None]:
best = jnp.argsort(plan_info["rewards"][:, :, 0].sum(0))[-1]

mediapy.show_image(
    vis.plot_plan(
        observations[idx + t],
        plan_info,
        tracked[idx + t],
        carry["smm_model"].stats,
        indices=jnp.array([best]),
    )
)

Now for each tracked object we can plot out the planned tMM switches. (i.e. for the first sample of the best planned policy)

In [None]:
for object_idx in jnp.argwhere(tracked[idx + t]).flatten():
    print(object_idx, plan_info["switches"][:, best, 0, object_idx])

as well as the inferred rMM clusters.

In [None]:
for object_idx in jnp.argwhere(tracked[idx + t]).flatten():
    print(object_idx, plan_info["rmm_switches"][:, best, 0, object_idx])

Moreover, for a particular rollout, timestep and object_idx you can inspect the inferred rMM cluster in more detail, by calling the predict method directly.

In [None]:
from dataclasses import asdict

from axiom.models.rmm import _to_distance_obs_hybrid
from axiom.models.rmm import predict

policy_idx = best
sample_idx = 0
object_idx = 1
timestep = 1

rmm = carry["rmm_model"]
imm = carry["imm_model"]
# takes the predicted state from the plan at timestep as input
x_t = plan_info["states"][timestep, policy_idx, sample_idx]
# to predict the next state, we need the previous action in the plan
action_t = plan_info["actions"][timestep - 1, policy_idx]

tracked_obj_ids = tracked[idx]
interact_with_static = False
num_switches = config.tmm.n_total_components
object_identities = None
r_interacting_predict = config.rmm.r_interacting_predict
forward_predict = config.rmm.forward_predict
stable_r = config.rmm.stable_r
reward_prob_threshold = config.rmm.reward_prob_threshold

c_obs, d_obs = _to_distance_obs_hybrid(
    imm,
    x_t,
    object_idx,
    action_t,
    tmm_switch=10,  # we pass in a dummy value (gets overwritten in predict)
    reward=jnp.array(0),  # we pass in a dummy value (gets overwritten in predict)
    tracked_obj_mask=tracked_obj_ids,
    max_switches=config.tmm.n_total_components,
    action_dim=rmm.model.discrete_likelihoods[-3].alpha.shape[-2],
    object_identities=None,
    num_object_classes=rmm.model.discrete_likelihoods[0].alpha.shape[1] - 1,
    **asdict(config.rmm),
)
c_obs = c_obs[None, :, None]
d_obs = jtu.tree_map(lambda d: d[None, :, None], d_obs)

# Compute the tMM switching slot using the rMM
switch_slot, pred_reward, ell, qz, r_cluster = predict(
    rmm,
    c_obs,
    d_obs,
    key=None,
    reward_prob_threshold=config.rmm.reward_prob_threshold,
)

mediapy.show_images(
    {
        "imagined": vis.plot_obs_and_info(
            None, plan_info["states"][timestep, policy_idx, sample_idx]
        ),
        "rmm_cluster": vis.plot_rmm(
            carry["rmm_model"],
            carry["imm_model"],
            indices=jnp.argsort(qz)[-1:],
            colorize="cluster",
        ),
    },
    width=300,
)

top5_qz = jnp.argsort(qz)[-5:]
print("Top 5 qzs")
print(top5_qz)
print(qz[top5_qz])

for i in top5_qz:
    if qz[i] > 0.1:
        vis.plot_rmm_detail(
            carry["rmm_model"].model,
            i,
            c_obs=c_obs[0],
            d_obs=jtu.tree_map(lambda d: d[0], d_obs),
        )

print("tMM component")
print(carry["tmm_model"].transitions[switch_slot.item()])

## License

Copyright 2025 VERSES AI, Inc.

Licensed under the VERSES Academic Research License (the “License”);
you may not use this file except in compliance with the license.

You may obtain a copy of the License at

    https://github.com/VersesTech/axiom/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.