In [None]:
from pathlib import Path
path_root = '../'
sys.path.append(str(path_root))

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score

import torch
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

from utils.ibl_data_loaders import SingleSessionDataset
from src.models import ReducedRankModel

In [None]:
eids = [fname.split('.')[0] for fname in os.listdir('./data/') if fname.endswith('npz')]
print(eids)

In [1]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = {
    'target': 'wheel_speed',
    'temporal_rank': 15,
    'batch_size': 4,
    'learning_rate': 5e-3,
    'weight_decay': 1e-1,
    'lr_factor': 0.1,
    'lr_patience': 5,
    'device': DEVICE,
    'n_workers': os.cpu_count()
}

In [None]:
eid = eids[1]
session_dataset = SingleSessionDataset(eid, config['target'], DEVICE)
config.update({'n_units': session_dataset.n_units, 'n_t_steps': session_dataset.n_t_steps})

In [None]:
data_len = len(session_dataset)
train_len, val_len = int(0.8*data_len), int(0.1*data_len)
test_len = data_len - train_len - val_len

train, val, test = torch.utils.data.random_split(
    session_dataset, [train_len, val_len, test_len]
)

In [None]:
# save model
# model_path = './models/'
# os.makedirs(model_path, exist_ok=True)
# checkpoint_callback = ModelCheckpoint(
#     monitor='val_loss', dirpath=model_path, filename=f'{eid}-{config['target']}-{epoch:02d}-{val_loss:.2f}'
# )

trainer = Trainer(
    max_epochs=500,
    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=50)],
    # callbacks=[checkpoint_callback],
)
# define the Model
model = ReducedRankModel(train, test, val, config)
trainer.fit(model)
trainer.test()

# retrieve the best checkpoint after training
# checkpoint_callback.best_model_path