In [None]:
%load_ext autoreload
%autoreload 2
from lib.utils import *
from lib.ekyn import *

In [None]:
train_idx,test_idx = get_merged_ekyn_snezana_mice_train_test_ids()
trainloader = DataLoader(Windowset(*load_paired_list(ids=train_idx),CONFIG['WINDOW_SIZE']),batch_size=CONFIG['BATCH_SIZE'],shuffle=True)
devloader = DataLoader(Windowset(*load_paired_list(ids=test_idx),CONFIG['WINDOW_SIZE']),batch_size=CONFIG['BATCH_SIZE'],shuffle=False)

In [None]:
import math
from torch.nn.functional import relu
from lib.models import ResidualBlockv2

class RegNet(nn.Module):
    def __init__(self, in_features, depthi=[1,1,3,1], widthi=[2,4,16,32], *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.in_features = in_features * 5000

        kernel_size = 10
        padding = 4
        in_features = math.floor(((self.in_features+2*padding-1*(kernel_size-1)-1))/2+1)
        
        self.c1 = nn.Conv1d(in_channels=1,out_channels=widthi[0],kernel_size=kernel_size,stride=2,padding=padding,dilation=1)
        self.ln1 = nn.LayerNorm(normalized_shape=(in_features))
        self.mp1 = nn.MaxPool1d(kernel_size=2,stride=2)

        blocks = []
        in_feature_maps = widthi[0]
        for stage,depth in enumerate(depthi):
            for i in range(depth):
                if i == 0:
                    in_features = math.floor((in_features-1)/2+1)
                    block = ResidualBlockv2(n_features=in_features,in_feature_maps=in_feature_maps,out_feature_maps=widthi[stage])
                    in_feature_maps = widthi[stage]
                else:
                    block = ResidualBlockv2(n_features=in_features,in_feature_maps=widthi[stage],out_feature_maps=widthi[stage])
                blocks.append(block)
            
        self.blocks = nn.Sequential(*blocks)

        self.classifier = nn.Sequential(
            nn.AvgPool1d(kernel_size=in_features),
            nn.Flatten(start_dim=1),
            nn.Linear(widthi[-1],3),
            # nn.ReLU(),
            # nn.Linear(32,3)
        )
    def forward(self,x):
        x = x.reshape(-1,1,self.in_features)
        x = self.c1(x)
        x = self.ln1(x)
        x = relu(x)
        x = self.mp1(x)

        x = self.blocks(x)
        x = self.classifier(x)
        return x

In [None]:
model = RegNet(in_features=1,depthi=[1,1,3,1],widthi=[2,4,16,32])
criterion = nn.CrossEntropyLoss(weight=torch.tensor([18.3846,  2.2810,  1.9716])).to(DEVICE)
optimizer = optim.Adam(model.parameters(),lr=CONFIG['LEARNING_RATE'])
params = sum([p.flatten().size()[0] for p in list(model.parameters())])
print("Params: ",params)
print(model)
lossi = []
trainlossi = []
devlossi = []

In [None]:
Xi,yi = next(iter(trainloader))

In [None]:
model(Xi).shape

In [None]:
model.train()
for epoch in tqdm(range(10)):
    for Xi,yi in trainloader:
        Xi,yi = Xi.to(DEVICE),yi.to(DEVICE)
        logits = model(Xi)
        loss = criterion(logits,yi)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lossi.append(loss.item())
    # if epoch % 1 == 0:
    #     plt.plot(torch.tensor(lossi[:len(lossi) - len(lossi)%10]).view(-1,10).mean(axis=1))
    #     plt.savefig('loss.jpg')
    #     plt.close()

    #     loss,_,_,_,_ = evaluate(dataloader=trainloader,model=model,criterion=criterion,DEVICE=DEVICE)
    #     trainlossi.append(loss)
    #     loss,_,_,_,_ = evaluate(dataloader=devloader,model=model,criterion=criterion,DEVICE=DEVICE)
    #     devlossi.append(loss)
    #     print(loss)

    #     plt.plot(trainlossi)
    #     plt.plot(devlossi)
    #     plt.savefig('dev.jpg')
    #     plt.close()

In [None]:
plt.plot(lossi)

In [None]:
print(torch.tensor(lossi[:len(lossi) - len(lossi)%10]).view(-1,10).mean(axis=1)[-1])

In [None]:
model.to('cpu')
Xi,yi = next(iter(trainloader))

fig,ax = plt.subplots(nrows=len(model.c1.weight),ncols=2,figsize=(8,10))
for i,kernel in enumerate(model.c1.weight.squeeze().detach()):
    ax[i,0].plot(kernel)
for i,kernel in enumerate(model.c1(Xi.reshape(-1,1,5000)).detach()[0]):
    ax[i,1].plot(kernel)

In [None]:
loss,report,y_true,y_pred,y_logits = evaluate(dataloader=trainloader,model=model,criterion=criterion,DEVICE=DEVICE)
ConfusionMatrixDisplay.from_predictions(y_true,y_pred,normalize='true')
print(classification_report(y_true,y_pred))
print(loss)

In [None]:
loss,report,y_true,y_pred,y_logits = evaluate(dataloader=devloader,model=model,criterion=criterion,DEVICE=DEVICE)
ConfusionMatrixDisplay.from_predictions(y_true,y_pred,normalize='true')
print(classification_report(y_true,y_pred))
print(loss)

In [None]:
import matplotlib.patches as patches
start = 190
duration = 200
fig, ax = plt.subplots(nrows=2,ncols=1,figsize=(16,5),dpi=200)

ax[1].plot(X[start:start+duration].flatten(),'black',linewidth=.2)
colors = ['red','green','blue']
epochs = []
for i in range(duration):
    stage = int(y.argmax(axis=1)[start+i])
    ax[1].fill_between([i*5000, (i+1)*5000], y1=-.0003, y2=.0003, color=colors[stage], alpha=0.3)
    epochs.append(i*5000+2500)

red_patch = patches.Patch(color='red', alpha=0.5, label='Paradoxical')
green_patch = patches.Patch(color='green', alpha=0.5, label='Slow-wave')
blue_patch = patches.Patch(color='blue', alpha=0.5, label='Wakefulness')
ax[1].set_ylim([-.0003,.0003])
ax[1].margins(0,0)
plt.legend(handles=[red_patch, green_patch,blue_patch],loc='upper left', bbox_to_anchor=(1.04, 1),
        fancybox=True, shadow=True, ncol=1)
plt.xlabel('epoch (index)')
ax[1].set_ylabel('potential energy (Volts)')
ax[1].set_xticks(epochs[::int(duration/20)],range(duration)[::int(duration/20)]);

ax[0].stackplot(torch.linspace(0,duration-1,duration),y_logits[start:start+duration,0],y_logits[start:start+duration,1],y_logits[start:start+duration,2],colors=['#FF000080','#00FF0080','#0000FF80'])
ax[0].margins(0,0)
ax[0].set_xticks([])