# Tutorial: Training a model
In this tutorial, we will train an event-based state-space model on a reduced version of the [Spiking Heidelberg Digits](https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/) dataset.
For training on larger datasets or multiple GPUs, we recommend using the training script `run_training.py` instead.

## Data loading
The SHD dataset contains 20 classes, digits from 0 to 9 in both German and English. 
We will use a reduced version of the dataset containing only two digits to train the model to non-trivial performance in reasonable time even on CPUs.

[Download the training and test dataset](https://zenkelab.org/datasets/) and fill the gaps in the dataset class below.

In [1]:
from torch.utils.data import Dataset, DataLoader, random_split
import h5py
import numpy as np

class SpikingHeidelbergDigits(Dataset):
    def __init__(self, path_to_file):
        self.num_classes = 2
        self.num_channels = 700
        self.path_to_file = path_to_file
        
        # load the dataset
        with h5py.File(path_to_file, 'r') as f:
            self.channels = f['spikes']['units'][:]
            self.timesteps = f['spikes']['times'][:]
            self.labels = f['labels'][:]
        
        # filter the dataset to contain only two classes
        mask = (self.labels == 0) | (self.labels == 1)
        self.channels = self.channels[mask]
        self.timesteps = self.timesteps[mask]
        self.labels = self.labels[mask]
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # create tonic-like structured arrays
        dtype = np.dtype([("t", int), ("x", int), ("p", int)])
        struct_arr = np.empty_like(self.channels[idx], dtype=dtype)
        
        # yield timesteps in milliseconds
        timesteps = self.timesteps[idx] * 1e6
        
        struct_arr['t'] = timesteps
        struct_arr['x'] = self.channels[idx]
        struct_arr['p'] = 1
        
        # one-hot encoding of labels (required for CutMix augmentation)
        label = np.eye(self.num_classes)[self.labels[idx]].astype(np.int32)
            
        return struct_arr, label

In [2]:
# Load the training and test dataset
#train_dataset = SpikingHeidelbergDigits('path/to/training/dataset')
#test_dataset = SpikingHeidelbergDigits('path/to/test/dataset')
train_dataset = SpikingHeidelbergDigits('../../../Datasets/SHD/shd_train.h5')
test_dataset = SpikingHeidelbergDigits('../../../Datasets/SHD/shd_test.h5')

Check the length of the datasets to check if the data loading was successful.

In [3]:
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

Number of training samples: 807
Number of test samples: 211


Now, create a validation set by randomly splitting the training dataset, and create data loaders for training, validation, and test datasets.

In [4]:
# Split the training dataset into training and validation
train_dataset, val_dataset = random_split(train_dataset, [int(0.8*len(train_dataset)), len(train_dataset) - int(0.8*len(train_dataset))])

# Create data loaders
from event_ssm.dataloading import event_stream_collate_fn
from functools import partial

collate_fn = partial(event_stream_collate_fn, resolution=(700,), pad_unit=8192)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, drop_last=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

## Model definition
We use the [hydra](https://hydra.cc/docs/intro/) package for efficient configuration management. Define the model configuration in a config file in the `configs` directory.

In [10]:
from hydra import compose, initialize
from omegaconf import OmegaConf, open_dict

with initialize(version_base=None, config_path="configs", job_name="training tutorial"):
    cfg = compose(config_name="base", overrides=["task=tutorial"])

with open_dict(cfg):    
    # optax updates the schedule every iteration and not every epoch
    cfg.optimizer.total_steps = cfg.training.num_epochs * len(train_loader) // cfg.optimizer.accumulation_steps
    cfg.optimizer.warmup_steps = cfg.optimizer.warmup_epochs * len(train_loader) // cfg.optimizer.accumulation_steps
    
    # scale learning rate by batch size
    cfg.optimizer.ssm_lr = cfg.optimizer.ssm_base_lr * cfg.training.per_device_batch_size * cfg.optimizer.accumulation_steps

print(OmegaConf.to_yaml(cfg))

seed: 1234
data_dir: ./data
output_dir: ./outputs/${now:%Y-%m-%d-%H-%M-%S}
model:
  ssm_init:
    C_init: lecun_normal
    dt_min: 0.005
    dt_max: 0.1
    conj_sym: true
    clip_eigs: false
  ssm:
    discretization: async
    d_model: 16
    d_ssm: 16
    ssm_block_size: 8
    num_stages: 1
    num_layers_per_stage: 6
    dropout: 0.1
    classification_mode: timepool
    prenorm: true
    batchnorm: false
    bn_momentum: 0.95
    pooling_stride: 32
    pooling_mode: timepool
    state_expansion_factor: 1
task:
  name: shd-classification
training:
  num_epochs: 10
  per_device_batch_size: 16
  per_device_eval_batch_size: 16
  num_workers: 4
  time_jitter: 1
  spatial_jitter: 0.55
  noise: 35
  max_drop_chunk: 0.02
  drop_event: 0.1
  time_skew: 1.2
  cut_mix: 0.3
  pad_unit: 8192
  validate_on_test: false
optimizer:
  ssm_base_lr: 1.0e-05
  lr_factor: 10
  warmup_epochs: 1
  ssm_weight_decay: 0.0
  weight_decay: 0.01
  schedule: cosine
  accumulation_steps: 1
  total_steps: 200
  

Now, create the model using the configuration defined above.

In [11]:
from event_ssm.ssm import init_S5SSM
from event_ssm.seq_model import BatchClassificationModel

ssm_init_fn = init_S5SSM(**cfg.model.ssm_init)
model = BatchClassificationModel(
    ssm=ssm_init_fn,
    num_classes=test_dataset.num_classes,
    num_embeddings=test_dataset.num_channels,
    **cfg.model.ssm,
)


Initialize the training state by feeding a dummy input

In [12]:
import jax
from event_ssm.train_utils import init_model_state

# pick the first batch from the training loader
batch = next(iter(train_loader))
inputs, targets, timesteps, lengths = batch

# initialize the training state
key = jax.random.PRNGKey(cfg.seed)
state = init_model_state(key, model, inputs, timesteps, lengths, cfg.optimizer)

SSM: 16 -> 16 -> 16 (stride 32 with pooling mode timepool)
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
[*] Model parameter count: 16370
[*] Using gradient accumulation with 1 steps


## Train the model
For training, we implemented a trainer module that makes training as easy as possible. The trainer module hides some boilerplate code for training from the user and provides a simple interface to train the model. It loops through the data loader, computes the loss, and updates the model parameters. Therefore, we need to define training_step and validation_step functions that the loop calls upon the model. These are implemented already, and can be used here.

In [13]:
from event_ssm.train_utils import training_step, evaluation_step
from event_ssm.trainer import TrainerModule

# just-in-time compile the training and evaluation functions
train_step = jax.jit(training_step)
eval_step = jax.jit(evaluation_step)

# initialize the trainer module
num_devices = 1
trainer = TrainerModule(
    train_state=state,
    training_step_fn=train_step,
    evaluation_step_fn=eval_step,
    world_size=num_devices,
    config=cfg,
)

[*] Logging to ./outputs/2024-05-21-12-04-36
[*] Number of model parameters: 16370


We are now ready to start the training loop. 

**Note:** JAX compiles your program just-in-time (JIT) to optimize performance. This means that the first iteration of the training loop will be slower than the following ones.  

In [14]:
# generate random key for dropout
key, dropout_key = jax.random.split(key)

# train the model
trainer.train_model(
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    dropout_key=dropout_key
)

SSM: 16 -> 16 -> 16 (stride 32 with pooling mode timepool)
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16


  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


-----------------------------------------------------------------------------------------
| end of epoch   1 | time per epoch: 79.31s |
| Train Metrics | accuracy:  0.51 | loss:  0.76
SSM: 16 -> 16 -> 16 (stride 32 with pooling mode timepool)
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16




| Eval  Metrics | accuracy:  0.57 | loss:  0.68
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   2 | time per epoch: 21.31s |
| Train Metrics | accuracy:  0.60 | loss:  0.67




| Eval  Metrics | accuracy:  0.62 | loss:  0.63
-----------------------------------------------------------------------------------------




-----------------------------------------------------------------------------------------
| end of epoch   3 | time per epoch: 21.08s |
| Train Metrics | accuracy:  0.69 | loss:  0.62




| Eval  Metrics | accuracy:  0.77 | loss:  0.57
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   4 | time per epoch: 21.59s |
| Train Metrics | accuracy:  0.75 | loss:  0.56




| Eval  Metrics | accuracy:  0.87 | loss:  0.50
-----------------------------------------------------------------------------------------




-----------------------------------------------------------------------------------------
| end of epoch   5 | time per epoch: 20.29s |
| Train Metrics | accuracy:  0.83 | loss:  0.50




| Eval  Metrics | accuracy:  0.86 | loss:  0.42
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   6 | time per epoch: 20.47s |
| Train Metrics | accuracy:  0.82 | loss:  0.44




| Eval  Metrics | accuracy:  0.85 | loss:  0.39
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   7 | time per epoch: 22.43s |
| Train Metrics | accuracy:  0.91 | loss:  0.35




| Eval  Metrics | accuracy:  0.93 | loss:  0.30
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   8 | time per epoch: 20.32s |
| Train Metrics | accuracy:  0.89 | loss:  0.34




| Eval  Metrics | accuracy:  0.94 | loss:  0.27
-----------------------------------------------------------------------------------------
-----------------------------------------------------------------------------------------
| end of epoch   9 | time per epoch: 21.34s |
| Train Metrics | accuracy:  0.92 | loss:  0.29




| Eval  Metrics | accuracy:  0.94 | loss:  0.26
-----------------------------------------------------------------------------------------




-----------------------------------------------------------------------------------------
| end of epoch  10 | time per epoch: 23.07s |
| Train Metrics | accuracy:  0.94 | loss:  0.27




| Eval  Metrics | accuracy:  0.94 | loss:  0.26
-----------------------------------------------------------------------------------------




SSM: 16 -> 16 -> 16 (stride 32 with pooling mode timepool)
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16 (stride 32 with pooling mode timepool)
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
SSM: 16 -> 16 -> 16
-----------------------------------------------------------------------------------------
| End of Training |
| Test  Metrics |  accuracy:  0.96 |  loss:  0.22
-----------------------------------------------------------------------------------------


{'Performance/Test accuracy': 0.9642857313156128,
 'Performance/Test loss': 0.22379271686077118}

## Assignment
The function `apply_ssm` in `event_ssm/ssm.py` implements the recurrent operator with an associative scan. On highly parallel GPUs, this can speed up training on very long sequences. 
On CPUs however, the overhead of the scan operation can slow down training. 
Your task is to implement a CPU-friendly version of the recurrent operator in `event_ssm/ssm.py` and compare the training time with the original implementation.
We suggest implement a step-by-step recurrence with `jax.lax.scan` instead of `jax.lax.associative_scan` for this purpose.