In [None]:
from lib.nursing import *
from lib.utils import *

X,y = load_feature_label_pair(index=0)
X = window_epoched_signal(X,windowsize=101)
trainloader = DataLoader(TensorDataset(X,y),batch_size=32,shuffle=True)
X,y = load_feature_label_pair(index=1)
X = window_epoched_signal(X,windowsize=101)
devloader = DataLoader(TensorDataset(X,y),batch_size=32,shuffle=False)

In [None]:
from torch import nn

class MODEL(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.c1 = nn.Conv1d(in_channels=3,out_channels=1,kernel_size=5,stride=1,padding='same',bias=False)
        self.lstm = nn.LSTM(input_size=1,hidden_size=64,batch_first=True,bidirectional=True)
        self.classifier = nn.Sequential(
            nn.Linear(128,32),
            nn.ReLU(),
            nn.Linear(32,1)
        )
    def forward(self,x):
        x = self.c1(x)
        x = x.transpose(1,2)
        o,_ = self.lstm(x)
        o = o[:,-1,:]
        x = self.classifier(o)
        return x
model = MODEL()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
lossi = []
f1i = []

for i in range(5):
    loss,metric,y_true,y_pred,y_logits = evaluate(devloader,model,criterion)
    f1i.append(metric['f1'])
    for Xi,yi in trainloader:
        logits = model(Xi)
        loss = criterion(logits,yi)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lossi.append(loss.item())
loss,metric,y_true,y_pred,y_logits = evaluate(devloader,model,criterion)
f1i.append(metric['f1'])

In [None]:
plt.plot(torch.tensor(lossi[:len(lossi) - len(lossi)%10]).view(-1,10).mean(axis=1))

In [None]:
plt.plot(f1i)

In [None]:
plt.plot(Xi[0].T)
plt.plot(model.c1(Xi).squeeze(1)[0].detach())
plt.plot

In [None]:
loss,metric,y_true,y_pred,y_logits = evaluate(devloader,model,criterion)
fig,axes = plt.subplots(nrows=2,ncols=1,figsize=(30,5),gridspec_kw={'height_ratios': [1,5]})
axes[0].plot(y_true,'g',linewidth=.5)
axes[0].plot(y_pred,'r',linewidth=.5)
axes[0].axis('off')
plt.stackplot(torch.arange(len(y_logits)),torch.hstack([torch.ones_like(y_logits) - y_logits,y_logits]).T)
plt.savefig(f'logits.pdf',bbox_inches='tight')
cm_grid(y_true,y_pred,save_path=f'cm.jpg')

In [28]:
torch.save(model,'model.pt')