In [1]:
import torch
from torch.utils.data import DataLoader, random_split
import numpy as np
from utils import calculate_kurtosis
import torch.optim as optim
from model import Encoder, Decoder
import torch.nn as nn
import torch.nn.functional as F
import zarr
from data.tdms_to_npy_scaling import patch_matrix
import tqdm

In [2]:
dataset = torch.load('autoencoder/data/dataset_p1500_filtered.pt')
print(f"Dataset shape: {dataset.shape}")
# Define the dataset sizes
num_samples = len(dataset)
train_size = int(0.7 * num_samples)
val_size = int(0.1 * num_samples)
test_size = num_samples - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Create DataLoader instances
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

Dataset shape: torch.Size([31158, 1500, 32])


In [3]:
encod_f = Encoder()
encod_m = Encoder()

optimizer = optim.Adam(list(encod_f.parameters()) + list(encod_m.parameters()), lr=0.001, weight_decay=0.001)
num_epochs = 100


In [4]:
class KurtosisLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=1.0, gamma=10.0, epsilon=0.01):
        super(KurtosisLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.epsilon = epsilon

    def forward(self, k_f, k_m):
        # Loss to maximize k_f and k_m
        k_f = torch.abs(k_f)
        k_m = torch.abs(k_m)

        k_f = k_f.view(-1)
        k_m = k_m.view(-1)
        maximize_kf_km = -self.alpha * (k_f + k_m)
        
        # Penalty to ensure k_m > k_f
        penalty = self.gamma * torch.relu(k_f - k_m + self.epsilon)
        
        # Total loss
        loss = maximize_kf_km  + penalty
        return loss.sum()

criterion = KurtosisLoss(alpha=1.0, beta=1.0, gamma=10.0, epsilon=0.01)
for epoch in range(num_epochs):
    encod_f.train()
    encod_m.train()

    train_loss = 0.0
    # Replace the regular for loop with tqdm.notebook.tqdm (for Jupyter notebooks) or tqdm.tqdm (for regular Python scripts)
    for patches in tqdm.tqdm(train_dataloader):
        optimizer.zero_grad()

        # Forward pass through encoders
        encoded_fetal = encod_f(patches)
        encoded_maternal = encod_m(patches)

        k_f = calculate_kurtosis(encoded_fetal)
        k_m = calculate_kurtosis(encoded_maternal)

        loss = criterion(k_f, k_m)

        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation step
    encod_f.eval()
    encod_m.eval()
    val_loss = 0.0
    with torch.no_grad():
        # Replace the regular for loop with tqdm.notebook.tqdm (for Jupyter notebooks) or tqdm.tqdm (for regular Python scripts)
        for patches in tqdm.tqdm(val_dataloader):
            encoded_fetal = encod_f(patches)
            encoded_maternal = encod_m(patches)
            k_f = calculate_kurtosis(encoded_fetal)
            k_m = calculate_kurtosis(encoded_maternal)
            loss = criterion(k_f, k_m)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss/len(train_dataloader):.4f}, Val Loss: {avg_val_loss:.4f}')
    
# Testing step
encod_f.eval()
encod_m.eval()
test_loss = 0.0
with torch.no_grad():
    for patches in test_dataloader:
        encoded_fetal = encod_f(patches)
        encoded_maternal = encod_m(patches)
        k_f = calculate_kurtosis(encoded_fetal)
        k_m = calculate_kurtosis(encoded_maternal)
        loss = criterion(k_f, k_m)
        test_loss += loss.item()

print(f'Test Loss: {test_loss/len(test_dataloader):.4f}')

100%|██████████| 341/341 [00:07<00:00, 44.26it/s]
100%|██████████| 49/49 [00:00<00:00, 60.98it/s]


Epoch [1/100], Train Loss: 21167.3589, Val Loss: 36820.0327


100%|██████████| 341/341 [00:07<00:00, 47.54it/s]
100%|██████████| 49/49 [00:00<00:00, 67.57it/s]


Epoch [2/100], Train Loss: 20434.6211, Val Loss: 28499.4083


100%|██████████| 341/341 [00:06<00:00, 48.98it/s]
100%|██████████| 49/49 [00:00<00:00, 66.32it/s]


Epoch [3/100], Train Loss: 20467.8247, Val Loss: 30331.7912


100%|██████████| 341/341 [00:07<00:00, 48.14it/s]
100%|██████████| 49/49 [00:00<00:00, 65.51it/s]


Epoch [4/100], Train Loss: 20169.3601, Val Loss: 15219.3568


100%|██████████| 341/341 [00:07<00:00, 47.20it/s]
100%|██████████| 49/49 [00:00<00:00, 67.11it/s]


Epoch [5/100], Train Loss: 22133.8649, Val Loss: 21734.3254


100%|██████████| 341/341 [00:07<00:00, 48.67it/s]
100%|██████████| 49/49 [00:00<00:00, 66.77it/s]


Epoch [6/100], Train Loss: 20794.9755, Val Loss: 34525.5633


100%|██████████| 341/341 [00:07<00:00, 48.12it/s]
100%|██████████| 49/49 [00:00<00:00, 64.52it/s]


Epoch [7/100], Train Loss: 19455.9197, Val Loss: 15847.3662


100%|██████████| 341/341 [00:06<00:00, 49.02it/s]
100%|██████████| 49/49 [00:00<00:00, 67.90it/s]


Epoch [8/100], Train Loss: 21803.8984, Val Loss: 27955.4467


100%|██████████| 341/341 [00:07<00:00, 48.38it/s]
100%|██████████| 49/49 [00:00<00:00, 67.54it/s]


Epoch [9/100], Train Loss: 20968.0562, Val Loss: 20381.8701


100%|██████████| 341/341 [00:07<00:00, 47.79it/s]
100%|██████████| 49/49 [00:00<00:00, 66.25it/s]


Epoch [10/100], Train Loss: 19950.4967, Val Loss: 23819.2410


100%|██████████| 341/341 [00:07<00:00, 47.99it/s]
100%|██████████| 49/49 [00:00<00:00, 67.28it/s]


Epoch [11/100], Train Loss: 20763.7452, Val Loss: 14645.0306


100%|██████████| 341/341 [00:07<00:00, 48.24it/s]
100%|██████████| 49/49 [00:00<00:00, 67.15it/s]


Epoch [12/100], Train Loss: 20076.5811, Val Loss: 19777.2765


100%|██████████| 341/341 [00:07<00:00, 47.94it/s]
100%|██████████| 49/49 [00:00<00:00, 65.43it/s]


Epoch [13/100], Train Loss: 20474.4118, Val Loss: 22916.1762


100%|██████████| 341/341 [00:07<00:00, 48.27it/s]
100%|██████████| 49/49 [00:00<00:00, 66.62it/s]


Epoch [14/100], Train Loss: 20558.6146, Val Loss: 20209.1590


100%|██████████| 341/341 [00:07<00:00, 48.20it/s]
100%|██████████| 49/49 [00:00<00:00, 65.45it/s]


Epoch [15/100], Train Loss: 20664.3403, Val Loss: 6664.2383


100%|██████████| 341/341 [00:06<00:00, 48.83it/s]
100%|██████████| 49/49 [00:00<00:00, 66.02it/s]


Epoch [16/100], Train Loss: 20352.1755, Val Loss: 26464.4414


100%|██████████| 341/341 [00:07<00:00, 47.51it/s]
100%|██████████| 49/49 [00:00<00:00, 65.23it/s]


Epoch [17/100], Train Loss: 21146.2038, Val Loss: 42071.1840


100%|██████████| 341/341 [00:07<00:00, 48.54it/s]
100%|██████████| 49/49 [00:00<00:00, 68.14it/s]


Epoch [18/100], Train Loss: 23052.9468, Val Loss: 1071.3368


100%|██████████| 341/341 [00:07<00:00, 48.34it/s]
100%|██████████| 49/49 [00:00<00:00, 67.44it/s]


Epoch [19/100], Train Loss: 21456.0523, Val Loss: 23282.8797


100%|██████████| 341/341 [00:07<00:00, 48.30it/s]
100%|██████████| 49/49 [00:00<00:00, 66.36it/s]


Epoch [20/100], Train Loss: 21060.8066, Val Loss: 41924.2495


100%|██████████| 341/341 [00:07<00:00, 48.69it/s]
100%|██████████| 49/49 [00:00<00:00, 67.65it/s]


Epoch [21/100], Train Loss: 20826.2133, Val Loss: 16861.9716


100%|██████████| 341/341 [00:07<00:00, 48.35it/s]
100%|██████████| 49/49 [00:00<00:00, 67.45it/s]


Epoch [22/100], Train Loss: 19962.8479, Val Loss: 23236.6955


100%|██████████| 341/341 [00:07<00:00, 48.29it/s]
100%|██████████| 49/49 [00:00<00:00, 67.11it/s]


Epoch [23/100], Train Loss: 21362.9809, Val Loss: 25336.6648


100%|██████████| 341/341 [00:07<00:00, 48.38it/s]
100%|██████████| 49/49 [00:00<00:00, 67.60it/s]


Epoch [24/100], Train Loss: 20268.8603, Val Loss: 20503.3842


100%|██████████| 341/341 [00:07<00:00, 48.36it/s]
100%|██████████| 49/49 [00:00<00:00, 67.60it/s]


Epoch [25/100], Train Loss: 20732.4380, Val Loss: 31680.6348


100%|██████████| 341/341 [00:06<00:00, 48.88it/s]
100%|██████████| 49/49 [00:00<00:00, 66.93it/s]


Epoch [26/100], Train Loss: 20183.1637, Val Loss: 38092.9895


100%|██████████| 341/341 [00:07<00:00, 48.65it/s]
100%|██████████| 49/49 [00:00<00:00, 68.44it/s]


Epoch [27/100], Train Loss: 20947.4233, Val Loss: 17698.6880


100%|██████████| 341/341 [00:06<00:00, 48.84it/s]
100%|██████████| 49/49 [00:00<00:00, 68.58it/s]


Epoch [28/100], Train Loss: 21199.6552, Val Loss: 17954.1286


100%|██████████| 341/341 [00:06<00:00, 48.73it/s]
100%|██████████| 49/49 [00:00<00:00, 68.02it/s]


Epoch [29/100], Train Loss: 21676.8971, Val Loss: 29637.1083


100%|██████████| 341/341 [00:06<00:00, 49.03it/s]
100%|██████████| 49/49 [00:00<00:00, 65.73it/s]


Epoch [30/100], Train Loss: 21102.7703, Val Loss: 23925.2850


100%|██████████| 341/341 [00:07<00:00, 44.37it/s]
100%|██████████| 49/49 [00:00<00:00, 61.85it/s]


Epoch [31/100], Train Loss: 20180.3575, Val Loss: 21936.8414


100%|██████████| 341/341 [00:07<00:00, 45.09it/s]
100%|██████████| 49/49 [00:00<00:00, 55.00it/s]


Epoch [32/100], Train Loss: 21792.7474, Val Loss: 29406.5374


100%|██████████| 341/341 [00:07<00:00, 47.93it/s]
100%|██████████| 49/49 [00:00<00:00, 68.89it/s]


Epoch [33/100], Train Loss: 19852.1660, Val Loss: 14355.1057


100%|██████████| 341/341 [00:07<00:00, 47.02it/s]
100%|██████████| 49/49 [00:00<00:00, 64.69it/s]


Epoch [34/100], Train Loss: 20531.5412, Val Loss: 13879.8450


100%|██████████| 341/341 [00:07<00:00, 47.46it/s]
100%|██████████| 49/49 [00:00<00:00, 66.58it/s]


Epoch [35/100], Train Loss: 21040.5777, Val Loss: 43107.1216


100%|██████████| 341/341 [00:07<00:00, 47.88it/s]
100%|██████████| 49/49 [00:00<00:00, 67.32it/s]


Epoch [36/100], Train Loss: 20899.0797, Val Loss: 24434.6645


100%|██████████| 341/341 [00:07<00:00, 47.26it/s]
100%|██████████| 49/49 [00:00<00:00, 67.43it/s]


Epoch [37/100], Train Loss: 21354.3296, Val Loss: 15053.9582


100%|██████████| 341/341 [00:07<00:00, 47.19it/s]
100%|██████████| 49/49 [00:00<00:00, 64.94it/s]


Epoch [38/100], Train Loss: 20293.3222, Val Loss: 32634.1663


100%|██████████| 341/341 [00:07<00:00, 47.55it/s]
100%|██████████| 49/49 [00:00<00:00, 63.78it/s]


Epoch [39/100], Train Loss: 19493.5112, Val Loss: 54950.6927


100%|██████████| 341/341 [00:07<00:00, 47.73it/s]
100%|██████████| 49/49 [00:00<00:00, 68.07it/s]


Epoch [40/100], Train Loss: 21203.7021, Val Loss: 25691.6381


100%|██████████| 341/341 [00:07<00:00, 48.59it/s]
100%|██████████| 49/49 [00:00<00:00, 68.17it/s]


Epoch [41/100], Train Loss: 20117.5891, Val Loss: 28711.2562


100%|██████████| 341/341 [00:07<00:00, 48.38it/s]
100%|██████████| 49/49 [00:00<00:00, 66.64it/s]


Epoch [42/100], Train Loss: 21410.8383, Val Loss: 29024.8676


100%|██████████| 341/341 [00:07<00:00, 48.04it/s]
100%|██████████| 49/49 [00:00<00:00, 66.79it/s]


Epoch [43/100], Train Loss: 21031.0086, Val Loss: 12261.5690


100%|██████████| 341/341 [00:07<00:00, 48.12it/s]
100%|██████████| 49/49 [00:00<00:00, 66.40it/s]


Epoch [44/100], Train Loss: 21446.2778, Val Loss: 19561.8703


100%|██████████| 341/341 [00:07<00:00, 48.33it/s]
100%|██████████| 49/49 [00:00<00:00, 67.33it/s]


Epoch [45/100], Train Loss: 20204.4080, Val Loss: 19587.1094


100%|██████████| 341/341 [00:07<00:00, 48.54it/s]
100%|██████████| 49/49 [00:00<00:00, 68.72it/s]


Epoch [46/100], Train Loss: 20319.7217, Val Loss: 19862.4338


100%|██████████| 341/341 [00:07<00:00, 48.65it/s]
100%|██████████| 49/49 [00:00<00:00, 68.28it/s]


Epoch [47/100], Train Loss: 21811.4185, Val Loss: 32409.3623


100%|██████████| 341/341 [00:06<00:00, 48.92it/s]
100%|██████████| 49/49 [00:00<00:00, 65.86it/s]


Epoch [48/100], Train Loss: 21462.8025, Val Loss: -769.7218


100%|██████████| 341/341 [00:07<00:00, 48.42it/s]
100%|██████████| 49/49 [00:00<00:00, 67.89it/s]


Epoch [49/100], Train Loss: 20555.4276, Val Loss: 33397.5520


 83%|████████▎ | 284/341 [00:06<00:01, 43.74it/s]


KeyboardInterrupt: 