In [1]:
import torch
from torch.utils.data import DataLoader, random_split
import numpy as np
from utils import calculate_kurtosis
import torch.optim as optim
from model import Encoder, Decoder
import torch.nn as nn
import torch.nn.functional as F
import zarr
from data.tdms_to_npy_scaling import patch_matrix
import tqdm

In [2]:
# Assuming 'dataset' is your dataset
#dataset = patch_matrix(patch_size=1500, apply_filter = True)
#print(f"Dataset shape: {dataset.shape}")
#dataset = torch.tensor(dataset, dtype=torch.float32)  # Ensure dataset is a PyTorch tensor
dataset = torch.load('autoencoder/data/dataset_p1500_filtered.pt')
print(f"Dataset shape: {dataset.shape}")
# Define the dataset sizes
num_samples = len(dataset)
train_size = int(0.7 * num_samples)
val_size = int(0.1 * num_samples)
test_size = num_samples - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Create DataLoader instances
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=True)

Dataset shape: torch.Size([31158, 1500, 32])


In [3]:
encod_f = Encoder()
encod_m = Encoder()
decoder = Decoder()

# Define optimizer
optimizer = optim.Adam(list(encod_f.parameters()) + list(encod_m.parameters()) +
                           list(decoder.parameters()), lr=0.001, weight_decay=0.001)

num_epochs = 100
best_val_loss = float('inf')  # Initialize best validation loss to infinity

In [4]:
def kurtosis_loss(k_f, k_m):
    k_f = torch.abs(k_f)
    k_m = torch.abs(k_m)
    maximize_kf_km = - (k_f + k_m)
    maximize_difference =  - torch.abs(k_m - k_f)
    penalty =  10 * torch.relu(k_f - k_m + 0.01)
    loss = maximize_kf_km + maximize_difference + penalty
    return loss.sum()

In [5]:
# Training loop
for epoch in range(num_epochs):
    encod_f.train()
    encod_m.train()
    decoder.train()

    train_loss = 0.0
    # Replace the regular for loop with tqdm.notebook.tqdm (for Jupyter notebooks) or tqdm.tqdm (for regular Python scripts)
    for patches in tqdm.tqdm(train_dataloader):
        optimizer.zero_grad()

        # Forward pass through encoders
        encoded_fetal = encod_f(patches)
        encoded_maternal = encod_m(patches)

        # Calculate kurtosis
        k_f = calculate_kurtosis(encoded_fetal)
        k_m = calculate_kurtosis(encoded_maternal)

        # Forward pass through decoder
        reconstructed = decoder(encoded_fetal + encoded_maternal) 

        # Calculate loss
        loss = F.mse_loss(reconstructed, patches) + 0.1 * kurtosis_loss(k_f, k_m)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation step
    encod_f.eval()
    encod_m.eval()
    decoder.eval()
    val_loss = 0.0
    with torch.no_grad():
        # Replace the regular for loop with tqdm.notebook.tqdm (for Jupyter notebooks) or tqdm.tqdm (for regular Python scripts)
        for patches in tqdm.tqdm(val_dataloader):
            encoded_fetal = encod_f(patches)
            encoded_maternal = encod_m(patches)
            k_f = calculate_kurtosis(encoded_fetal)
            k_m = calculate_kurtosis(encoded_maternal)
            reconstructed = decoder(encoded_fetal + encoded_maternal)
            loss = F.mse_loss(reconstructed, patches)  + 0.1 * kurtosis_loss(k_f, k_m)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss/len(train_dataloader):.4f}, Val Loss: {avg_val_loss:.4f}')
    
# Testing step
encod_f.eval()
encod_m.eval()
decoder.eval()
test_loss = 0.0
with torch.no_grad():
    for patches in test_dataloader:
        encoded_fetal = encod_f(patches)
        encoded_maternal = encod_m(patches)
        k_f = calculate_kurtosis(encoded_fetal)
        k_m = calculate_kurtosis(encoded_maternal)
        reconstructed = decoder(encoded_fetal + encoded_maternal)
        loss = F.mse_loss(reconstructed, patches) + 0.1 * kurtosis_loss(k_f, k_m)
        test_loss += loss.item()

print(f'Test Loss: {test_loss/len(test_dataloader):.4f}')

100%|██████████| 171/171 [08:41<00:00,  3.05s/it]
100%|██████████| 25/25 [00:24<00:00,  1.03it/s]


Epoch [1/100], Train Loss: 14703.5746, Val Loss: 13888.6789


100%|██████████| 171/171 [08:06<00:00,  2.84s/it]
100%|██████████| 25/25 [00:20<00:00,  1.21it/s]


Epoch [2/100], Train Loss: 13994.7994, Val Loss: 14483.8972


100%|██████████| 171/171 [08:26<00:00,  2.96s/it]
100%|██████████| 25/25 [00:20<00:00,  1.24it/s]


Epoch [3/100], Train Loss: 16161.1981, Val Loss: 16337.2802


100%|██████████| 171/171 [07:58<00:00,  2.80s/it]
100%|██████████| 25/25 [00:20<00:00,  1.21it/s]


Epoch [4/100], Train Loss: 17228.0248, Val Loss: 17052.6889


 19%|█▉        | 33/171 [01:35<06:40,  2.90s/it]


KeyboardInterrupt: 