In [None]:
import ipywidgets as widgets
import os
import pickle
from datetime import datetime
from pathlib import Path
from ipyfilechooser import FileChooser

# Format the date and time as mm-dd-H:M:S
formatted_now = datetime.now().strftime("%m-%d-%H:%M:%S")

In [None]:
class Param:
    def __init__(self, name, arg, widgetType, values = [], default_value = None):
        self.name = name
        self.arg = arg
        self.widgetType = widgetType
        self.values = values
        self.default_value = default_value
        
        self.label_width = '200px'

        self.widget = self._create_widget()

    def _create_widget(self):

        label = widgets.HTML(value=self.name, layout=widgets.Layout(width=self.label_width))

        if self.widgetType == 'title':
            return widgets.HTML(value=f"<b style='font-size:24px;'>{self.name}</b>")
        elif self.widgetType == 'text':
            widget = widgets.Text(value=self.default_value)
        elif self.widgetType == 'dropdown':
            widget = widgets.Dropdown(options=self.values, value=self.default_value)
        elif self.widgetType == 'checkbox':
            widget = widgets.Checkbox(value=self.default_value)
        elif self.widgetType == 'radio':
            widget = widgets.RadioButtons(options=self.values, value=self.default_value)
        elif self.widgetType == 'int':
            widget = widgets.BoundedIntText(
                value=self.default_value,
                min=self.values[0] if len(self.values) > 0 else None,
                max=self.values[1] if len(self.values) > 1 else None,
                step=self.values[2] if len(self.values) > 2 else 1)
        elif self.widgetType == 'float':
            widget = widgets.BoundedFloatText(
                value=self.default_value,
                min=self.values[0] if len(self.values) > 0 else None,
                max=self.values[1] if len(self.values) > 1 else None,
                step=self.values[2] if len(self.values) > 2 else 0.1)
        elif self.widgetType == 'file':
            widget = FileChooser(default=self.default_value)
        else:
            raise ValueError('Unknown widget type: {}'.format(self.widgetType))

        return widgets.HBox([label, widget])


In [None]:
fc = FileChooser()
# Shorthand reset
fc.reset(path=os.environ['DATA'] + '/willett2023', filename='pytorchTFRecords.pkl')
fc

In [None]:
params = [
    # Slurm hyperparams
    Param('Slurm', '', 'title'),

    Param('Partition', 'partition', 'radio', ['interactive', 'short', 'medium'], 'short'),
    Param('Job Name', 'jobName', 'text', [], 'mamba'),
    Param('Time (D-h:m:s)', 'time', 'text', [], '0-04:00:00'),
    Param('Memory (GB)', 'mem', 'int', [1, 128, 4], 16),
    Param('BaseDir', 'baseDir', 'file', [], os.environ['DATA'] + '/willett2023'),
    Param('Dataset', 'dataset', 'file', [], os.environ['DATA'] + '/willett2023/competitionData/pytorchTFRecords.pkl'),

    # Training loop hyperparams
    Param('Training Loop', '', 'title'),

    Param('Epochs', 'nEpochs', 'int', [0, 10000, 100], 500),
    Param('Batch', 'batchSize', 'int', [1, 256], 64),
    Param('Random Seed', 'seed', 'int', [0, ], 0),
    Param('Learning Rate Start', 'lrStart', 'float', [0.0, 1.0, 0.01], 0.2),
    Param('Learning Rate End', 'lrEnd', 'float', [0.0, 1.0, 0.01], 0.02),
    Param('L2 Decay', 'l2_decay', 'float', [0.0, 1e-3, 1e-6], 1e-5),
    Param('Device', 'device', 'radio', ['cuda', 'cpu'], 'cuda'),

    # Mamba Hyperparams
    Param('Mamba', '', 'title'),

    Param('Use Prepoc. peline', 'pppipeline', 'checkbox', None, True),
    Param('Number of Layers', 'nLayers', 'int', [1, 100], 1),
    Param('State Dimension', 'd_state', 'int', [1, 1000], 16),
    Param('Convolution Dimension', 'd_conv', 'int', [1, 1000], 4),
    Param('Expand Factor', 'expand', 'int', [1, 10], 2),
    Param('DT Rank', 'dt_rank', 'radio', ['auto', 'manual'], 'auto'),
    Param('DT Min', 'dt_min', 'float', [0.0001, 1.0, 0.0001], 0.001),
    Param('DT Max', 'dt_max', 'float', [0.0001, 1.0, 0.0001], 0.1),
    Param('DT Init', 'dt_init', 'radio', ['random', 'fixed'], 'random'),
    Param('DT Scale', 'dt_scale', 'float', [0.1, 10.0, 0.1], 1.0),
    Param('DT Init Floor', 'dt_init_floor', 'float', [1e-6, 1e-3, 1e-6], 1e-4),
    Param('Convolution Bias', 'conv_bias', 'checkbox', None, True),
    Param('Bias', 'bias', 'checkbox', None, False),
    Param('Use Fast Path', 'use_fast_path', 'checkbox', None, True),

    # I/O model dimensions
    Param('Model Dimensions', '', 'title'),

    Param('Number of Classes', 'nClasses', 'int', [1, 1000], 40),
    Param('Number of Input Features', 'nInputFeatures', 'int', [1, 10000], 256),
    Param('Number of Hidden Features', 'nHiddenFeatures', 'int', [1, 10000], 256),

    # Data augmentation params
    Param('Data Augmentation', '', 'title'),
    
    Param('White Noise SD', 'whiteNoiseSD', 'float', [0.0, 5.0, 0.1], 0.8),
    Param('Constant Offset SD', 'constantOffsetSD', 'float', [0.0, 5.0, 0.1], 0.2),
    Param('Gaussian Smooth Width', 'gaussianSmoothWidth', 'float', [0.0, 5.0, 0.1], 2.0),
]

widgets.VBox([param.widget for param in params], layout=widgets.Layout(align_items='flex-start'))

In [None]:
# Build args dict
args = {}
for param in params:
    if param.arg != '':
        args[param.arg] = param.widget.children[1].value

args['modelName'] = f"{args['jobName']}_{args['nLayers']}_layers_{args['d_state']}_d_state_{args['d_conv']}_d_conv_{args['expand']}_expand"
if args['pppipeline']:
    args['modelName'] = "p_" + args['modelName']

experimentPath = args['baseDir'] + '/experiments/' + formatted_now
# Create the directory
Path(experimentPath).mkdir(parents=True, exist_ok=True)

# Save pickle args
with open(experimentPath + '/conf.pkl', "wb") as f:
    pickle.dump(args, f)

# SLURM boilerplate substitution
slurm = f"""
#!/bin/bash

#SBATCH --nodes=1
#SBATCH --gres=gpu:v100:1
#SBATCH --ntasks-per-node=8
#SBATCH --mem-per-cpu={args['mem']}G
#SBATCH --time={args['time']}
#SBATCH --job-name={args['jobName']}
#SBATCH --partition={args['partition']}
#SBATCH --output={experimentPath}/{args['jobName']}.%j.out
#SBATCH --mail-type=BEGIN,END
#SBATCH --mail-user={os.environ['EMAIL']}

echo "Training Mamba"

module load Anaconda3/2022.05
source activate $DATA/pnpl

source $HOME/modules.sh

$DATA/pnpl/bin/python $HOME/trainer.py {experimentPath}/conf.pkl
"""

# Execute SLURM script
with open(experimentPath + '/slurm.sh', 'w') as f:
    f.write(slurm)

#!sbatch {experimentPath + '/slurm.sh'}

In [None]:
args