# Checkpoint Inspection 

Make sure that the checkpoint has the desired property, such as freezing the decoder, and make the encoder trainable.

In [1]:
from track_mjx.agent import checkpointing
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
track_ckpt_path = "/root/vast/scott-yang/track-mjx/model_checkpoints/250611_163603/"
vnl_ckpt_path = "/root/vast/scott-yang/vnl-mjx/model_checkpoints/250617_054335/"

track_ckpt = checkpointing.load_checkpoint_for_eval(track_ckpt_path)
vnl_ckpt = checkpointing.load_checkpoint_for_eval(vnl_ckpt_path)

Loading checkpoint from /root/vast/scott-yang/track-mjx/model_checkpoints/250611_163603/ at step 159
Loading checkpoint from /root/vast/scott-yang/vnl-mjx/model_checkpoints/250617_054335/ at step 1


In [2]:
vnl_decoder = vnl_ckpt["policy"][1]["params"]["decoder"]
vnl_encoder = vnl_ckpt["policy"][1]["params"]["encoder"]
track_decoder = track_ckpt["policy"][1]["params"]["decoder"]
track_encoder = track_ckpt["policy"][1]["params"]["encoder"]

the following code asserts whether the decoder from vnl-playground and vnl-mjx is the same.

In [3]:
import jax

import jax.numpy as jnp


def assert_tree_equal(tree1, tree2):
    def assert_fn(x, y):
        if isinstance(x, jnp.ndarray) and isinstance(y, jnp.ndarray):
            assert x.shape == y.shape, f"Shape mismatch: {x.shape} vs {y.shape}"
            assert jnp.allclose(x, y), "Array values are not equal"
        else:
            assert x == y, f"Value mismatch: {x} vs {y}"
        return None  # map_structure requires a return value

    jax.tree_util.tree_map(assert_fn, tree1, tree2)


# They do have the same decoder
assert_tree_equal(vnl_decoder, track_decoder)

In [4]:
track_encoder

{'LayerNorm_0': {'bias': Array([-0.30320007, -0.09659585, -0.03474695, ..., -0.38287705,
         -0.21678148, -0.18700059], dtype=float32),
  'scale': Array([1.1557599 , 1.0485269 , 0.97350466, ..., 1.2632253 , 0.9909549 ,
         1.0602205 ], dtype=float32)},
 'LayerNorm_1': {'bias': Array([ 0.12961869,  0.05182547,  0.1457802 , ..., -0.14899787,
         -0.12084189, -0.137351  ], dtype=float32),
  'scale': Array([0.98674464, 1.0362111 , 1.0587865 , ..., 1.2773917 , 1.060464  ,
         1.0255595 ], dtype=float32)},
 'LayerNorm_2': {'bias': Array([ 0.01898997,  0.02770316, -0.33379212, ...,  0.03945607,
          0.02037648,  0.03199795], dtype=float32),
  'scale': Array([0.9840111 , 0.91554815, 0.51904404, ..., 1.0008398 , 0.8992554 ,
         0.9187974 ], dtype=float32)},
 'fc2_logvar': {'bias': Array([-0.00880353, -0.5311519 , -0.07623577, -0.04954675, -0.02053421,
         -0.52546614, -0.03946827, -0.01658806, -0.3649219 , -0.02670817,
         -0.02304734, -0.04116683, -0.013

In [5]:
vnl_encoder

{'LayerNorm_0': {'bias': Array([-2.57462356e-03, -5.83517272e-03,  1.03628011e-02,  3.26405419e-03,
         -4.84173594e-04,  4.27233800e-03, -7.33876089e-03,  6.32826006e-03,
          1.87375082e-03,  1.35982723e-03,  4.05338546e-03,  6.26502652e-03,
          3.65910609e-03, -6.28632214e-03, -5.64091559e-03,  1.30444416e-03,
          4.96140355e-03, -2.38967710e-03,  7.55075226e-03,  1.95739907e-03,
          4.21437295e-03, -2.88919779e-03,  5.82450943e-04,  6.16501959e-04,
         -8.00104439e-03,  3.13292607e-03, -1.12989079e-03, -4.07161983e-03,
          1.87105860e-03,  2.14721006e-03,  1.29400883e-02,  1.40996755e-03,
          6.22191839e-03,  4.28046240e-03, -4.55044629e-03, -1.35047489e-03,
         -6.93898823e-04, -2.63355975e-03, -2.91593699e-03, -6.77155913e-04,
          1.73692917e-03, -1.50566033e-04,  6.72755763e-04, -1.11924466e-02,
          3.67012573e-03,  2.19919789e-03,  3.62791703e-03,  3.71443981e-04,
          6.30394975e-03,  2.74428260e-03,  1.4899015

In [16]:
vnl_decoder

{'LayerNorm_0': {'bias': Array([-0.2611952 ,  0.11642706, -0.08585232, ..., -0.07937414,
         -0.14428318,  0.02890676], dtype=float32),
  'scale': Array([1.2039077 , 0.77131075, 1.0509723 , ..., 0.8167469 , 1.0613282 ,
         0.8399805 ], dtype=float32)},
 'LayerNorm_1': {'bias': Array([-0.0281991 ,  0.13998286, -0.09278315, ...,  0.01732832,
          0.0055672 ,  0.00616174], dtype=float32),
  'scale': Array([1.191049  , 0.78367305, 0.8230591 , ..., 0.58677614, 0.8470753 ,
         0.9211945 ], dtype=float32)},
 'LayerNorm_2': {'bias': Array([-0.00235324,  0.07183172,  0.0591858 ,  0.06954012,  0.06868281,
          0.07804646, -0.04419471,  0.05194281,  0.07622087,  0.09152613,
          0.09196759,  0.06744884,  0.07999999,  0.09104982, -0.0850725 ,
          0.03896307,  0.00533044, -0.09660052,  0.045004  ,  0.06158363,
          0.05112151,  0.03934103,  0.06637447,  0.07591236,  0.08283453,
          0.05485939,  0.08449824,  0.0740684 ,  0.07809301,  0.06106612,
       