In [1]:
import pickle
from functools import partial
import itertools

import jax
import jax.numpy as np
from jax import random
from jax.scipy.linalg import block_diag
from tqdm import tqdm

from s5.dataloading import Datasets
from s5.seq_model import BatchClassificationModel, RetrievalModel
from s5.ssm import init_S5SSM
from s5.ssm_init import make_DPLR_HiPPO
from s5.train_helpers import (create_train_state, reduce_lr_on_plateau, \
    linear_warmup, cosine_annealing, constant_lr, validate, prep_batch, train_step, update_learning_rate_per_step, cross_entropy_loss)
from s5.gradients import gradient_monitoring_hook

In [2]:
class Args:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)


# Create the args object
args = Args(
    ssm_size=128,
    jax_seed=0,
    blocks=1,
    d_model=64,
    n_layers=1,
    n_classes=2,
    C_init='complex_normal',
    discretization='zoh',
    dt_min=0.001,
    dt_max=0.1,
    conj_sym=True,
    clip_eigs=True,
    bidirectional=False,
    activation_fn='gelu',
    dropout=0.0,
    prenorm=True,
    batchnorm=False,
    bn_momentum=0.9,
    dir_name='./cache_dir',
    bsz=16,
    ssm_size_base=256,
    p_dropout=0.0,
    mode='pool',
    lr_factor=1,
    ssm_lr_base=1e-3,
    weight_decay=0.05,
    opt_config='standard',
    dt_global=False,
    dataset='mnist-classification',
    epochs=1,
    warmup_end=1,
    lr_min=0,
)

let's try to solve a simpler problem, just give me the hidden state at layer i at time t for a single batch.

In [108]:
def hidden_gradients(state,
               rng,
               batch_inputs,
               batch_integration_timesteps,
               model,
               batchnorm,
               ):
    """Computes gradients of hidden states with respect to parameters and inputs"""
    
    def hidden_fn(inputs, params):
        """Returns only the hidden states from the model"""
        if batchnorm:
            _, hiddens = model.apply(
                {"params": params, "batch_stats": state.batch_stats},
                inputs, batch_integration_timesteps,
                rngs={"dropout": rng},
                mutable=False  # We don't need to track batch stats anymore
            )
            print(hiddens)
        else:
            _, hiddens = model.apply(
                {"params": params},
                inputs, batch_integration_timesteps,
                rngs={"dropout": rng},
                mutable=False,
            )
            # print(hiddens.shape)
            # print(np.abs(hiddens).shape)
        return np.abs(hiddens)
    def hidden_at_time(inputs, params, time, dimension):
        """Returns the hidden state at a specific time and dimension"""
        hiddens = hidden_fn(inputs, params) # shape (bsz, time, hidden_dim)
        # print(np.mean(hiddens[:, time, dimension]).shape)
        return np.mean(hiddens[:, time, dimension])
    print(batch_inputs.shape)
    gradient = jax.grad(hidden_at_time, argnums=0)(batch_inputs, state.params, 784//2, 0)
    print(f"Gradient: {gradient.squeeze().shape}")
    print(gradient.squeeze())
    # print(batch_inputs.shape)
    # # Compute gradients with respect to parameters
    # param_grads = jax.grad(lambda p: hidden_fn(p, batch_inputs).sum())(state.params)
    # 
    # # Compute gradients with respect to inputs
    # input_grads = jax.grad(lambda x: hidden_fn(state.params, x).sum())(batch_inputs)
    # 
    # # Get the hidden states
    # hiddens = hidden_fn(state.params, batch_inputs)
    # 
    # return hiddens, param_grads, input_grads

In [109]:
@partial(jax.jit, static_argnums=(5, 6))
def hidden_gradient_step(state,
               rng,
               batch_inputs,
               batch_labels,
               batch_integration_timesteps,
               model,
               batchnorm,
               ):
    """Performs a single training step given a batch of data"""
    def loss_fn(params):

        if batchnorm:
            (logits, hiddens), mod_vars = model.apply(
                {"params": params, "batch_stats": state.batch_stats},
                batch_inputs, batch_integration_timesteps,
                rngs={"dropout": rng},
                mutable=["intermediates", "batch_stats"],
                capture_intermediates=True
            )
        else:
            (logits, hiddens), mod_vars = model.apply(
                {"params": params},
                batch_inputs, batch_integration_timesteps,
                rngs={"dropout": rng},
                mutable=["intermediates"],
            )
            print("got here")
            print(mod_vars)

        loss = np.mean(cross_entropy_loss(logits, batch_labels))

        return loss, (mod_vars, logits)

    (loss, (mod_vars, logits)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    if batchnorm:
        state = state.apply_gradients(grads=grads, batch_stats=mod_vars["batch_stats"])
    else:
        state = state.apply_gradients(grads=grads)
    return state, loss

def train_epoch(state, rng, model, trainloader, seq_len, in_dim, batchnorm, lr_params):
    """
    Training function for an epoch that loops over batches.
    """
    # Store Metrics
    model = model(training=True)
    batch_losses = []

    decay_function, ssm_lr, lr, step, end_step, opt_config, lr_min = lr_params

    for batch_idx, batch in enumerate(itertools.islice(tqdm(trainloader), 10)):
        inputs, labels, integration_times = prep_batch(batch, seq_len, in_dim)
        rng, drop_rng = jax.random.split(rng)
        hiddens, param_grads, input_grads = hidden_gradients(
            state=state,
            rng=rng,
            batch_inputs=inputs,
            batch_integration_timesteps=integration_times,
            model=model,
            batchnorm=batchnorm
        )
        # print("THIS SEEMS TO WORK SOMEHOW")
        # state, loss = hidden_gradient_step(
        #     state,
        #     drop_rng,
        #     inputs,
        #     labels,
        #     integration_times,
        #     model,
        #     batchnorm,
        # )
        batch_losses.append(loss)
        lr_params = (decay_function, ssm_lr, lr, step, end_step, opt_config, lr_min)
        state, step = update_learning_rate_per_step(lr_params, state)

    # Return average loss over batches
    return state, np.mean(np.array(batch_losses)), step

def train(args):
    """
    Main function to train over a certain number of epochs
    """

    best_test_loss = 100000000
    best_test_acc = -10000.0

    ssm_size = args.ssm_size_base
    ssm_lr = args.ssm_lr_base

    # determine the size of initial blocks
    block_size = int(ssm_size / args.blocks)

    # Set global learning rate lr (e.g. encoders, etc.) as function of ssm_lr
    lr = args.lr_factor * ssm_lr

    # Set randomness...
    print("[*] Setting Randomness...")
    key = random.PRNGKey(args.jax_seed)
    init_rng, train_rng = random.split(key, num=2)

    # Get dataset creation function
    create_dataset_fn = Datasets[args.dataset]

    # Dataset dependent logic
    if args.dataset in ["imdb-classification", "listops-classification", "aan-classification"]:
        padded = True
        if args.dataset in ["aan-classification"]:
            # Use retreival model for document matching
            retrieval = True
            print("Using retrieval model for document matching")
        else:
            retrieval = False

    else:
        padded = False
        retrieval = False

    # For speech dataset
    if args.dataset in ["speech35-classification"]:
        speech = True
        print("Will evaluate on both resolutions for speech task")
    else:
        speech = False

    # Create dataset...
    init_rng, key = random.split(init_rng, num=2)
    trainloader, valloader, testloader, aux_dataloaders, n_classes, seq_len, in_dim, train_size = \
        create_dataset_fn(args.dir_name, seed=args.jax_seed, bsz=args.bsz)

    print(f"[*] Starting S5 Training on `{args.dataset}` =>> Initializing...")

    # Initialize state matrix A using approximation to HiPPO-LegS matrix
    Lambda, _, B, V, B_orig = make_DPLR_HiPPO(block_size)

    if args.conj_sym:
        block_size = block_size // 2
        ssm_size = ssm_size // 2

    Lambda = Lambda[:block_size]
    V = V[:, :block_size]
    Vc = V.conj().T

    # If initializing state matrix A as block-diagonal, put HiPPO approximation
    # on each block
    Lambda = (Lambda * np.ones((args.blocks, block_size))).ravel()
    V = block_diag(*([V] * args.blocks))
    Vinv = block_diag(*([Vc] * args.blocks))

    print("Lambda.shape={}".format(Lambda.shape))
    print("V.shape={}".format(V.shape))
    print("Vinv.shape={}".format(Vinv.shape))

    ssm_init_fn = init_S5SSM(H=args.d_model,
                             P=ssm_size,
                             Lambda_re_init=Lambda.real,
                             Lambda_im_init=Lambda.imag,
                             V=V,
                             Vinv=Vinv,
                             C_init=args.C_init,
                             discretization=args.discretization,
                             dt_min=args.dt_min,
                             dt_max=args.dt_max,
                             conj_sym=args.conj_sym,
                             clip_eigs=args.clip_eigs,
                             bidirectional=args.bidirectional)

    if retrieval:
        # Use retrieval head for AAN task
        print("Using Retrieval head for {} task".format(args.dataset))
        model_cls = partial(
            RetrievalModel,
            ssm=ssm_init_fn,
            d_output=n_classes,
            d_model=args.d_model,
            n_layers=args.n_layers,
            padded=padded,
            activation=args.activation_fn,
            dropout=args.p_dropout,
            prenorm=args.prenorm,
            batchnorm=args.batchnorm,
            bn_momentum=args.bn_momentum,
        )

    else:
        model_cls = partial(
            BatchClassificationModel,
            ssm=ssm_init_fn,
            d_output=n_classes,
            d_model=args.d_model,
            n_layers=args.n_layers,
            padded=padded,
            activation=args.activation_fn,
            dropout=args.p_dropout,
            mode=args.mode,
            prenorm=args.prenorm,
            batchnorm=args.batchnorm,
            bn_momentum=args.bn_momentum,
        )

    # initialize training state
    state = create_train_state(model_cls,
                               init_rng,
                               padded,
                               retrieval,
                               in_dim=in_dim,
                               bsz=args.bsz,
                               seq_len=seq_len,
                               weight_decay=args.weight_decay,
                               batchnorm=args.batchnorm,
                               opt_config=args.opt_config,
                               ssm_lr=ssm_lr,
                               lr=lr,
                               dt_global=args.dt_global)
    # print(state.)
    # Training Loop over epochs
    step = 0  # for per step learning rate decay
    steps_per_epoch = int(train_size / args.bsz)
    for epoch in range(args.epochs):
        print(f"[*] Starting Training Epoch {epoch + 1}...")

        if epoch < args.warmup_end:
            print("using linear warmup for epoch {}".format(epoch + 1))
            decay_function = linear_warmup
            end_step = steps_per_epoch * args.warmup_end

        elif args.cosine_anneal:
            print("using cosine annealing for epoch {}".format(epoch + 1))
            decay_function = cosine_annealing
            # for per step learning rate decay
            end_step = steps_per_epoch * args.epochs - (steps_per_epoch * args.warmup_end)
        else:
            print("using constant lr for epoch {}".format(epoch + 1))
            decay_function = constant_lr
            end_step = None

        # TODO: Switch to letting Optax handle this.
        #  Passing this around to manually handle per step learning rate decay.
        lr_params = (decay_function, ssm_lr, lr, step, end_step, args.opt_config, args.lr_min)

        train_rng, skey = random.split(train_rng)
        state, train_loss, step = train_epoch(state,
                                              skey,
                                              model_cls,
                                              trainloader,
                                              seq_len,
                                              in_dim,
                                              args.batchnorm,
                                              lr_params)

In [110]:
train(args)

[*] Setting Randomness...
[*] Generating MNIST Classification Dataset
[*] Starting S5 Training on `mnist-classification` =>> Initializing...
Lambda.shape=(128,)
V.shape=(256, 128)
Vinv.shape=(128, 256)
configuring standard optimization setup
[*] Trainable Parameters: 34122
[*] Starting Training Epoch 1...
using linear warmup for epoch 1


  0%|          | 0/3375 [00:00<?, ?it/s]

(16, 784, 1)


  0%|          | 0/3375 [00:00<?, ?it/s]

Gradient: (16, 784)
[[-1.41488810e-04  1.06206964e-04  8.66978808e-05 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [-4.08149390e-05  1.59114177e-04 -4.33817440e-05 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 1.47605155e-04 -9.43694104e-05 -9.91910783e-05 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 ...
 [-5.14317217e-05  1.58796785e-04 -3.24650755e-05 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [ 1.57166272e-04 -2.47842181e-05 -1.45911908e-04 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]
 [-1.56634182e-04  2.05286233e-05  1.47638697e-04 ...  0.00000000e+00
   0.00000000e+00  0.00000000e+00]]





TypeError: cannot unpack non-iterable NoneType object

In [5]:

# best_test_loss = 100000000
# best_test_acc = -10000.0
# 
# ssm_size = args.ssm_size_base
# ssm_lr = args.ssm_lr_base
# 
# # determine the size of initial blocks
# block_size = int(ssm_size / args.blocks)
# 
# # Set global learning rate lr (e.g. encoders, etc.) as function of ssm_lr
# lr = args.lr_factor * ssm_lr
# 
# # Set randomness...
# print("[*] Setting Randomness...")
# key = random.PRNGKey(args.jax_seed)
# init_rng, train_rng = random.split(key, num=2)
# 
# # Get dataset creation function
# create_dataset_fn = Datasets[args.dataset]
# 
# # Dataset dependent logic
# if args.dataset in ["imdb-classification", "listops-classification", "aan-classification"]:
#     padded = True
#     if args.dataset in ["aan-classification"]:
#         # Use retreival model for document matching
#         retrieval = True
#         print("Using retrieval model for document matching")
#     else:
#         retrieval = False
# 
# else:
#     padded = False
#     retrieval = False
# 
# # For speech dataset
# if args.dataset in ["speech35-classification"]:
#     speech = True
#     print("Will evaluate on both resolutions for speech task")
# else:
#     speech = False
# 
# # Create dataset...
# init_rng, key = random.split(init_rng, num=2)
# trainloader, valloader, testloader, aux_dataloaders, n_classes, seq_len, in_dim, train_size = \
#     create_dataset_fn(args.dir_name, seed=args.jax_seed, bsz=args.bsz)
# 
# print(f"[*] Starting S5 Training on `{args.dataset}` =>> Initializing...")
# 
# # Initialize state matrix A using approximation to HiPPO-LegS matrix
# Lambda, _, B, V, B_orig = make_DPLR_HiPPO(block_size)
# 
# if args.conj_sym:
#     block_size = block_size // 2
#     ssm_size = ssm_size // 2
# 
# Lambda = Lambda[:block_size]
# V = V[:, :block_size]
# Vc = V.conj().T
# 
# # If initializing state matrix A as block-diagonal, put HiPPO approximation
# # on each block
# Lambda = (Lambda * np.ones((args.blocks, block_size))).ravel()
# V = block_diag(*([V] * args.blocks))
# Vinv = block_diag(*([Vc] * args.blocks))
# 
# print("Lambda.shape={}".format(Lambda.shape))
# print("V.shape={}".format(V.shape))
# print("Vinv.shape={}".format(Vinv.shape))
# 
# ssm_init_fn = init_S5SSM(H=args.d_model,
#                          P=ssm_size,
#                          Lambda_re_init=Lambda.real,
#                          Lambda_im_init=Lambda.imag,
#                          V=V,
#                          Vinv=Vinv,
#                          C_init=args.C_init,
#                          discretization=args.discretization,
#                          dt_min=args.dt_min,
#                          dt_max=args.dt_max,
#                          conj_sym=args.conj_sym,
#                          clip_eigs=args.clip_eigs,
#                          bidirectional=args.bidirectional)
# 
# if retrieval:
#     # Use retrieval head for AAN task
#     print("Using Retrieval head for {} task".format(args.dataset))
#     model_cls = partial(
#         RetrievalModel,
#         ssm=ssm_init_fn,
#         d_output=n_classes,
#         d_model=args.d_model,
#         n_layers=args.n_layers,
#         padded=padded,
#         activation=args.activation_fn,
#         dropout=args.p_dropout,
#         prenorm=args.prenorm,
#         batchnorm=args.batchnorm,
#         bn_momentum=args.bn_momentum,
#     )
# 
# else:
#     model_cls = partial(
#         BatchClassificationModel,
#         ssm=ssm_init_fn,
#         d_output=n_classes,
#         d_model=args.d_model,
#         n_layers=args.n_layers,
#         padded=padded,
#         activation=args.activation_fn,
#         dropout=args.p_dropout,
#         mode=args.mode,
#         prenorm=args.prenorm,
#         batchnorm=args.batchnorm,
#         bn_momentum=args.bn_momentum,
#     )
# 
# # initialize training state
# state = create_train_state(model_cls,
#                            init_rng,
#                            padded,
#                            retrieval,
#                            in_dim=in_dim,
#                            bsz=args.bsz,
#                            seq_len=seq_len,
#                            weight_decay=args.weight_decay,
#                            batchnorm=args.batchnorm,
#                            opt_config=args.opt_config,
#                            ssm_lr=ssm_lr,
#                            lr=lr,
#                            dt_global=args.dt_global)
# 
# model = model_cls(training=True)
# batch_losses = []
# train_rng, skey = random.split(train_rng)
# lr_params = (decay_function, ssm_lr, lr, step, end_step, args.opt_config, args.lr_min)
# decay_function, ssm_lr, lr, step, end_step, opt_config, lr_min = lr_params
# 
# for batch_idx, batch in enumerate(itertools.islice(tqdm(trainloader), 10)):
#     inputs, labels, integration_times = prep_batch(batch, seq_len, in_dim)
#     
#     rng, drop_rng = jax.random.split(skey)
#     state, loss = train_step(
#         state,
#         drop_rng,
#         inputs,
#         labels,
#         integration_times,
#         model,
#         args.batchnorm,
#     )
#     # time_indices = np.array([0, seq_len//2, seq_len-1]) # this will be the array of the indices within the mnist image (len 784), we will monitor the gradient at the start, beginning, and end
#     # grad_value = gradient_monitoring_hook(
#     #         state=state,
#     #         batch={'inputs': inputs, 'labels': labels},
#     #         time_indices=time_indices,
#     #         state_idx=0  # Monitor first hidden state dimension
#     #     )
#     batch_losses.append(loss)
#     lr_params = (decay_function, ssm_lr, lr, step, end_step, opt_config, lr_min)
#     state, step = update_learning_rate_per_step(lr_params, state)

In [4]:
def explore_dict_structure_as_directory(d, path=""):
    if isinstance(d, dict):
        for key, value in d.items():
            new_path = f"{path}/{key}" if path else key
            print(new_path)
            explore_dict_structure_as_directory(value, new_path)
    else:
        print(f"{path} -> {type(d).__name__}")

# Assuming `state.params` is the dictionary:
explore_dict_structure_as_directory(state.params)

NameError: name 'state' is not defined

In [None]:
state.model

In [5]:
args.batchnorm

True

In [7]:
train(args)

[*] Setting Randomness...
[*] Generating MNIST Classification Dataset


  Referenced from: <9DBE5D5C-AC87-30CA-96DA-F5BC116EDA2B> /Users/jakub/miniconda3/envs/jax-env/lib/python3.11/site-packages/torchvision/image.so
  warn(


[*] Starting S5 Training on `mnist-classification` =>> Initializing...
Lambda.shape=(128,)
V.shape=(256, 128)
Vinv.shape=(128, 256)
configuring standard optimization setup
[*] Trainable Parameters: 34122
{'decoder': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32), 'kernel': Array([[ 2.81169564e-01, -2.19013482e-01, -6.91609159e-02,
        -1.44792110e-01, -9.88199711e-02,  1.05479829e-01,
         1.18012309e-01, -6.14414811e-02,  1.42612875e-01,
         3.11395880e-02],
       [-1.67155378e-02,  1.40935361e-01, -1.52113259e-01,
        -4.82966378e-02, -2.34872401e-01, -2.13007480e-02,
        -2.57653236e-01, -1.33985296e-01,  1.71239257e-01,
        -9.71533731e-02],
       [-2.81250775e-01,  5.37459664e-02,  7.20960051e-02,
         1.14900894e-01,  1.17362468e-02, -1.00356825e-01,
         2.52808668e-02, -1.68488190e-01, -9.68281925e-03,
        -7.25319237e-02],
       [-1.78734630e-01,  1.10142782e-01, -1.11236036e-01,
        -1.16073810e-01,  1.91563

  0%|          | 0/3375 [00:00<?, ?it/s]


KeyError: 'intermediates'