In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import mat73
import numpy as np
from tqdm import tqdm
import wandb
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

from scipy.signal import find_peaks
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import r2_score
import psutil
import os
import time

In [None]:
wandb.login()
# run = wandb.init(project="master-multicomponent-mri", name="unet-baseline-fullpatch")
run = wandb.init(project="master-multicomponent-mri", name="unet-baseline-fullpatch-att-3")

[34m[1mwandb[0m: Currently logged in as: [33mtr-phan[0m ([33mtrphan[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [None]:
data = mat73.loadmat('../../data/training_data_T1_3D_9x9x32x47932_noise0.05.mat')
input_noisy_np = data['input_noisy']
input_clean_np = data['input']
ref_np = data['ref']

print("Shape of input data:", input_noisy_np.shape)
print("Shape of reference data:", ref_np.shape)

Shape of input data: (9, 9, 8, 47932)
Shape of reference data: (9, 9, 32, 47932)


In [4]:
nan_count_input_noisy_before = np.isnan(input_noisy_np).sum()
nan_count_ref_before = np.isnan(ref_np).sum()
print(f"NaN count in input_noisy before handling: {nan_count_input_noisy_before}")
print(f"NaN count in ref before handling: {nan_count_ref_before}")

NaN count in input_noisy before handling: 1635840
NaN count in ref before handling: 0


In [5]:
input_noisy_np = np.nan_to_num(input_noisy_np, nan=0.0)
nan_count_input_noisy_after = np.isnan(input_noisy_np).sum()
print(f"NaN count in input_noisy after handling: {nan_count_input_noisy_after}")

NaN count in input_noisy after handling: 0


In [6]:
class T1PatchDataset(Dataset):
    def __init__(self, input_data, target_data):
        self.input_data = input_data
        self.target_data = target_data

    def __len__(self):
        return len(self.input_data)

    def __getitem__(self, idx):
        return self.input_data[idx], self.target_data[idx]

In [7]:
class SpatialAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Conv2d(in_channels, in_channels//8, kernel_size=1),
            nn.BatchNorm2d(in_channels//8),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels//8, in_channels, kernel_size=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        attention_weights = self.attention(x)
        return x * attention_weights

In [8]:
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction_ratio, in_channels)
        )
        
    def forward(self, x):
        b, c, _, _ = x.size()
        
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        
        out = avg_out + max_out
        return torch.sigmoid(out).view(b, c, 1, 1)

In [9]:
class DualAttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.channel_att = ChannelAttention(in_channels)
        self.spatial_att = SpatialAttention(in_channels)
        
    def forward(self, x):
        x = x * self.channel_att(x)
        x = x * self.spatial_att(x)
        return x

In [10]:
class RefineBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.refine = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.LeakyReLU(0.1),
            DualAttentionBlock(in_channels)
        )
        
    def forward(self, x):
        return x + self.refine(x)

In [None]:
class ResidualDoubleConvWithAttention(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1),
            nn.Dropout2d(0.1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1)
        )
        self.attention = SpatialAttention(out_channels)
        
        # residual connection
        self.residual = nn.Identity() if in_channels == out_channels else \
                       nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        identity = self.residual(x)
        x = self.double_conv(x)
        x = self.attention(x)
        return x + identity  # Residual connection

In [None]:
class ImprovedUNet(nn.Module):
    def __init__(self, in_channels=8, init_features=64):
        super().__init__()
        
        # Encoder
        self.encoder1 = ResidualDoubleConvWithAttention(in_channels, init_features)
        self.pool1 = nn.MaxPool2d(2, padding=1)
        self.encoder2 = ResidualDoubleConvWithAttention(init_features, init_features*2)
        
        # Bridge
        self.bridge = ResidualDoubleConvWithAttention(init_features*2, init_features*4)
        
        # Decoder
        self.upconv1 = nn.ConvTranspose2d(init_features*4, init_features*2, kernel_size=2, stride=2)
        self.decoder1 = ResidualDoubleConvWithAttention(init_features*3, init_features*2)
        
        # Skip connection refinement
        self.refine1 = RefineBlock(init_features)
        
        # Multi-scale feature fusion
        self.fusion = nn.Sequential(
            nn.Conv2d(init_features*2 + init_features, init_features, kernel_size=1),
            nn.BatchNorm2d(init_features),
            nn.LeakyReLU(0.1)
        )
        
        # Final output
        self.final = nn.Sequential(
            nn.Conv2d(init_features, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),
            DualAttentionBlock(64),
            nn.Conv2d(64, 32, kernel_size=1)
        )
        
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Encoding
        enc1 = self.encoder1(x)
        x = self.pool1(enc1)
        enc2 = self.encoder2(x)
        
        # Bridge
        x = self.bridge(enc2)
        
        # Decoding
        x = self.upconv1(x)
        x = x[:, :, :9, :9]
        
        # skip connection
        refined_skip = self.refine1(enc1)
        x = torch.cat([x, refined_skip], dim=1)
        x = self.decoder1(x)
        
        # Multi-scale feature fusion
        x = self.fusion(torch.cat([x, refined_skip], dim=1))
        
        # Final output
        x = self.final(x)
        
        return x

In [13]:
class DeepSupervisionLoss(nn.Module):
    def __init__(self, main_loss_weight=1.0, aux_loss_weight=0.4):
        super().__init__()
        self.main_loss_weight = main_loss_weight
        self.aux_loss_weight = aux_loss_weight
        self.criterion = nn.MSELoss()
        
    def forward(self, outputs, targets):
        if isinstance(outputs, tuple):
            main_out, deep_out1, deep_out2 = outputs
            loss = self.main_loss_weight * self.criterion(main_out, targets)
            loss += self.aux_loss_weight * self.criterion(deep_out1, targets)
            loss += self.aux_loss_weight * self.criterion(deep_out2, targets)
            return loss
        return self.criterion(outputs, targets)

In [None]:
class EarlyStopping:
    def __init__(self, patience=15, min_delta=1e-6, path='saved_model/best_unet_model_attention_3.pt'):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False
        self.path = path
        
    def __call__(self, val_loss, model):
        if val_loss < self.best_loss:
            print(f'Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}). Saving model...')
            self.best_loss = val_loss
            self.save_checkpoint(model)
            self.counter = 0
        else:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, model):
        torch.save({
            'model_state_dict': model.state_dict(),
            'val_loss': self.best_loss,
        }, self.path)

In [None]:
# Transpose the input and reference data to be (N, C, H, W) format
input_noisy_torch = np.transpose(input_noisy_np, (3, 2, 0, 1))
ref_torch = np.transpose(ref_np, (3, 2, 0, 1))

input_noisy_torch = torch.tensor(input_noisy_torch, dtype=torch.float32)
ref_torch = torch.tensor(ref_torch, dtype=torch.float32)

In [16]:
input_mean = input_noisy_torch.mean()
input_std = input_noisy_torch.std()
input_noisy_torch = (input_noisy_torch - input_mean) / input_std

In [None]:
# Create Dataset
dataset = T1PatchDataset(input_noisy_torch, ref_torch)

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Training Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImprovedUNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
early_stopping = EarlyStopping(patience=10, path='best_unet_model_attention_3.pt')

num_epochs = 200

In [22]:
# Log configuration
wandb.config.update({
    "learning_rate": 0.001,
    "batch_size": batch_size,
    "epochs": num_epochs,
    "architecture": "UNet",
    "optimizer": "Adam",
    "loss_function": "MSELoss",
    "scheduler": "ReduceLROnPlateau",
    "early_stopping_patience": 10
})

In [23]:
# Training loop
best_val_loss = float('inf')

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    progress_bar_train = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{num_epochs}] Training')
    
    for inputs, targets in progress_bar_train:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        train_loss += loss.item()
        progress_bar_train.set_postfix({'loss': loss.item()})

    avg_train_loss = train_loss / len(train_loader)

    # Validation phase
    model.eval()
    val_loss = 0.0
    progress_bar_val = tqdm(val_loader, desc=f'Epoch [{epoch+1}/{num_epochs}] Validation')
    
    with torch.no_grad():
        for inputs, targets in progress_bar_val:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            progress_bar_val.set_postfix({'loss': loss.item()})

    avg_val_loss = val_loss / len(val_loader)

    # Learning rate scheduling
    scheduler.step(avg_val_loss)
    current_lr = optimizer.param_groups[0]['lr']

    # Early stopping
    early_stopping(avg_val_loss, model)

    # Logging
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": avg_train_loss,
        "val_loss": avg_val_loss,
        "learning_rate": current_lr
    })

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss}, Val Loss: {avg_val_loss}, LR: {current_lr}')

    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

print("Training finished")

Epoch [1/200] Training: 100%|██████████| 600/600 [00:14<00:00, 42.51it/s, loss=0.00418]
Epoch [1/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 123.05it/s, loss=0.00322]


Validation loss decreased (inf --> 0.003122). Saving model...
Epoch [1/200], Train Loss: 0.004804139414336533, Val Loss: 0.0031223810153702893, LR: 0.001


Epoch [2/200] Training: 100%|██████████| 600/600 [00:14<00:00, 42.76it/s, loss=0.00347]
Epoch [2/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 120.52it/s, loss=0.00285]


Validation loss decreased (0.003122 --> 0.002640). Saving model...
Epoch [2/200], Train Loss: 0.00294046962284483, Val Loss: 0.0026400376080224913, LR: 0.001


Epoch [3/200] Training: 100%|██████████| 600/600 [00:14<00:00, 41.77it/s, loss=0.00273]
Epoch [3/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 123.63it/s, loss=0.00271]


Validation loss decreased (0.002640 --> 0.002531). Saving model...
Epoch [3/200], Train Loss: 0.0025970397365745156, Val Loss: 0.0025309902740021546, LR: 0.001


Epoch [4/200] Training: 100%|██████████| 600/600 [00:14<00:00, 42.38it/s, loss=0.00217]
Epoch [4/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 120.17it/s, loss=0.00252]


Validation loss decreased (0.002531 --> 0.002326). Saving model...
Epoch [4/200], Train Loss: 0.002401989675903072, Val Loss: 0.0023259759088978173, LR: 0.001


Epoch [5/200] Training: 100%|██████████| 600/600 [00:14<00:00, 42.79it/s, loss=0.00214]
Epoch [5/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 113.15it/s, loss=0.00249]


EarlyStopping counter: 1 out of 10
Epoch [5/200], Train Loss: 0.0022823925242604066, Val Loss: 0.0023279787716455756, LR: 0.001


Epoch [6/200] Training: 100%|██████████| 600/600 [00:14<00:00, 41.14it/s, loss=0.00273]
Epoch [6/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 121.52it/s, loss=0.00217]


Validation loss decreased (0.002326 --> 0.002058). Saving model...
Epoch [6/200], Train Loss: 0.0021944624942261725, Val Loss: 0.002057666053685049, LR: 0.001


Epoch [7/200] Training: 100%|██████████| 600/600 [00:13<00:00, 43.52it/s, loss=0.00184]
Epoch [7/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 118.43it/s, loss=0.00264]


EarlyStopping counter: 1 out of 10
Epoch [7/200], Train Loss: 0.0021302095123489078, Val Loss: 0.0023583986586891113, LR: 0.001


Epoch [8/200] Training: 100%|██████████| 600/600 [00:14<00:00, 42.74it/s, loss=0.00272]
Epoch [8/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 111.34it/s, loss=0.00224]


EarlyStopping counter: 2 out of 10
Epoch [8/200], Train Loss: 0.002081057935526284, Val Loss: 0.002183212290207545, LR: 0.001


Epoch [9/200] Training: 100%|██████████| 600/600 [00:14<00:00, 41.84it/s, loss=0.00193]
Epoch [9/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 103.73it/s, loss=0.00213]


Validation loss decreased (0.002058 --> 0.002009). Saving model...
Epoch [9/200], Train Loss: 0.0020363354766353343, Val Loss: 0.0020092405860001844, LR: 0.001


Epoch [10/200] Training: 100%|██████████| 600/600 [00:14<00:00, 41.98it/s, loss=0.00307]
Epoch [10/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 103.70it/s, loss=0.00207]


Validation loss decreased (0.002009 --> 0.001998). Saving model...
Epoch [10/200], Train Loss: 0.0020041667765084035, Val Loss: 0.001997783982660621, LR: 0.001


Epoch [11/200] Training: 100%|██████████| 600/600 [00:14<00:00, 41.34it/s, loss=0.00302]
Epoch [11/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 107.41it/s, loss=0.00201]


Validation loss decreased (0.001998 --> 0.001934). Saving model...
Epoch [11/200], Train Loss: 0.0019749013881664723, Val Loss: 0.0019341644023855528, LR: 0.001


Epoch [12/200] Training: 100%|██████████| 600/600 [00:14<00:00, 42.59it/s, loss=0.00283]
Epoch [12/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 108.05it/s, loss=0.00205]


Validation loss decreased (0.001934 --> 0.001911). Saving model...
Epoch [12/200], Train Loss: 0.0019378996785962953, Val Loss: 0.001911228314662973, LR: 0.001


Epoch [13/200] Training: 100%|██████████| 600/600 [00:14<00:00, 42.51it/s, loss=0.00192]
Epoch [13/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 112.06it/s, loss=0.00232]


EarlyStopping counter: 1 out of 10
Epoch [13/200], Train Loss: 0.001917302857618779, Val Loss: 0.002080636476166546, LR: 0.001


Epoch [14/200] Training: 100%|██████████| 600/600 [00:14<00:00, 41.92it/s, loss=0.00253]
Epoch [14/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 113.47it/s, loss=0.00196]


Validation loss decreased (0.001911 --> 0.001840). Saving model...
Epoch [14/200], Train Loss: 0.0018989571026759222, Val Loss: 0.0018397855029130975, LR: 0.001


Epoch [15/200] Training: 100%|██████████| 600/600 [00:13<00:00, 45.44it/s, loss=0.00181]
Epoch [15/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 120.77it/s, loss=0.00186]


EarlyStopping counter: 1 out of 10
Epoch [15/200], Train Loss: 0.0018719614710425959, Val Loss: 0.0018512083396005133, LR: 0.001


Epoch [16/200] Training: 100%|██████████| 600/600 [00:14<00:00, 40.34it/s, loss=0.00262]
Epoch [16/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 116.75it/s, loss=0.00215]


EarlyStopping counter: 2 out of 10
Epoch [16/200], Train Loss: 0.0018543166694386553, Val Loss: 0.0019167674736430248, LR: 0.001


Epoch [17/200] Training: 100%|██████████| 600/600 [00:14<00:00, 42.10it/s, loss=0.00256]
Epoch [17/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 109.80it/s, loss=0.00198]


EarlyStopping counter: 3 out of 10
Epoch [17/200], Train Loss: 0.001846499239715437, Val Loss: 0.0019236950553022324, LR: 0.001


Epoch [18/200] Training: 100%|██████████| 600/600 [00:13<00:00, 43.15it/s, loss=0.0028] 
Epoch [18/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 116.22it/s, loss=0.00194]


Validation loss decreased (0.001840 --> 0.001791). Saving model...
Epoch [18/200], Train Loss: 0.0018258441926445811, Val Loss: 0.0017913085144634047, LR: 0.001


Epoch [19/200] Training: 100%|██████████| 600/600 [00:14<00:00, 40.22it/s, loss=0.00189]
Epoch [19/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 120.60it/s, loss=0.00213]


EarlyStopping counter: 1 out of 10
Epoch [19/200], Train Loss: 0.0018086022509184356, Val Loss: 0.0020406962395645677, LR: 0.001


Epoch [20/200] Training: 100%|██████████| 600/600 [00:13<00:00, 43.18it/s, loss=0.00238]
Epoch [20/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 114.63it/s, loss=0.00186]


EarlyStopping counter: 2 out of 10
Epoch [20/200], Train Loss: 0.001797759125280815, Val Loss: 0.0019171474982673923, LR: 0.001


Epoch [21/200] Training: 100%|██████████| 600/600 [00:13<00:00, 42.95it/s, loss=0.0014] 
Epoch [21/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 126.30it/s, loss=0.00196]


EarlyStopping counter: 3 out of 10
Epoch [21/200], Train Loss: 0.0017838441394269467, Val Loss: 0.0018243598286062479, LR: 0.001


Epoch [22/200] Training: 100%|██████████| 600/600 [00:14<00:00, 41.72it/s, loss=0.0022] 
Epoch [22/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 115.62it/s, loss=0.00185]


Validation loss decreased (0.001791 --> 0.001765). Saving model...
Epoch [22/200], Train Loss: 0.0017668734496692196, Val Loss: 0.0017649862356483937, LR: 0.001


Epoch [23/200] Training: 100%|██████████| 600/600 [00:14<00:00, 42.35it/s, loss=0.00189]
Epoch [23/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 125.72it/s, loss=0.00193]


EarlyStopping counter: 1 out of 10
Epoch [23/200], Train Loss: 0.0017518920798708375, Val Loss: 0.0017675568621295195, LR: 0.001


Epoch [24/200] Training: 100%|██████████| 600/600 [00:15<00:00, 39.57it/s, loss=0.00141]
Epoch [24/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 112.09it/s, loss=0.0019] 


Validation loss decreased (0.001765 --> 0.001749). Saving model...
Epoch [24/200], Train Loss: 0.00174029302453467, Val Loss: 0.001749473560291032, LR: 0.001


Epoch [25/200] Training: 100%|██████████| 600/600 [00:13<00:00, 44.87it/s, loss=0.00352]
Epoch [25/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 125.98it/s, loss=0.00181]


EarlyStopping counter: 1 out of 10
Epoch [25/200], Train Loss: 0.0017300998142066723, Val Loss: 0.0018174676783382893, LR: 0.001


Epoch [26/200] Training: 100%|██████████| 600/600 [00:13<00:00, 43.31it/s, loss=0.00188]
Epoch [26/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 121.00it/s, loss=0.00187]


Validation loss decreased (0.001749 --> 0.001745). Saving model...
Epoch [26/200], Train Loss: 0.0017192947713192553, Val Loss: 0.0017446812358684837, LR: 0.001


Epoch [27/200] Training: 100%|██████████| 600/600 [00:13<00:00, 45.41it/s, loss=0.00213]
Epoch [27/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 120.38it/s, loss=0.00193]


EarlyStopping counter: 1 out of 10
Epoch [27/200], Train Loss: 0.0017047171634233867, Val Loss: 0.00180997480560715, LR: 0.001


Epoch [28/200] Training: 100%|██████████| 600/600 [00:12<00:00, 46.64it/s, loss=0.00187]
Epoch [28/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 121.44it/s, loss=0.00199]


EarlyStopping counter: 2 out of 10
Epoch [28/200], Train Loss: 0.0016939710090324903, Val Loss: 0.0018254950370950004, LR: 0.001


Epoch [29/200] Training: 100%|██████████| 600/600 [00:14<00:00, 42.55it/s, loss=0.00191]
Epoch [29/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 125.97it/s, loss=0.0019] 


EarlyStopping counter: 3 out of 10
Epoch [29/200], Train Loss: 0.0016798039601417258, Val Loss: 0.0018050136153275767, LR: 0.001


Epoch [30/200] Training: 100%|██████████| 600/600 [00:13<00:00, 45.64it/s, loss=0.002]  
Epoch [30/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 136.35it/s, loss=0.00192]


EarlyStopping counter: 4 out of 10
Epoch [30/200], Train Loss: 0.0016703810989080617, Val Loss: 0.0017873733886517584, LR: 0.001


Epoch [31/200] Training: 100%|██████████| 600/600 [00:13<00:00, 43.62it/s, loss=0.00266]
Epoch [31/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 128.44it/s, loss=0.00199]


EarlyStopping counter: 5 out of 10
Epoch [31/200], Train Loss: 0.0016573349372871842, Val Loss: 0.0018578331917524337, LR: 0.001


Epoch [32/200] Training: 100%|██████████| 600/600 [00:13<00:00, 44.09it/s, loss=0.00249]
Epoch [32/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 128.08it/s, loss=0.00203]


EarlyStopping counter: 6 out of 10
Epoch [32/200], Train Loss: 0.001645856317675983, Val Loss: 0.001827547150508811, LR: 0.0005


Epoch [33/200] Training: 100%|██████████| 600/600 [00:13<00:00, 43.63it/s, loss=0.00201]
Epoch [33/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 121.01it/s, loss=0.00182]


EarlyStopping counter: 7 out of 10
Epoch [33/200], Train Loss: 0.0015688439221897472, Val Loss: 0.001783499448404958, LR: 0.0005


Epoch [34/200] Training: 100%|██████████| 600/600 [00:13<00:00, 44.69it/s, loss=0.00214]
Epoch [34/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 111.92it/s, loss=0.00181]


EarlyStopping counter: 8 out of 10
Epoch [34/200], Train Loss: 0.0015493095594380673, Val Loss: 0.0017704903497360648, LR: 0.0005


Epoch [35/200] Training: 100%|██████████| 600/600 [00:12<00:00, 47.21it/s, loss=0.00135]
Epoch [35/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 123.90it/s, loss=0.00178]


EarlyStopping counter: 9 out of 10
Epoch [35/200], Train Loss: 0.0015293648516914496, Val Loss: 0.0017522995662875474, LR: 0.0005


Epoch [36/200] Training: 100%|██████████| 600/600 [00:14<00:00, 42.34it/s, loss=0.00247]
Epoch [36/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 122.94it/s, loss=0.00167]


Validation loss decreased (0.001745 --> 0.001735). Saving model...
Epoch [36/200], Train Loss: 0.0015177057274073983, Val Loss: 0.00173460741682599, LR: 0.0005


Epoch [37/200] Training: 100%|██████████| 600/600 [00:13<00:00, 45.92it/s, loss=0.0013] 
Epoch [37/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 129.92it/s, loss=0.00171]


EarlyStopping counter: 1 out of 10
Epoch [37/200], Train Loss: 0.0015023584155521045, Val Loss: 0.0017418917974767586, LR: 0.0005


Epoch [38/200] Training: 100%|██████████| 600/600 [00:13<00:00, 44.47it/s, loss=0.00157]
Epoch [38/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 128.46it/s, loss=0.00177]


EarlyStopping counter: 2 out of 10
Epoch [38/200], Train Loss: 0.0014867247340346996, Val Loss: 0.001755695784619699, LR: 0.0005


Epoch [39/200] Training: 100%|██████████| 600/600 [00:13<00:00, 44.47it/s, loss=0.00232]
Epoch [39/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 132.97it/s, loss=0.00175]


EarlyStopping counter: 3 out of 10
Epoch [39/200], Train Loss: 0.0014709281550797945, Val Loss: 0.0017559050877268116, LR: 0.0005


Epoch [40/200] Training: 100%|██████████| 600/600 [00:13<00:00, 43.01it/s, loss=0.00234]
Epoch [40/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 105.14it/s, loss=0.00181]


EarlyStopping counter: 4 out of 10
Epoch [40/200], Train Loss: 0.0014597993701075515, Val Loss: 0.0017743292590603232, LR: 0.0005


Epoch [41/200] Training: 100%|██████████| 600/600 [00:14<00:00, 41.29it/s, loss=0.00145]
Epoch [41/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 108.40it/s, loss=0.00175]


EarlyStopping counter: 5 out of 10
Epoch [41/200], Train Loss: 0.0014398042069903264, Val Loss: 0.0017832959536463022, LR: 0.0005


Epoch [42/200] Training: 100%|██████████| 600/600 [00:13<00:00, 43.78it/s, loss=0.00213]
Epoch [42/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 108.18it/s, loss=0.00201]


EarlyStopping counter: 6 out of 10
Epoch [42/200], Train Loss: 0.001433387571790566, Val Loss: 0.0018310591555200518, LR: 0.00025


Epoch [43/200] Training: 100%|██████████| 600/600 [00:14<00:00, 41.12it/s, loss=0.00122] 
Epoch [43/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 109.19it/s, loss=0.00179]


EarlyStopping counter: 7 out of 10
Epoch [43/200], Train Loss: 0.0013717608536050344, Val Loss: 0.0017817524028941989, LR: 0.00025


Epoch [44/200] Training: 100%|██████████| 600/600 [00:13<00:00, 44.53it/s, loss=0.0022]  
Epoch [44/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 114.73it/s, loss=0.00177]


EarlyStopping counter: 8 out of 10
Epoch [44/200], Train Loss: 0.0013567781621046985, Val Loss: 0.0018124111276119946, LR: 0.00025


Epoch [45/200] Training: 100%|██████████| 600/600 [00:13<00:00, 44.72it/s, loss=0.00136]
Epoch [45/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 107.49it/s, loss=0.00177]


EarlyStopping counter: 9 out of 10
Epoch [45/200], Train Loss: 0.0013471641152864322, Val Loss: 0.0018040699209086596, LR: 0.00025


Epoch [46/200] Training: 100%|██████████| 600/600 [00:14<00:00, 40.54it/s, loss=0.00168] 
Epoch [46/200] Validation: 100%|██████████| 150/150 [00:01<00:00, 109.79it/s, loss=0.00173]

EarlyStopping counter: 10 out of 10
Epoch [46/200], Train Loss: 0.0013312525448660986, Val Loss: 0.001802956215105951, LR: 0.00025
Early stopping triggered
Training finished





In [None]:
def calculate_psnr(target, prediction):
    """Calculate Peak Signal-to-Noise Ratio"""
    mse = np.mean((target - prediction) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = np.max(target)
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return psnr

def calculate_peak_metrics_patch(target, prediction, prominence=0.1):
    """Calculate peak detection accuracy for full patch predictions"""
    peak_metrics = {'true_peaks': 0, 'predicted_peaks': 0, 'matching_peaks': 0}
    
    # Iterate over each voxel in the patch
    batch_size, num_points, height, width = target.shape
    
    for b in range(batch_size):
        for h in range(height):
            for w in range(width):
                # Get spectrum for current voxel
                true_spectrum = target[b, :, h, w]
                pred_spectrum = prediction[b, :, h, w]
                
                # Find peaks
                true_peaks, _ = find_peaks(true_spectrum, prominence=prominence)
                pred_peaks, _ = find_peaks(pred_spectrum, prominence=prominence)
                
                # Count matching peaks
                matches = 0
                for tp in true_peaks:
                    for pp in pred_peaks:
                        if abs(tp - pp) <= 1:
                            matches += 1
                            break
                
                peak_metrics['true_peaks'] += len(true_peaks)
                peak_metrics['predicted_peaks'] += len(pred_peaks)
                peak_metrics['matching_peaks'] += matches
    
    return peak_metrics

def measure_inference_time(model, input_tensor, device, num_iterations=100):
    """Measure average inference time"""
    model.eval()
    start_time = time.time()
    with torch.no_grad():
        for _ in range(num_iterations):
            _ = model(input_tensor.to(device))
    end_time = time.time()
    return (end_time - start_time) / num_iterations

def count_parameters(model):
    """Count number of trainable parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def measure_memory_usage():
    """Measure current memory usage"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024  # Convert to MB

def calculate_spatial_consistency(targets, predictions, window_size=3):
    """Calculate spatial consistency metric with proper handling of constant patches"""
    batch_size, num_points, height, width = targets.shape
    consistency_scores = []
    
    # Calculate local spatial correlation for each patch
    for b in range(min(batch_size, 100)):  # Limit to 100 samples for efficiency
        for t in range(num_points):
            true_patch = targets[b, t]
            pred_patch = predictions[b, t]
            
            # Skip if either patch is constant
            if np.std(true_patch) == 0 or np.std(pred_patch) == 0:
                continue
                
            # Calculate local spatial correlation
            try:
                correlation = np.corrcoef(true_patch.flatten(), pred_patch.flatten())[0, 1]
                if not np.isnan(correlation):
                    consistency_scores.append(correlation)
            except:
                continue
    
    # Return mean if we have scores, otherwise return 0
    return np.mean(consistency_scores) if consistency_scores else 0.0


def evaluate_model_metrics(model, val_loader, device):
    """Evaluate all model metrics for patch-based models"""
    model.eval()
    metrics = {}
    
    # Get sample input for inference time measurement
    sample_input, _ = next(iter(val_loader))
    metrics['inference_time'] = measure_inference_time(model, sample_input, device)
    
    all_targets = []
    all_predictions = []
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            
            # Convert to numpy for metric calculation
            targets_np = targets.cpu().numpy()
            outputs_np = outputs.cpu().numpy()
            
            all_targets.append(targets_np)
            all_predictions.append(outputs_np)
    
    # Concatenate all batches
    all_targets = np.concatenate(all_targets)
    all_predictions = np.concatenate(all_predictions)
    
    # Calculate metrics
    metrics['mse'] = np.mean((all_targets - all_predictions) ** 2)
    metrics['mae'] = np.mean(np.abs(all_targets - all_predictions))
    metrics['psnr'] = calculate_psnr(all_targets, all_predictions)
    
    # Calculate R² score for each spatial position
    r2_scores = []
    batch_size, num_points, height, width = all_targets.shape
    for h in range(height):
        for w in range(width):
            true_spectra = all_targets[:, :, h, w].reshape(-1)
            pred_spectra = all_predictions[:, :, h, w].reshape(-1)
            r2_scores.append(r2_score(true_spectra, pred_spectra))
    metrics['r2_score'] = np.mean(r2_scores)
    
    # Calculate peak metrics for patches
    metrics['peak_metrics'] = calculate_peak_metrics_patch(all_targets, all_predictions)
    
    # System metrics
    metrics['memory_usage'] = measure_memory_usage()
    metrics['num_parameters'] = count_parameters(model)
    
    # Add spatial metrics
    metrics['spatial_consistency'] = calculate_spatial_consistency(all_targets, all_predictions)
    
    return metrics

In [None]:
# checkpoint = torch.load('saved_models/best_unet_model_attention_3.pt')
# model.load_state_dict(checkpoint['model_state_dict'])

model.eval()
final_metrics = evaluate_model_metrics(model, val_loader, device)

wandb.log({
    "final_mse": final_metrics['mse'],
    "final_mae": final_metrics['mae'],
    "final_psnr": final_metrics['psnr'],
    "final_r2_score": final_metrics['r2_score'],
    "peak_detection_accuracy": final_metrics['peak_metrics']['matching_peaks'] / 
                             final_metrics['peak_metrics']['true_peaks'],
    "inference_time_ms": final_metrics['inference_time'] * 1000,
    "memory_usage_mb": final_metrics['memory_usage'],
    "model_parameters": final_metrics['num_parameters']
})

print("\nFinal Model Evaluation:")
print(f"MSE: {final_metrics['mse']:.6f}")
print(f"MAE: {final_metrics['mae']:.6f}")
print(f"PSNR: {final_metrics['psnr']:.2f} dB")
print(f"R² Score: {final_metrics['r2_score']:.4f}")
print(f"Peak Detection Accuracy: {final_metrics['peak_metrics']['matching_peaks'] / final_metrics['peak_metrics']['true_peaks']:.2%}")
print(f"Average Inference Time: {final_metrics['inference_time']*1000:.2f} ms")
print(f"Memory Usage: {final_metrics['memory_usage']:.1f} MB")
print(f"Number of Parameters: {final_metrics['num_parameters']:,}")

Loading best model for final evaluation...
Loaded model checkpoint with validation loss: 0.001735


  checkpoint = torch.load('best_unet_model_attention_3.pt')



Final Model Evaluation:
MSE: 0.001735
MAE: 0.021225
PSNR: 27.60 dB
R² Score: 0.5772
Peak Detection Accuracy: 46.47%
Average Inference Time: 5.35 ms
Memory Usage: 2794.5 MB
Number of Parameters: 1,835,760


In [None]:
inputs, targets = next(iter(val_loader))
inputs, targets = inputs.to(device), targets.to(device)

with torch.no_grad():
    outputs = model(inputs)

inputs = inputs.cpu().numpy()
targets = targets.cpu().numpy()
outputs = outputs.cpu().numpy()

example_idx = 0

plt.figure(figsize=(15, 5))

plt.subplot(131)
plt.imshow(inputs[example_idx, 4, :, :])  # Middle temporal slice (4 out of 8)
plt.title('Input 9x9 patch\n(Middle Time Point)')
plt.colorbar()

plt.subplot(132)
plt.imshow(targets[example_idx, 16, :, :])  # Middle spectral point (16 out of 32)
plt.title('Target 9x9 patch\n(Middle Spectral Point)')
plt.colorbar()

plt.subplot(133)
plt.imshow(outputs[example_idx, 16, :, :])  # Middle spectral point
plt.title('Prediction 9x9 patch\n(Middle Spectral Point)')
plt.colorbar()

plt.tight_layout()
plt.show()

# Plot all 81 voxels (9x9) in a single plot
plt.figure(figsize=(15, 8))

# Plot all voxels
for i in range(9):
    for j in range(9):
        plt.plot(targets[example_idx, :, i, j], 'b-', alpha=0.3, label='Target' if i==0 and j==0 else "")
        plt.plot(outputs[example_idx, :, i, j], 'r--', alpha=0.3, label='Prediction' if i==0 and j==0 else "")

plt.title('T1 Spectra for All Voxels in 9x9 Patch')
plt.xlabel('T1 Index (32 points)')
plt.ylabel('Amplitude')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Print metrics for the whole 9x9 patch
patch_mse = np.mean((targets[example_idx] - outputs[example_idx])**2)
patch_mae = np.mean(np.abs(targets[example_idx] - outputs[example_idx]))
print(f"Full 9x9 patch MSE: {patch_mse:.6f}")
print(f"Full 9x9 patch MAE: {patch_mae:.6f}")

# Calculate statistics of individual voxel MSEs
voxel_mses = []
for i in range(9):
    for j in range(9):
        mse = np.mean((targets[example_idx, :, i, j] - outputs[example_idx, :, i, j])**2)
        voxel_mses.append(mse)

print(f"\nMean Voxel MSE: {np.mean(voxel_mses):.6f}")
print(f"Std Voxel MSE: {np.std(voxel_mses):.6f}")
print(f"Min Voxel MSE: {np.min(voxel_mses):.6f}")
print(f"Max Voxel MSE: {np.max(voxel_mses):.6f}")

In [None]:
wandb.finish()