In [None]:
import ipywidgets as widgets
import os
import yaml
from datetime import datetime
from pathlib import Path
from ipyfilechooser import FileChooser
from models.mamba_phoneme import MambaPhoneme
from mamba_ssm.models.config_mamba import MambaConfig
import torch
import dotenv
import math
import torch
import pickle

from einops import repeat

dotenv.load_dotenv()

In [None]:
class Param:
    def __init__(self, name, arg, widgetType, values = [], default_value = None):
        self.name = name
        self.arg = arg
        self.widgetType = widgetType
        self.values = values
        self.default_value = default_value
        
        self.label_width = '200px'
        self.widget = None

    def create_widget(self):

        if self.widget is not None:
            return self.widget

        label = widgets.HTML(value=self.name, layout=widgets.Layout(width=self.label_width))

        if self.widgetType == 'title':
            return widgets.HTML(value=f"<b style='font-size:24px;'>{self.name}</b>")
        elif self.widgetType == 'text':
            widget = widgets.Text(value=self.default_value)
        elif self.widgetType == 'dropdown':
            widget = widgets.Dropdown(options=self.values, value=self.default_value)
        elif self.widgetType == 'checkbox':
            widget = widgets.Checkbox(value=self.default_value)
        elif self.widgetType == 'radio':
            widget = widgets.RadioButtons(options=self.values, value=self.default_value)
        elif self.widgetType == 'int':
            widget = widgets.BoundedIntText(
                value=self.default_value,
                min=self.values[0] if len(self.values) > 0 else None,
                max=self.values[1] if len(self.values) > 1 else None,
                step=self.values[2] if len(self.values) > 2 else 1)
        elif self.widgetType == 'float':
            widget = widgets.BoundedFloatText(
                value=self.default_value,
                min=self.values[0] if len(self.values) > 0 else None,
                max=self.values[1] if len(self.values) > 1 else None,
                step=self.values[2] if len(self.values) > 2 else 0.1)
        elif self.widgetType == 'file':
            widget = FileChooser(self.default_value)
            if len(self.values) > 0:
                widget.default_filename = self.values[0]
        else:
            raise ValueError('Unknown widget type: {}'.format(self.widgetType))

        self.widget = widgets.HBox([label, widget])
        return self.widget


### Replicate existing experiment

In [None]:
baseDir = os.environ['DATA'] + '/willett2023'
experiment_replicate = FileChooser(baseDir + "/experiments/")
display(experiment_replicate)

### Configure experiment

In [None]:
model_type = Param('Model Type', 'modelType', 'radio', ['mamba', 's4'], 's4')
general_params = [
    # Slurm hyperparams
    Param('Slurm', '', 'title'),

    #Param('Partition', 'partition', 'radio', ['interactive', 'short', 'medium'], 'short'),
    model_type,
    #Param('Time (D-h:m:s)', 'time', 'text', [], '0-04:00:00'),
    #Param('Memory (GB)', 'mem', 'int', [1, 128, 4], 16),
    Param('BaseDir', 'baseDir', 'file', [], os.environ['DATA'] + '/willett2023'),
    Param('Dataset', 'dataset', 'file', ['pytorchTFRecords.pkl'], os.environ['DATA'] + '/willett2023/competitionData'),

    # Training loop hyperparams
    Param('Training Loop', '', 'title'),

    Param('Epochs', 'nEpochs', 'int', [0, 10000, 100], 100),
    Param('Batch', 'batchSize', 'int', [1, 256], 64),
    Param('Random Seed', 'seed', 'int', [0, 1000000], 0),
    Param('Use Scheduler', 'useScheduler', 'checkbox', None, True),
    Param('Learning Rate Finder', 'lrFinder', 'checkbox', None, False),
    Param('Learning Rate Start/Max', 'lr', 'float', [0.0, 1.0, 0.01], 0.02),
    Param('Learning Rate End/Min', 'lrEnd', 'float', [0.0, 1.0, 0.001], 0.002),
    Param('Weight Decay', 'weightDecay', 'float', [0.0, 1e-3, 1e-6], 0.),
    Param('Device', 'device', 'radio', ['cuda', 'cpu'], 'cuda'),
    Param('Log wandb', 'log_wandb', 'checkbox', None, True),

    # I/O model dimensions
    Param('Model Dimensions', '', 'title'),

    Param('Number of Classes', 'nClasses', 'int', [1, 1000], 40),
    Param('Number of Input Features', 'nInputFeatures', 'int', [1, 10000], 256),
    Param('Number of Hidden Features', 'nHiddenFeatures', 'int', [1, 10000], 256),

    # Data augmentation params
    Param('Data Augmentation', '', 'title'),
    
    Param('White Noise SD', 'whiteNoiseSD', 'float', [0.0, 5.0, 0.1], 0.8),
    Param('Constant Offset SD', 'constantOffsetSD', 'float', [0.0, 5.0, 0.1], 0.2),
    Param('Gaussian Smooth Width', 'gaussianSmoothWidth', 'float', [0.0, 5.0, 0.1], 2.0),
]

s4_params = [
    Param('S4', '', 'title'),

    Param('Resume from Checkpoint', 'resume', 'file', [], os.environ['DATA'] + '/willett2023/experiments'),
    Param('Use Prepoc. peline', 'pppipeline', 'checkbox', None, False),
    Param('Number of Layers', 'nLayers', 'int', [1, 100], 1),
    Param('State Dimension', 'd_state', 'int', [1, 1000], 16),
    Param('Train DT', 'train_log_dt', 'checkbox', None, True),
    Param('Train D', 'train_D', 'checkbox', None, True),
    Param('Train C', 'train_C', 'checkbox', None, True),
    Param('Train A real', 'train_log_A_real', 'checkbox', None, True),
    Param('Train A imag', 'train_A_imag', 'checkbox', None, True),
    Param('Custom Init', 'custom_init', 'checkbox', None, False),
    Param('Custom Identifier', 'custom_identifier', 'text', [], ''),
]

mamba_params = [
    Param('Mamba', '', 'title'),

    Param('Resume from Checkpoint', 'resume', 'file', [], os.environ['DATA'] + '/willett2023/experiments'),
    Param('Use Prepoc. peline', 'pppipeline', 'checkbox', None, False),
    Param('Number of Layers', 'nLayers', 'int', [1, 100], 1),
    Param('State Dimension', 'd_state', 'int', [1, 1000], 16),
    Param('Convolution Dimension', 'd_conv', 'int', [1, 1000], 4),
    Param('Expand Factor', 'expand', 'int', [1, 10], 2),
    Param('DT Rank', 'dt_rank', 'radio', ['auto', 'manual'], 'auto'),
    Param('DT Min', 'dt_min', 'float', [0.0001, 1.0, 0.0001], 0.001),
    Param('DT Max', 'dt_max', 'float', [0.0001, 1.0, 0.0001], 0.1),
    Param('DT Init', 'dt_init', 'radio', ['random', 'fixed'], 'random'),
    Param('DT Scale', 'dt_scale', 'float', [0.1, 10.0, 0.1], 1.0),
    Param('DT Init Floor', 'dt_init_floor', 'float', [1e-6, 1e-3, 1e-6], 1e-4),
    Param('Convolution Bias', 'conv_bias', 'checkbox', None, True),
    Param('Bias', 'bias', 'checkbox', None, False),
    Param('Use Fast Path', 'use_fast_path', 'checkbox', None, True),
]

if experiment_replicate.selected:
    with open(experiment_replicate.selected, 'r') as f:
        args = yaml.safe_load(f)
    
    for param in general_params:
        if param.arg in args and param.widgetType != 'file':
            param.default_value = args[param.arg]

container = widgets.VBox()

def change_model_options(*args):
    if model_type.widget.children[1].value == 's4':
        params = general_params + s4_params
        container.children = [param.create_widget() for param in params]
    elif model_type.widget.children[1].value == 'mamba':
        params = general_params + mamba_params
        container.children = [param.create_widget() for param in params]

model_type.create_widget().children[1].observe(change_model_options, 'value')

params = general_params + s4_params
container.children = [param.create_widget() for param in params]

display(container)

In [None]:
# Build args dict
args = {}
for param in params:
    if param.arg != '':
        args[param.arg] = param.widget.children[1].value

### SSM Init params

In [None]:
torch.manual_seed(args['seed'])

log_dt_inits = {
    'default': lambda H, dt_min=0.001, dt_max=0.1: torch.rand(H) * (
                math.log(dt_max) - math.log(dt_min)
            ) + math.log(dt_min),
    'uniform': lambda H, dt: torch.ones(H) * math.log(dt),
}

D_inits = {
    'default': lambda H: torch.rand(H),
    'zero': lambda H: torch.zeros(H),
}

C_inits = {
    'default': lambda H, N: torch.randn(H, N // 2, dtype=torch.cfloat),
}

log_A_real_inits = {
    'default': lambda H, N: torch.log(0.5 * torch.ones(H, N//2)),
    'unit_circle': lambda H, N: torch.log(torch.zeros(H, N//2)),
}

A_imag_inits = {
    'default': lambda H, N: math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H),
}

In [None]:
H = args['nInputFeatures']
N = args['d_state']

initializers = {
    'log_dt': log_dt_inits['default'](H),
    'D': D_inits['zero'](H),
    'C': C_inits['default'](H, N),
    'log_A_real': log_A_real_inits['default'](H, N),
    'A_imag': A_imag_inits['default'](H, N),
}

In [None]:
# Format the date and time as mm-dd-H:M:S
formatted_now = datetime.now().strftime("%m-%d-%H:%M:%S")

# Build args dict
args = {}
for param in params:
    if param.arg != '':
        args[param.arg] = param.widget.children[1].value

args['modelName'] = f"{args['custom_identifier']}_{args['modelType']}_{args['nLayers']}_layers_{args['d_state']}_d_state"
if args['pppipeline']:
    args['modelName'] = "p_" + args['modelName']

experimentPath = args['baseDir'] + '/experiments/' + formatted_now + '_' + args['modelName']
argsPath = experimentPath + '/args.yaml'

args['experimentPath'] = experimentPath
# Create the directory
Path(experimentPath).mkdir(parents=True, exist_ok=True)

# Save yaml args
with open(argsPath, 'w') as f:
    yaml.dump(args, f)

if args['custom_init']:
    with open(experimentPath + '/ssm_init.pkl', 'wb') as f:
        pickle.dump(initializers, f)

print(f"$DATA/pnpl/bin/python $HOME/trainer.py {argsPath}")

In [None]:
# Get number of parameters
coreModel = MambaPhoneme(
    config=MambaConfig(
        d_model=args['nInputFeatures'],
        n_layer=args['nLayers'],
        vocab_size=args['nClasses'],
        ssm_cfg={
        'd_state'   : args["d_state"],
        'd_conv'    : args["d_conv"],
        'expand'    : args["expand"],
        'dt_rank'   : args["dt_rank"],
        'dt_min'    : args["dt_min"],
        'dt_max'    : args["dt_max"],
        'dt_init'   : args["dt_init"],
        'dt_scale'  : args["dt_scale"],
        'dt_init_floor' : args["dt_init_floor"],
        'conv_bias' : args["conv_bias"],
        'bias'      : args["bias"],
        'use_fast_path' : args["use_fast_path"],  # Fused kernel options
        },
        rms_norm=False,
        residual_in_fp32=False,
        fused_add_norm=False,
    ),
    device=args['device'],
    dtype=torch.float32,
)
args['nParams'] = sum(p.numel() for p in coreModel.parameters() if p.requires_grad)
if args['pppipeline']:
    args['nParams'] += 3168553

# SLURM boilerplate substitution
slurm = f"""#!/bin/bash

#SBATCH --nodes=1
#SBATCH --gres=gpu:1
#SBATCH --constraint='gpu_sku:V100|gpu_sku:A100'
#SBATCH --ntasks-per-node=8
#SBATCH --mem-per-cpu={args['mem']}G
#SBATCH --time={args['time']}
#SBATCH --job-name={args['modelType']}
#SBATCH --partition={args['partition']}
#SBATCH --output={experimentPath}/{args['modelType']}.%j.out
#SBATCH --mail-type=BEGIN,END
#SBATCH --mail-user={os.environ['EMAIL']}

echo "Training Mamba"

module load Anaconda3/2022.05
source activate $DATA/pnpl

source $HOME/modules.sh

$DATA/pnpl/bin/python $HOME/trainer.py {argsPath}
"""

# Execute SLURM script
with open(experimentPath + '/slurm.sh', 'w') as f:
    f.write(slurm)

print(f"Experiment {args['modelName']} with # params: {args['nParams']} saved to {argsPath}")
print(f"$DATA/pnpl/bin/python $HOME/trainer.py {argsPath}")

#!sbatch {experimentPath + '/slurm.sh'}