In [18]:
import os
os.environ["JAX_PLATFORMS"] = "cpu"

from typing import Tuple
import jax
import jax.numpy as jnp
import numpy as np
%load_ext autoreload
%autoreload 2

from src.models.base import ModelConfig
from src.models.rnn import ElmanRNN, LSTM, UnitaryRNN
from src.models.lru import LinearRecurrentUnit
from src.data.copy_dataset import CopyDataset

Array = jnp.ndarray


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
def _jacobian_lookback_frobenius(
    jac_diags: Array,
    wh_weight: Array,
    mask: Array,
) -> Array:
    """Compute M(T) = ||J_t * J_{t-1} * ... * J_{t-T+1}||_F for each time step t and lookback T.

    For time step t, computes jacobian norms for all possible lookback windows:
    - lag=0: identity matrix (norm = sqrt(H))
    - lag=1: ||J_t||_F
    - lag=2: ||J_t * J_{t-1}||_F
    - lag=T: ||J_t * J_{t-1} * ... * J_0||_F

    Args:
        jac_diags: Jacobian diagonals [B, T, H]
        wh_weight: Hidden-to-hidden weight matrix [H, H]
        mask: Sequence mask [B, T]

    Returns:
        Array of shape [B, T, T+1] where output[b, t, lag] = M(lag) at time t
    """
    B, T, H = jac_diags.shape
    wh = wh_weight.astype(jac_diags.dtype)
    diag_seq = jnp.swapaxes(jac_diags, 0, 1)  # [T, B, H]
    mask_seq = jnp.swapaxes(mask, 0, 1)  # [T, B]

    # Identity matrix norm (for lag=0)
    identity_norm = jnp.sqrt(jnp.array(H, dtype=wh.dtype))
    eye = jnp.broadcast_to(jnp.eye(H, dtype=wh.dtype), (B, H, H))

    # Initialize output array [B, T, T+1]
    all_norms = jnp.zeros((B, T, T + 1), dtype=wh.dtype)

    def scan_step(carry, inputs):
        """Process one time step, computing all lookback norms.
        
        carry: (jacobian_history [B, T, H, H], norms [B, T, T+1])
        inputs: (diag_t, mask_t, t_idx) where t_idx is the current time step
        """
        diag_t, mask_t, t_idx = inputs
        jacobian_history, norms = carry
        
        # Compute J_t for this time step
        J_t = diag_t[:, :, None] * wh[None, :, :]  # [B, H, H]
        mask_bool = mask_t > 0.0
        
        # Update jacobian history: shift and prepend J_t
        jacobian_history = jnp.concatenate([J_t[:, None, :, :], jacobian_history[:, :-1, :, :]], axis=1)
        
        # Collect norms for all lags at this time step
        time_step_norms = jnp.zeros((B, T + 1), dtype=wh.dtype)
        time_step_norms = time_step_norms.at[:, 0].set(identity_norm)  # lag=0: identity
        
        # Compute norms for lags 1 to t_idx+1 using fori_loop
        # We'll compute up to T+1 and mask invalid lags
        def compute_lag_norm(lag, carry_state):
            J_cumulative, norms_array = carry_state
            # Only compute if lag <= t_idx + 1
            valid_lag = lag <= (t_idx + 1)
            # Get J_{t-lag+1} from history (most recent is at index 0)
            hist_idx = lag - 1
            J_hist = jacobian_history[:, hist_idx, :, :]  # [B, H, H]
            # Multiply cumulative by this jacobian (going backwards in time)
            J_cumulative = jnp.einsum("bij,bjk->bik", J_hist, J_cumulative)
            # Apply mask
            J_cumulative = jnp.where(mask_bool[:, None, None], J_cumulative, eye)
            # Compute norm
            frob = jnp.linalg.norm(J_cumulative, axis=(-2, -1))
            frob = jnp.where(mask_bool, frob, jnp.zeros_like(frob))
            # Only update if valid lag
            updated_norms = norms_array.at[:, lag].set(frob)
            norms_array = jnp.where(valid_lag, updated_norms, norms_array)
            return (J_cumulative, norms_array)
        
        # Start with identity matrix
        initial_state = (eye, time_step_norms)
        # Compute for lags 1 to T+1 (we'll mask invalid ones)
        final_state = jax.lax.fori_loop(1, T + 1, compute_lag_norm, initial_state)
        _, time_step_norms = final_state
        
        # Update norms array
        norms = norms.at[:, t_idx, :].set(time_step_norms)
        
        return (jacobian_history, norms), None

    # Initialize carry
    jacobian_history = jnp.zeros((B, T, H, H), dtype=wh.dtype)
    initial_carry = (jacobian_history, all_norms)
    
    # Create inputs: (diag_seq, mask_seq, time_indices)
    time_indices = jnp.arange(T)
    inputs = (diag_seq, mask_seq, time_indices)
    
    # Scan over time steps
    final_carry, _ = jax.lax.scan(scan_step, initial_carry, inputs)
    _, final_norms = final_carry
    
    return final_norms  # [B, T, T+1]

In [19]:
def _compute_l_eff(
    lookback_norms: Array,
    epsilon_values: Tuple[float, ...],
    mask: Array,
) -> Array:
    """Compute l_eff(epsilon) = max{T >= 0 : M(T) > epsilon} for each time step and epsilon.

    Args:
        lookback_norms: Array of shape [B, T, T+1] where lookback_norms[b, t, lag] = M(lag) at time t
        epsilon_values: Tuple of epsilon values to compute l_eff for
        mask: Sequence mask [B, T]

    Returns:
        Array of shape [B, T, num_epsilons] where output[b, t, e] = l_eff(epsilon_values[e]) at time t
    """
    B, T, max_lag = lookback_norms.shape

    # Convert epsilon values to array for broadcasting
    epsilons = jnp.array(epsilon_values, dtype=lookback_norms.dtype)  # [num_epsilons]

    # For each time step t, find maximum lag T where M(T) > epsilon
    # lookback_norms[b, t, :] contains M(0), M(1), ..., M(t) (padded with zeros)
    # We need to find the maximum lag where the norm > epsilon

    # Expand dimensions for broadcasting: [B, T, T+1] vs [num_epsilons]
    # We want to compare each norm with each epsilon
    lookback_norms_expanded = lookback_norms[:, :, :, None]  # [B, T, T+1, 1]
    epsilons_expanded = epsilons[None, None, None, :]  # [1, 1, 1, num_epsilons]

    # Compare: [B, T, T+1, num_epsilons]
    greater_than_epsilon = lookback_norms_expanded > epsilons_expanded

    # For each epsilon, find the maximum lag where condition is true
    # Create lag indices: [0, 1, 2, ..., T]
    lag_indices = jnp.arange(max_lag, dtype=lookback_norms.dtype)  # [T+1]
    lag_indices_expanded = lag_indices[None, None, :, None]  # [1, 1, T+1, 1]

    # Where condition is true, use the lag index; where false, use -1
    valid_lags = jnp.where(
        greater_than_epsilon, lag_indices_expanded, -1.0
    )  # [B, T, T+1, num_epsilons]

    # Take maximum over lag dimension: [B, T, num_epsilons]
    l_eff = jnp.max(valid_lags, axis=2)  # [B, T, num_epsilons]

    # If no lag satisfies the condition (all are -1), set to 0
    l_eff = jnp.maximum(l_eff, 0.0)

    # Apply mask: set l_eff to 0 for masked time steps
    mask_expanded = mask[:, :, None]  # [B, T, 1]
    l_eff = l_eff * mask_expanded

    return l_eff

In [2]:
num_classes = 10
model_cfg = ModelConfig(input_dim=num_classes, output_dim=num_classes, hidden_dim=32)
model = ElmanRNN(model_cfg, nonlinearity="relu")
# model = LSTM(model_cfg)
# model = LinearRecurrentUnit(model_cfg)
# model = UnitaryRNN(model_cfg)
params = model.initialize(jax.random.PRNGKey(0))

dataset = CopyDataset(min_lag=10, max_lag=10, batch_size=2, num_classes=num_classes, seq_length=8)
inputs, targets, mask = dataset()
inputs_oh = jax.nn.one_hot(inputs, num_classes, dtype=jnp.float32)
print("inputs one-hot:", inputs_oh.shape)
print("mask:", mask.shape)

ERROR:2025-12-07 21:03:50,527:jax._src.xla_bridge:473: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/home/rphess/conda/envs/recurrent-networks/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 471, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/home/rphess/conda/envs/recurrent-networks/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 328, in initialize
    _check_cuda_versions(raise_on_first_error=True)
  File "/home/rphess/conda/envs/recurrent-networks/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 285, in _check_cuda_versions
    local_device_count = cuda_versions.cuda_device_count()
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:113: operation cuInit(0) failed: CUDA_ERROR_NO_DEVICE


inputs one-hot: (2, 26, 10)
mask: (2, 26)


In [5]:
outputs, runtime_tensors = model.apply(params, inputs_oh, mask, return_features=True)

In [6]:
hidden_to_hidden_weight = params["wh"]["w"]

In [8]:
nonlin_jacobian_diag = runtime_tensors.nonlinearity_jacobian_diag

In [21]:
epsilon_values = (0.1, 0.5)

In [17]:
lookback_norms = _jacobian_lookback_frobenius(nonlin_jacobian_diag, hidden_to_hidden_weight, mask)

In [22]:
l_eff = _compute_l_eff(lookback_norms, epsilon_values, mask)

In [25]:
l_eff[0]

Array([[ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [ 0.,  0.],
       [11.,  6.],
       [11.,  6.],
       [11.,  6.],
       [10.,  6.],
       [10.,  7.],
       [10.,  6.],
       [10.,  6.],
       [11.,  6.]], dtype=float32)