In [None]:
from tqdm import tqdm
from sklearn.metrics import f1_score
import numpy as np

import torch
from torch import nn

from lib.utils import get_dataloader
from lib.utils import SleepStageClassifier
from lib.utils import ekyn_ids,snezana_mice_ids,courtney_ids
from sklearn.metrics import ConfusionMatrixDisplay,classification_report
from lib.utils import calculate_f1,plot_training_progress

In [None]:
batch_size = 1024

In [None]:
ekyn_ids = ekyn_ids[:8]
snezana_mice_ids = snezana_mice_ids[:8]
courtney_ids = courtney_ids[:8]
print(ekyn_ids,snezana_mice_ids,courtney_ids)

In [None]:
train_ekyn_ids,test_ekyn_ids = ekyn_ids[:-len(ekyn_ids)//4],ekyn_ids[-len(ekyn_ids)//4:]
print(len(train_ekyn_ids),len(test_ekyn_ids),train_ekyn_ids,test_ekyn_ids)
train_snezana_mice_ids,test_snezana_mice_ids = snezana_mice_ids[:-len(snezana_mice_ids)//4],snezana_mice_ids[-len(snezana_mice_ids)//4:]
print(len(train_snezana_mice_ids),len(test_snezana_mice_ids),train_snezana_mice_ids,test_snezana_mice_ids)
train_courtney_ids,test_courtney_ids = courtney_ids[:-len(courtney_ids)//4],courtney_ids[-len(courtney_ids)//4:]
print(len(train_courtney_ids),len(test_courtney_ids),train_courtney_ids,test_courtney_ids)

In [None]:
# trainloader = get_dataloader(train_ekyn_ids[:1],snezana_mice_ids=None,courtney_ids=None,batch_size=batch_size,shuffle=True,downsample=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SleepStageClassifier()
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-4)
criterion = nn.CrossEntropyLoss()

print(device)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, dilation=1, padding='same'):
        super(ConvBlock, self).__init__()
        if padding == 'same':
            padding = ((kernel_size - 1) * dilation) // 2
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        return x

class UTime(nn.Module):
    def __init__(self, in_channels=1, num_classes=5, base_filters=16, segment_length=3000):
        super(UTime, self).__init__()
        self.segment_length = segment_length

        # Encoder with adjusted padding to preserve length where possible
        self.enc1 = ConvBlock(in_channels, base_filters, dilation=2)
        self.pool1 = nn.MaxPool1d(kernel_size=10, padding=0)  # No padding here, adjust input if needed
        self.enc2 = ConvBlock(base_filters, base_filters * 2, dilation=2)
        self.pool2 = nn.MaxPool1d(kernel_size=8)
        self.enc3 = ConvBlock(base_filters * 2, base_filters * 4, dilation=2)
        self.pool3 = nn.MaxPool1d(kernel_size=6)
        self.enc4 = ConvBlock(base_filters * 4, base_filters * 8, dilation=2)
        self.pool4 = nn.MaxPool1d(kernel_size=4)

        # Bottleneck
        self.bottleneck = ConvBlock(base_filters * 8, base_filters * 16, dilation=2)

        # Decoder
        self.up4 = nn.Upsample(scale_factor=4, mode='nearest')
        self.dec4 = ConvBlock(base_filters * 24, base_filters * 8, kernel_size=4, dilation=1)
        self.up3 = nn.Upsample(scale_factor=6, mode='nearest')
        self.dec3 = ConvBlock(base_filters * 12, base_filters * 4, kernel_size=6, dilation=1)
        self.up2 = nn.Upsample(scale_factor=8, mode='nearest')
        self.dec2 = ConvBlock(base_filters * 6, base_filters * 2, kernel_size=8, dilation=1)
        self.up1 = nn.Upsample(scale_factor=10, mode='nearest')
        self.dec1 = ConvBlock(base_filters * 3, base_filters, kernel_size=10, dilation=1)

        # Final convolution
        self.final_conv = nn.Conv1d(base_filters, num_classes, kernel_size=1)

        # Segment Classifier
        self.segment_classifier = nn.Sequential(
            nn.AvgPool1d(kernel_size=segment_length, stride=segment_length),
            nn.Conv1d(num_classes, num_classes, kernel_size=1),
        )

    def forward(self, x):
        input_length = x.size(2)  # Preserve input length

        # Encoder
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        e3 = self.enc3(p2)
        p3 = self.pool3(e3)
        e4 = self.enc4(p3)
        p4 = self.pool4(e4)
        bottleneck = self.bottleneck(p4)

        # Decoder with length preservation
        d4 = self.up4(bottleneck)
        d4 = torch.cat([d4, F.interpolate(e4, size=d4.size(2), mode='nearest')], dim=1)
        d4 = self.dec4(d4)

        d3 = self.up3(d4)
        d3 = torch.cat([d3, F.interpolate(e3, size=d3.size(2), mode='nearest')], dim=1)
        d3 = self.dec3(d3)

        d2 = self.up2(d3)
        d2 = torch.cat([d2, F.interpolate(e2, size=d2.size(2), mode='nearest')], dim=1)
        d2 = self.dec2(d2)

        d1 = self.up1(d2)
        d1 = torch.cat([d1, F.interpolate(e1, size=d1.size(2), mode='nearest')], dim=1)
        d1 = self.dec1(d1)

        # Adjust output to match input length
        sample_scores = self.final_conv(d1)
        if sample_scores.size(2) != input_length:
            sample_scores = F.interpolate(sample_scores, size=input_length, mode='nearest')

        segment_scores = self.segment_classifier(sample_scores)

        return sample_scores, segment_scores

# Test the updated model
batch_size = 2
channels = 1
time_steps = 45000
num_classes = 3
segment_length = 5000

model = UTime(in_channels=channels, num_classes=num_classes, segment_length=segment_length)
x = torch.randn(batch_size, channels, time_steps)
sample_scores, segment_scores = model(x)

print("Sample scores shape:", sample_scores.shape)  # Should be [2, 5, 105000]
print("Segment scores shape:", segment_scores.shape)  # Should be [2, 5, 35]

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = F.softmax(pred, dim=1)  # Convert to probabilities
        intersection = (pred * target).sum(dim=(0, 2))
        union = pred.sum(dim=(0, 2)) + target.sum(dim=(0, 2))
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()  # Average over classes

In [None]:
from lib.utils import load_ekyn
id = train_ekyn_ids[0]
condition = 'PF'
X,y = load_ekyn(id=id,condition=condition)
X = X.unfold(dimension=0,size=9,step=1)
X = X.flatten(1,2)
X = X.unsqueeze(1)
y = y.unfold(dimension=0,size=9,step=1)
trainloader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X, y), batch_size=32, shuffle=True)

# Example training loop (pseudo-code)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = DiceLoss()

In [None]:
model.to(device)
model.train()

lossi = []
# Assuming target is one-hot encoded with shape (batch_size, num_classes, num_segments)
for epoch in range(1):
    for Xi,yi in tqdm(trainloader):
        Xi,yi = Xi.to(device),yi.to(device)
        optimizer.zero_grad()
        _, segment_scores = model(Xi)
        loss = criterion(segment_scores, yi)
        loss.backward()
        optimizer.step()
        lossi.append(loss.item())

In [None]:
import matplotlib.pyplot as plt
plt.plot(lossi)

In [None]:
from lib.utils import load_ekyn
id = train_ekyn_ids[1]
condition = 'Vehicle'
X,y = load_ekyn(id=id,condition=condition)
X = X.unfold(dimension=0,size=9,step=1)
X = X.flatten(1,2)
X = X.unsqueeze(1)
y = y.unfold(dimension=0,size=9,step=1)
testloader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X, y), batch_size=32, shuffle=True)

model.cpu()
Xi,yi = next(iter(testloader))
_, segment_scores = model(Xi)
segment_scores = F.softmax(segment_scores, dim=1)
y_pred = segment_scores.argmax(dim=1).view(-1)
y_true = yi.argmax(dim=1).view(-1)
print(classification_report(y_true,y_pred))

In [None]:
y_pred = segment_scores.argmax(dim=1)[0]
y_true= yi.argmax(dim=1)[0]

In [None]:
# Hyperparameters
batch_size = 2
channels = 1
time_steps = 105000  # 17.5 minutes at 100 Hz (35 segments of 30s)
num_classes = 3

# Create model
model = UTime(in_channels=channels, num_classes=num_classes, segment_length=segment_length)

# Sample input (batch_size, channels, time_steps)
x = torch.randn(batch_size, channels, time_steps)

# Forward pass
sample_scores, segment_scores = model(x)

print("Sample scores shape:", sample_scores.shape)  # Expected: (2, 5, 105000)
print("Segment scores shape:", segment_scores.shape)  # Expected: (2, 5, 35)