In [None]:
from torch.utils.data import DataLoader,ConcatDataset
from lib.datasets import EpochedDataset
from sklearn.model_selection import train_test_split
from lib.ekyn import *
import numpy as np
import scipy.signal as signal
import matplotlib.pyplot as plt
import torch
from lib.ekyn import *
from torch import nn
from torch.nn.functional import relu
import torch
from lib.env import *
import math
import json

class EpochedDataset(torch.utils.data.Dataset):
    """
    Dataset for training w1 resnets with ekyn data
    """
    def __init__(self,dataset='ekyn',id='A1-1',condition='Vehicle',robust=True,downsampled=True):
        if dataset == 'ekyn':
            if robust:
                X,y = load_ekyn_pt_robust(id=id,condition=condition,downsampled=downsampled)
            else:
                X,y = load_ekyn_pt(id=id,condition=condition)
        elif dataset == 'snezana_mice':
            if robust:
                X,y = load_snezana_mice_pt_robust(id=id,downsampled=downsampled)
            else:
                X,y = load_snezana_mice_pt(id=id)

        self.X = X
        self.y = y
        self.id = id

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return (self.X[idx:idx+1],self.y[idx])

In [None]:

train_idx,test_idx = train_test_split(get_ekyn_ids(),test_size=.25,random_state=0)
trainloader = DataLoader(ConcatDataset([EpochedDataset(id=id,condition=condition,downsampled=True) for id in train_idx[:4] for condition in ['Vehicle','PF']]),batch_size=512,shuffle=True)
devloader = DataLoader(ConcatDataset([EpochedDataset(id=id,condition=condition,downsampled=True) for id in test_idx for condition in ['Vehicle','PF']]),batch_size=512,shuffle=True)

In [None]:
from lib.models import ResidualBlock

class Frodo(nn.Module):
    """
    the little wanderer
    """
    def __init__(self,n_features,device='cuda') -> None:
        super().__init__()
        self.n_features = n_features
        self.block1 = ResidualBlock(1,8,n_features)
        self.block2 = ResidualBlock(8,16,n_features)
        self.block3 = ResidualBlock(16,16,n_features)
        self.block4 = ResidualBlock(16,16,n_features)
        self.block5 = ResidualBlock(16,16,n_features)

        self.gap = nn.AvgPool1d(kernel_size=n_features)
        self.fc1 = nn.Linear(in_features=16,out_features=3)
    def forward(self,x,classification=True):
        x = x.view(-1,1,self.n_features)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.gap(x)
        if(classification):
            x = self.fc1(x.squeeze())
            return x
        else:
            return x.squeeze()

In [None]:
model = Frodo(n_features=1000).cuda()
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-4,weight_decay=1e-2)
criterion = torch.nn.CrossEntropyLoss()
print(sum([p.flatten().size()[0] for p in list(model.parameters())]))

In [None]:
lossi = []
trainlossi = []
trainf1 = []
devlossi = []
devf1 = []
model.train()
from tqdm import tqdm
from lib.utils import training_loop,development_loop
for i in tqdm(range(1000)):
    loss,f1 = training_loop(model=model,trainloader=trainloader,criterion=criterion,optimizer=optimizer,device='cuda')
    trainlossi.append(loss)
    trainf1.append(f1)

    loss,f1 = development_loop(model=model,devloader=devloader,criterion=criterion,device='cuda')
    devlossi.append(loss)
    devf1.append(f1)

    fig,ax = plt.subplots(nrows=1,ncols=2,figsize=(20,4))
    ax[0].plot(trainlossi)
    ax[0].plot(devlossi)
    ax[1].plot(trainf1)
    ax[1].plot(devf1)
    plt.savefig('loss.jpg')
    plt.close()

In [None]:
from lib.utils import evaluate
from sklearn.metrics import ConfusionMatrixDisplay,classification_report

loss,report,y_true,y_pred,y_logits = evaluate(dataloader=devloader,model=model,criterion=criterion,DEVICE=DEVICE)
ConfusionMatrixDisplay.from_predictions(y_true,y_pred,normalize='true')
print(classification_report(y_true,y_pred))
print(loss)