In [None]:
import os

from daart.data import DataGenerator
from daart.models import HardSegmenter
from daart.transforms import ZScore

### build data generator

In [None]:
# define data paths
base_dir = '/media/mattw/fly/behavior'
expt_id = '2019_06_26_fly2'
markers_file = os.path.join(base_dir, 'labels', expt_id + '_labeled.h5')
labels_file = os.path.join(
    base_dir, 'segmentation', 'states-v2.1', expt_id + '_beh-states-heuristic.pkl')

# define data generator signals
signals = ['markers', 'labels']
transforms = [ZScore(), None]
paths = [markers_file, labels_file]
device = 'cuda'  # 'cpu' | 'cuda'

trial_splits = {
    'train_tr': 9,
    'val_tr': 1,
    'test_tr': 0,
    'gap_tr': 0
}

# build data generator
data_gen = DataGenerator(
    [expt_id], [signals], [transforms], [paths], device=device, batch_size=500, 
    trial_splits=trial_splits)
print(data_gen)

In [None]:
# see what data generator returns
data, dataset = data_gen.next_batch('train')
print(data['batch_idx'])
print(data['markers'].shape)
print(data['labels'].shape)

### build model

In [None]:
# define model params
hparams = {
    'model_type': 'temporal-mlp',
    'input_size': 16,  # dimensionality of markers
    'output_size': 5,  # number of classes
    'n_hid_layers': 1,  # hidden layers in network
    'n_hid_units': 32,  # hidden units per hidden layer
    'n_lags': 16,  # width of temporal convolution window
    'activation': 'relu',  # layer nonlinearity
}

# build model
model = HardSegmenter(hparams)
model.to(device)
print(model)

### train model

In [None]:
# define training params
train_kwargs = {
    'learning_rate': 1e-4,  # adam learning rate
    'l2_reg': 0,  # general l2 reg on parameters
    'min_epochs': 1,  # minimum number of training epochs
    'max_epochs': 2,  # maximum number of training epochs
    'val_check_interval': 1,  # requency with which to log performance on val data
    'rng_seed_train': 0,  # control order in which data are served to model
    'enable_early_stop': False,  # True to use early stopping; False will use max_epochs
    'early_stop_history': 10,  # epochs over which to average early stopping metric
    'save_last_model': True,  # true to save out last (as well as best) model
}

# fit model!
model.fit(data_gen, save_path='/media/mattw/fly/behavior/daart-demo', **train_kwargs)