In [None]:
#!cd $HOME/speechBCI/NeuralDecoder && pip install --user -e .
#!cd $HOME/speechBCI/LanguageModelDecoder/runtime/server/x86 && python setup.py install
#!pip install causal-conv1d
#!cd $HOME/mamba && pip install --user -e .
#!cd $HOME/neural_seq_decoder && pip install --user -e .
#!pip install pytorch-lightning
#!pip install tensorboard

### Imports and Script vars

In [None]:
#%load_ext autoreload
#%autoreload 2

In [None]:
import torch
import pickle
import os
import time
import numpy as np
import pytorch_lightning as pl

from models.datasetLoaders import getDatasetLoaders
from models.mamba_phoneme import MambaPhoneme
from models.lightning_wrapper import LightningWrapper
from mamba_ssm.models.config_mamba import MambaConfig
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

# Mamba Hyperparms
args['pppipeline'] = False
args['nLayers'] = 1
ssm_cfg = {
        'd_state'   : 16,
        'd_conv'    : 8,
        'expand'    : 2,
        'dt_rank'   : "auto",
        'dt_min'    : 0.001,
        'dt_max'    : 0.1,
        'dt_init'   : "random",
        'dt_scale'  : 1.0,
        'dt_init_floor' : 1e-4,
        'conv_bias' : True,
        'bias'      : False,
        'use_fast_path' : True,  # Fused kernel options
        }

# Datapaths
args['baseDir'] = os.environ['DATA'] + '/willett2023'
datsetPath = args['baseDir'] + "/competitionData/pytorchTFRecords.pkl"

torch.manual_seed(args["seed"])
np.random.seed(args["seed"])

### Load Datasets

In [None]:
trainLoader, testLoader, loadedData = getDatasetLoaders(
    datsetPath, args['batchSize']
)

args['nDays'] = len(loadedData["train"])

### Initialize model

In [None]:
coreModel = MambaPhoneme(
    config=MambaConfig(
        d_model=args['nInputFeatures'],
        n_layer=args['nLayers'],
        vocab_size=args['nClasses'],
        ssm_cfg=ssm_cfg,
        rms_norm=False,
        residual_in_fp32=False,
        fused_add_norm=False,
    ),
    device=args['device'],
    dtype=torch.float32,
)

In [None]:
print(coreModel.modelName)
print('Number of parameters: ', sum(p.numel() for p in coreModel.parameters() if p.requires_grad))
print('\n--------------------\n')
print(coreModel)
print('\n--------------------\n')

### Train

In [None]:
# Set seeds and setup output directory
timestamp = int(time.time())
outputPath = args['baseDir'] + "/outputs"
logsPath = outputPath + "/logs"
checkpointPath = outputPath + "/checkpoints/" + args['modelName'] + "_" + str(timestamp)

os.makedirs(outputPath, exist_ok=True)

# Define the logger
logger = TensorBoardLogger(logsPath, name=coreModel.modelName)

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath=checkpointPath,
    monitor='val_loss',
    filename='{epoch:02d}-{val_loss:.2f}-{avg_val_cer:.2f}',
    save_last=False,
    save_top_k=2,
    verbose=True,
    mode='min'
)

# Define early stopping callback
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=20,
    verbose=False,
    mode='min'
)

callbacks = [checkpoint_callback, early_stop_callback]

loss_ctc = torch.nn.CTCLoss(blank=0, reduction="mean", zero_infinity=True)

optimizer = torch.optim.Adam(
    coreModel.parameters(),
    lr=args["lrStart"],
    betas=(0.9, 0.999),
    eps=0.1,
    weight_decay=args["l2_decay"],
)

scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=1.0,
    end_factor=args["lrEnd"] / args["lrStart"],
    total_iters=args["nEpochs"],
)

# Training
model = LightningWrapper(coreModel, loss_ctc, optimizer, args, scheduler, willetts_preprocessing_pipeline = args['pppipeline'])
model

In [None]:
trainer = pl.Trainer(
    max_epochs=args["nEpochs"],
    log_every_n_steps=100,
    check_val_every_n_epoch=1,
    logger=logger,
    callbacks=callbacks,
    enable_progress_bar=False
)

trainer.fit(model, trainLoader, val_dataloaders=testLoader)