In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
from scipy.io import loadmat
from tqdm import tqdm
import wandb

from scipy.signal import find_peaks
from skimage.metrics import structural_similarity as ssim
from sklearn.metrics import r2_score
import psutil
import os
import time

In [2]:
wandb.login()
run = wandb.init(project="master-multicomponent-mri", name="single-signal-baseline")

[34m[1mwandb[0m: Currently logged in as: [33mtr-phan[0m ([33mtrphan[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [None]:
class T1Dataset(Dataset):
    def __init__(self, input_data, target_data):
        self.input_data = torch.FloatTensor(input_data.T)  # Shape: (400000, 8)
        self.target_data = torch.FloatTensor(target_data.T)  # Shape: (400000, 32)
    
    def __len__(self):
        return len(self.input_data)
    
    def __getitem__(self, idx):
        return self.input_data[idx], self.target_data[idx]

In [None]:
class T1ReconstructionNet(nn.Module):
    def __init__(self):
        super(T1ReconstructionNet, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(8, 128),
            nn.LeakyReLU(0.1),
            nn.BatchNorm1d(128),
            
            nn.Linear(128, 256),
            nn.LeakyReLU(0.1),
            nn.BatchNorm1d(256),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.1),
            nn.BatchNorm1d(512),
            nn.Dropout(0.2),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.1),
            nn.BatchNorm1d(256),
            
            nn.Linear(256, 128),
            nn.LeakyReLU(0.1),
            nn.BatchNorm1d(128),
            
            nn.Linear(128, 32)
        )
    
    def forward(self, x):
        return self.network(x)

In [None]:
class EarlyStopping:
    def __init__(self, patience=15, min_delta=1e-6, path='saved_models/baseline_nn.pt'):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.path = path
        
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.save_checkpoint(model)
            self.counter = 0

    def save_checkpoint(self, model):
        torch.save(model.state_dict(), self.path)

In [6]:
data = loadmat('../../data/data_T1_Q32_P8_400k.mat')
input_noisy = data['input_noisy']
ref = data['ref']
no_comp = data['no_comp']

In [7]:
X_train, X_val, y_train, y_val = train_test_split(
    input_noisy.T, ref.T, 
    test_size=0.2, 
    random_state=42
)

# Create datasets
train_dataset = T1Dataset(X_train.T, y_train.T)
val_dataset = T1Dataset(X_val.T, y_val.T)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T1ReconstructionNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.5, 
    patience=10,
    verbose=True
)
early_stopping = EarlyStopping(patience=15)
batch_size = 256



In [None]:
wandb.config.update({
    "learning_rate": 0.001,
    "batch_size": batch_size,
    "epochs": 1,
    "architecture": "SingleSignalMLP",
    "optimizer": "Adam",
    "loss_function": "MSELoss",
    "scheduler": "ReduceLROnPlateau",
    "early_stopping_patience": 15
})

In [10]:
for epoch in range(200):
    # Training phase
    model.train()
    train_loss = 0.0
    progress_bar_train = tqdm(train_loader, desc=f'Epoch [{epoch+1}/200] Training')
    
    for inputs, targets in progress_bar_train:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        train_loss += loss.item()
        progress_bar_train.set_postfix({'loss': loss.item()})

    avg_train_loss = train_loss / len(train_loader)

    # Validation phase
    model.eval()
    val_loss = 0.0
    progress_bar_val = tqdm(val_loader, desc=f'Epoch [{epoch+1}/200] Validation')
    
    with torch.no_grad():
        for inputs, targets in progress_bar_val:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            progress_bar_val.set_postfix({'loss': loss.item()})

    avg_val_loss = val_loss / len(val_loader)

    # Learning rate scheduling
    scheduler.step(avg_val_loss)
    current_lr = optimizer.param_groups[0]['lr']

    # Early stopping
    early_stopping(avg_val_loss, model)

    # Logging
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": avg_train_loss,
        "val_loss": avg_val_loss,
        "learning_rate": current_lr
    })

    print(f'Epoch [{epoch+1}/200], Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}, LR: {current_lr}')

    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

Epoch [1/200] Training: 100%|██████████| 1250/1250 [00:09<00:00, 128.96it/s, loss=0.00263]
Epoch [1/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 278.98it/s, loss=0.0024] 


Epoch [1/200], Train Loss: 0.004034, Val Loss: 0.002669, LR: 0.001


Epoch [2/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 139.61it/s, loss=0.00278]
Epoch [2/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 343.26it/s, loss=0.00225]


Epoch [2/200], Train Loss: 0.002740, Val Loss: 0.002608, LR: 0.001


Epoch [3/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 152.47it/s, loss=0.00283]
Epoch [3/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 299.27it/s, loss=0.00238]


EarlyStopping counter: 1 out of 15
Epoch [3/200], Train Loss: 0.002696, Val Loss: 0.002708, LR: 0.001


Epoch [4/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 162.83it/s, loss=0.0026] 
Epoch [4/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 291.63it/s, loss=0.00241]


EarlyStopping counter: 2 out of 15
Epoch [4/200], Train Loss: 0.002674, Val Loss: 0.002646, LR: 0.001


Epoch [5/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 162.39it/s, loss=0.00259]
Epoch [5/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 333.19it/s, loss=0.00228]


EarlyStopping counter: 3 out of 15
Epoch [5/200], Train Loss: 0.002661, Val Loss: 0.002642, LR: 0.001


Epoch [6/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 143.21it/s, loss=0.00275]
Epoch [6/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 277.98it/s, loss=0.0023] 


EarlyStopping counter: 4 out of 15
Epoch [6/200], Train Loss: 0.002652, Val Loss: 0.002619, LR: 0.001


Epoch [7/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 153.64it/s, loss=0.00258]
Epoch [7/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 306.31it/s, loss=0.00232]


Epoch [7/200], Train Loss: 0.002640, Val Loss: 0.002605, LR: 0.001


Epoch [8/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 150.54it/s, loss=0.00282]
Epoch [8/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 291.30it/s, loss=0.00228]


Epoch [8/200], Train Loss: 0.002629, Val Loss: 0.002573, LR: 0.001


Epoch [9/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 150.23it/s, loss=0.00254]
Epoch [9/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 301.70it/s, loss=0.00242]


EarlyStopping counter: 1 out of 15
Epoch [9/200], Train Loss: 0.002618, Val Loss: 0.002651, LR: 0.001


Epoch [10/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 149.55it/s, loss=0.00239]
Epoch [10/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 295.65it/s, loss=0.00233]


EarlyStopping counter: 2 out of 15
Epoch [10/200], Train Loss: 0.002612, Val Loss: 0.002661, LR: 0.001


Epoch [11/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 142.19it/s, loss=0.00275]
Epoch [11/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 296.65it/s, loss=0.00233]


EarlyStopping counter: 3 out of 15
Epoch [11/200], Train Loss: 0.002605, Val Loss: 0.002626, LR: 0.001


Epoch [12/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 147.93it/s, loss=0.00262]
Epoch [12/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 299.08it/s, loss=0.00231]


Epoch [12/200], Train Loss: 0.002595, Val Loss: 0.002548, LR: 0.001


Epoch [13/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 142.99it/s, loss=0.00259]
Epoch [13/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 269.88it/s, loss=0.00225]


EarlyStopping counter: 1 out of 15
Epoch [13/200], Train Loss: 0.002593, Val Loss: 0.002549, LR: 0.001


Epoch [14/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 145.34it/s, loss=0.00259]
Epoch [14/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 293.30it/s, loss=0.00231]


EarlyStopping counter: 2 out of 15
Epoch [14/200], Train Loss: 0.002587, Val Loss: 0.002600, LR: 0.001


Epoch [15/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 155.05it/s, loss=0.00265]
Epoch [15/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 301.72it/s, loss=0.00225]


Epoch [15/200], Train Loss: 0.002581, Val Loss: 0.002546, LR: 0.001


Epoch [16/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 154.65it/s, loss=0.00251]
Epoch [16/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 330.36it/s, loss=0.00229]


EarlyStopping counter: 1 out of 15
Epoch [16/200], Train Loss: 0.002574, Val Loss: 0.002553, LR: 0.001


Epoch [17/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 154.93it/s, loss=0.00263]
Epoch [17/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 328.40it/s, loss=0.00218]


Epoch [17/200], Train Loss: 0.002568, Val Loss: 0.002523, LR: 0.001


Epoch [18/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 156.86it/s, loss=0.00254]
Epoch [18/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 274.87it/s, loss=0.00221]


EarlyStopping counter: 1 out of 15
Epoch [18/200], Train Loss: 0.002562, Val Loss: 0.002532, LR: 0.001


Epoch [19/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 145.16it/s, loss=0.00272]
Epoch [19/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 304.43it/s, loss=0.00228]


EarlyStopping counter: 2 out of 15
Epoch [19/200], Train Loss: 0.002559, Val Loss: 0.002539, LR: 0.001


Epoch [20/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 146.93it/s, loss=0.00264]
Epoch [20/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 316.23it/s, loss=0.00224]


Epoch [20/200], Train Loss: 0.002554, Val Loss: 0.002520, LR: 0.001


Epoch [21/200] Training: 100%|██████████| 1250/1250 [00:10<00:00, 124.79it/s, loss=0.00267]
Epoch [21/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 282.75it/s, loss=0.00228]


EarlyStopping counter: 1 out of 15
Epoch [21/200], Train Loss: 0.002548, Val Loss: 0.002548, LR: 0.001


Epoch [22/200] Training: 100%|██████████| 1250/1250 [00:09<00:00, 135.10it/s, loss=0.00229]
Epoch [22/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 340.77it/s, loss=0.00227]


EarlyStopping counter: 2 out of 15
Epoch [22/200], Train Loss: 0.002544, Val Loss: 0.002542, LR: 0.001


Epoch [23/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 147.03it/s, loss=0.00252]
Epoch [23/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 298.70it/s, loss=0.0022] 


EarlyStopping counter: 3 out of 15
Epoch [23/200], Train Loss: 0.002540, Val Loss: 0.002523, LR: 0.001


Epoch [24/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 149.91it/s, loss=0.00241]
Epoch [24/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 325.09it/s, loss=0.00222]


Epoch [24/200], Train Loss: 0.002535, Val Loss: 0.002507, LR: 0.001


Epoch [25/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 155.76it/s, loss=0.00243]
Epoch [25/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 327.94it/s, loss=0.00226]


EarlyStopping counter: 1 out of 15
Epoch [25/200], Train Loss: 0.002532, Val Loss: 0.002517, LR: 0.001


Epoch [26/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 167.39it/s, loss=0.00269]
Epoch [26/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 335.61it/s, loss=0.00221]


Epoch [26/200], Train Loss: 0.002530, Val Loss: 0.002491, LR: 0.001


Epoch [27/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 168.37it/s, loss=0.00266]
Epoch [27/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 336.30it/s, loss=0.00227]


EarlyStopping counter: 1 out of 15
Epoch [27/200], Train Loss: 0.002524, Val Loss: 0.002513, LR: 0.001


Epoch [28/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 149.42it/s, loss=0.00235]
Epoch [28/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 290.12it/s, loss=0.00225]


EarlyStopping counter: 2 out of 15
Epoch [28/200], Train Loss: 0.002522, Val Loss: 0.002509, LR: 0.001


Epoch [29/200] Training: 100%|██████████| 1250/1250 [00:09<00:00, 138.47it/s, loss=0.00251]
Epoch [29/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 228.36it/s, loss=0.00219]


EarlyStopping counter: 3 out of 15
Epoch [29/200], Train Loss: 0.002520, Val Loss: 0.002497, LR: 0.001


Epoch [30/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 148.20it/s, loss=0.00258]
Epoch [30/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 292.10it/s, loss=0.00216]


EarlyStopping counter: 4 out of 15
Epoch [30/200], Train Loss: 0.002516, Val Loss: 0.002502, LR: 0.001


Epoch [31/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 154.67it/s, loss=0.00249]
Epoch [31/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 326.29it/s, loss=0.00216]


EarlyStopping counter: 5 out of 15
Epoch [31/200], Train Loss: 0.002514, Val Loss: 0.002494, LR: 0.001


Epoch [32/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 158.71it/s, loss=0.00247]
Epoch [32/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 265.75it/s, loss=0.00226]


EarlyStopping counter: 6 out of 15
Epoch [32/200], Train Loss: 0.002511, Val Loss: 0.002495, LR: 0.001


Epoch [33/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 161.75it/s, loss=0.00268]
Epoch [33/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 263.80it/s, loss=0.00218]


EarlyStopping counter: 7 out of 15
Epoch [33/200], Train Loss: 0.002509, Val Loss: 0.002498, LR: 0.001


Epoch [34/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 152.06it/s, loss=0.00266]
Epoch [34/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 319.96it/s, loss=0.0022] 


EarlyStopping counter: 8 out of 15
Epoch [34/200], Train Loss: 0.002509, Val Loss: 0.002494, LR: 0.001


Epoch [35/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 148.25it/s, loss=0.00243]
Epoch [35/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 330.21it/s, loss=0.00219]


Epoch [35/200], Train Loss: 0.002507, Val Loss: 0.002479, LR: 0.001


Epoch [36/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 151.03it/s, loss=0.00223]
Epoch [36/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 306.01it/s, loss=0.00221]


EarlyStopping counter: 1 out of 15
Epoch [36/200], Train Loss: 0.002505, Val Loss: 0.002484, LR: 0.001


Epoch [37/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 164.50it/s, loss=0.00252]
Epoch [37/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 321.97it/s, loss=0.00219]


EarlyStopping counter: 2 out of 15
Epoch [37/200], Train Loss: 0.002504, Val Loss: 0.002486, LR: 0.001


Epoch [38/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 157.20it/s, loss=0.00244]
Epoch [38/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 298.67it/s, loss=0.0022] 


EarlyStopping counter: 3 out of 15
Epoch [38/200], Train Loss: 0.002504, Val Loss: 0.002499, LR: 0.001


Epoch [39/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 160.39it/s, loss=0.00264]
Epoch [39/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 319.22it/s, loss=0.00223]


EarlyStopping counter: 4 out of 15
Epoch [39/200], Train Loss: 0.002504, Val Loss: 0.002492, LR: 0.001


Epoch [40/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 158.50it/s, loss=0.0024] 
Epoch [40/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 321.30it/s, loss=0.00221]


EarlyStopping counter: 5 out of 15
Epoch [40/200], Train Loss: 0.002502, Val Loss: 0.002484, LR: 0.001


Epoch [41/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 153.95it/s, loss=0.00257]
Epoch [41/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 328.08it/s, loss=0.00218]


EarlyStopping counter: 6 out of 15
Epoch [41/200], Train Loss: 0.002501, Val Loss: 0.002486, LR: 0.001


Epoch [42/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 157.53it/s, loss=0.00242]
Epoch [42/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 323.95it/s, loss=0.0022] 


EarlyStopping counter: 7 out of 15
Epoch [42/200], Train Loss: 0.002500, Val Loss: 0.002484, LR: 0.001


Epoch [43/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 163.10it/s, loss=0.00246]
Epoch [43/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 290.34it/s, loss=0.00218]


EarlyStopping counter: 8 out of 15
Epoch [43/200], Train Loss: 0.002501, Val Loss: 0.002486, LR: 0.001


Epoch [44/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 161.57it/s, loss=0.00247]
Epoch [44/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 333.02it/s, loss=0.00217]


EarlyStopping counter: 9 out of 15
Epoch [44/200], Train Loss: 0.002500, Val Loss: 0.002483, LR: 0.001


Epoch [45/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 152.79it/s, loss=0.00228]
Epoch [45/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 330.41it/s, loss=0.00224]


EarlyStopping counter: 10 out of 15
Epoch [45/200], Train Loss: 0.002500, Val Loss: 0.002502, LR: 0.001


Epoch [46/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 149.06it/s, loss=0.00238]
Epoch [46/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 324.94it/s, loss=0.00214]


EarlyStopping counter: 11 out of 15
Epoch [46/200], Train Loss: 0.002500, Val Loss: 0.002490, LR: 0.0005


Epoch [47/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 152.06it/s, loss=0.00233]
Epoch [47/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 337.07it/s, loss=0.00219]


Epoch [47/200], Train Loss: 0.002489, Val Loss: 0.002473, LR: 0.0005


Epoch [48/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 156.85it/s, loss=0.00248]
Epoch [48/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 299.56it/s, loss=0.00219]


EarlyStopping counter: 1 out of 15
Epoch [48/200], Train Loss: 0.002488, Val Loss: 0.002475, LR: 0.0005


Epoch [49/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 157.64it/s, loss=0.00242]
Epoch [49/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 326.37it/s, loss=0.00217]


EarlyStopping counter: 2 out of 15
Epoch [49/200], Train Loss: 0.002489, Val Loss: 0.002473, LR: 0.0005


Epoch [50/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 162.43it/s, loss=0.00202]
Epoch [50/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 301.07it/s, loss=0.00221]


EarlyStopping counter: 3 out of 15
Epoch [50/200], Train Loss: 0.002488, Val Loss: 0.002477, LR: 0.0005


Epoch [51/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 151.84it/s, loss=0.00269]
Epoch [51/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 293.88it/s, loss=0.00218]


EarlyStopping counter: 4 out of 15
Epoch [51/200], Train Loss: 0.002488, Val Loss: 0.002477, LR: 0.0005


Epoch [52/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 149.31it/s, loss=0.0024] 
Epoch [52/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 249.64it/s, loss=0.00215]


EarlyStopping counter: 5 out of 15
Epoch [52/200], Train Loss: 0.002488, Val Loss: 0.002475, LR: 0.0005


Epoch [53/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 152.60it/s, loss=0.00258]
Epoch [53/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 231.01it/s, loss=0.00218]


EarlyStopping counter: 6 out of 15
Epoch [53/200], Train Loss: 0.002487, Val Loss: 0.002473, LR: 0.0005


Epoch [54/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 156.02it/s, loss=0.00274]
Epoch [54/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 331.92it/s, loss=0.00217]


EarlyStopping counter: 7 out of 15
Epoch [54/200], Train Loss: 0.002486, Val Loss: 0.002476, LR: 0.0005


Epoch [55/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 157.89it/s, loss=0.0025] 
Epoch [55/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 332.83it/s, loss=0.00218]


EarlyStopping counter: 8 out of 15
Epoch [55/200], Train Loss: 0.002487, Val Loss: 0.002476, LR: 0.0005


Epoch [56/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 160.79it/s, loss=0.00256]
Epoch [56/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 284.12it/s, loss=0.00218]


Epoch [56/200], Train Loss: 0.002487, Val Loss: 0.002469, LR: 0.0005


Epoch [57/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 154.02it/s, loss=0.00248]
Epoch [57/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 316.29it/s, loss=0.00216]


EarlyStopping counter: 1 out of 15
Epoch [57/200], Train Loss: 0.002485, Val Loss: 0.002474, LR: 0.0005


Epoch [58/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 147.50it/s, loss=0.00252]
Epoch [58/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 293.60it/s, loss=0.00218]


EarlyStopping counter: 2 out of 15
Epoch [58/200], Train Loss: 0.002486, Val Loss: 0.002475, LR: 0.0005


Epoch [59/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 155.43it/s, loss=0.00269]
Epoch [59/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 311.61it/s, loss=0.00216]


EarlyStopping counter: 3 out of 15
Epoch [59/200], Train Loss: 0.002485, Val Loss: 0.002474, LR: 0.0005


Epoch [60/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 166.10it/s, loss=0.00234]
Epoch [60/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 315.87it/s, loss=0.00216]


EarlyStopping counter: 4 out of 15
Epoch [60/200], Train Loss: 0.002486, Val Loss: 0.002473, LR: 0.0005


Epoch [61/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 160.96it/s, loss=0.0023] 
Epoch [61/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 307.13it/s, loss=0.00218]


EarlyStopping counter: 5 out of 15
Epoch [61/200], Train Loss: 0.002486, Val Loss: 0.002479, LR: 0.0005


Epoch [62/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 159.20it/s, loss=0.00244]
Epoch [62/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 308.81it/s, loss=0.00218]


EarlyStopping counter: 6 out of 15
Epoch [62/200], Train Loss: 0.002485, Val Loss: 0.002473, LR: 0.0005


Epoch [63/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 159.72it/s, loss=0.00251]
Epoch [63/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 281.69it/s, loss=0.00219]


EarlyStopping counter: 7 out of 15
Epoch [63/200], Train Loss: 0.002486, Val Loss: 0.002472, LR: 0.0005


Epoch [64/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 150.21it/s, loss=0.0025] 
Epoch [64/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 319.36it/s, loss=0.00223]


EarlyStopping counter: 8 out of 15
Epoch [64/200], Train Loss: 0.002486, Val Loss: 0.002473, LR: 0.0005


Epoch [65/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 154.88it/s, loss=0.00238]
Epoch [65/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 342.24it/s, loss=0.00218]


EarlyStopping counter: 9 out of 15
Epoch [65/200], Train Loss: 0.002485, Val Loss: 0.002476, LR: 0.0005


Epoch [66/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 164.01it/s, loss=0.00261]
Epoch [66/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 289.70it/s, loss=0.00218]


EarlyStopping counter: 10 out of 15
Epoch [66/200], Train Loss: 0.002486, Val Loss: 0.002473, LR: 0.0005


Epoch [67/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 164.41it/s, loss=0.00261]
Epoch [67/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 331.40it/s, loss=0.0022] 


EarlyStopping counter: 11 out of 15
Epoch [67/200], Train Loss: 0.002485, Val Loss: 0.002477, LR: 0.00025


Epoch [68/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 150.15it/s, loss=0.00266]
Epoch [68/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 291.68it/s, loss=0.00219]


EarlyStopping counter: 12 out of 15
Epoch [68/200], Train Loss: 0.002480, Val Loss: 0.002469, LR: 0.00025


Epoch [69/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 152.18it/s, loss=0.00261]
Epoch [69/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 294.18it/s, loss=0.00217]


EarlyStopping counter: 13 out of 15
Epoch [69/200], Train Loss: 0.002479, Val Loss: 0.002469, LR: 0.00025


Epoch [70/200] Training: 100%|██████████| 1250/1250 [00:09<00:00, 137.87it/s, loss=0.00238]
Epoch [70/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 278.05it/s, loss=0.00217]


EarlyStopping counter: 14 out of 15
Epoch [70/200], Train Loss: 0.002479, Val Loss: 0.002469, LR: 0.00025


Epoch [71/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 148.90it/s, loss=0.00217]
Epoch [71/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 296.16it/s, loss=0.00216]


Epoch [71/200], Train Loss: 0.002479, Val Loss: 0.002468, LR: 0.00025


Epoch [72/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 144.86it/s, loss=0.00248]
Epoch [72/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 317.71it/s, loss=0.00219]


EarlyStopping counter: 1 out of 15
Epoch [72/200], Train Loss: 0.002479, Val Loss: 0.002467, LR: 0.00025


Epoch [73/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 142.89it/s, loss=0.00246]
Epoch [73/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 291.97it/s, loss=0.00218]


EarlyStopping counter: 2 out of 15
Epoch [73/200], Train Loss: 0.002478, Val Loss: 0.002468, LR: 0.00025


Epoch [74/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 161.23it/s, loss=0.00258]
Epoch [74/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 335.96it/s, loss=0.00218]


EarlyStopping counter: 3 out of 15
Epoch [74/200], Train Loss: 0.002479, Val Loss: 0.002468, LR: 0.00025


Epoch [75/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 154.18it/s, loss=0.00248]
Epoch [75/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 336.71it/s, loss=0.00216]


EarlyStopping counter: 4 out of 15
Epoch [75/200], Train Loss: 0.002478, Val Loss: 0.002470, LR: 0.00025


Epoch [76/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 147.34it/s, loss=0.00236]
Epoch [76/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 260.76it/s, loss=0.00218]


EarlyStopping counter: 5 out of 15
Epoch [76/200], Train Loss: 0.002479, Val Loss: 0.002468, LR: 0.00025


Epoch [77/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 162.03it/s, loss=0.00263]
Epoch [77/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 293.31it/s, loss=0.00216]


EarlyStopping counter: 6 out of 15
Epoch [77/200], Train Loss: 0.002478, Val Loss: 0.002469, LR: 0.00025


Epoch [78/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 159.68it/s, loss=0.00239]
Epoch [78/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 309.20it/s, loss=0.00217]


EarlyStopping counter: 7 out of 15
Epoch [78/200], Train Loss: 0.002478, Val Loss: 0.002469, LR: 0.00025


Epoch [79/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 154.34it/s, loss=0.00231]
Epoch [79/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 326.33it/s, loss=0.00217]


EarlyStopping counter: 8 out of 15
Epoch [79/200], Train Loss: 0.002478, Val Loss: 0.002469, LR: 0.00025


Epoch [80/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 152.75it/s, loss=0.00242]
Epoch [80/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 318.09it/s, loss=0.00218]


EarlyStopping counter: 9 out of 15
Epoch [80/200], Train Loss: 0.002478, Val Loss: 0.002469, LR: 0.00025


Epoch [81/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 150.30it/s, loss=0.00244]
Epoch [81/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 295.84it/s, loss=0.00216]


EarlyStopping counter: 10 out of 15
Epoch [81/200], Train Loss: 0.002478, Val Loss: 0.002468, LR: 0.00025


Epoch [82/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 141.84it/s, loss=0.0024] 
Epoch [82/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 341.56it/s, loss=0.00219]


EarlyStopping counter: 11 out of 15
Epoch [82/200], Train Loss: 0.002478, Val Loss: 0.002469, LR: 0.00025


Epoch [83/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 160.25it/s, loss=0.00271]
Epoch [83/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 294.06it/s, loss=0.00216]


EarlyStopping counter: 12 out of 15
Epoch [83/200], Train Loss: 0.002478, Val Loss: 0.002470, LR: 0.000125


Epoch [84/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 161.47it/s, loss=0.00277]
Epoch [84/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 330.77it/s, loss=0.00218]


Epoch [84/200], Train Loss: 0.002475, Val Loss: 0.002466, LR: 0.000125


Epoch [85/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 154.72it/s, loss=0.0024] 
Epoch [85/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 337.82it/s, loss=0.00218]


Epoch [85/200], Train Loss: 0.002475, Val Loss: 0.002465, LR: 0.000125


Epoch [86/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 154.65it/s, loss=0.00234]
Epoch [86/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 282.01it/s, loss=0.00219]


EarlyStopping counter: 1 out of 15
Epoch [86/200], Train Loss: 0.002475, Val Loss: 0.002467, LR: 0.000125


Epoch [87/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 154.95it/s, loss=0.00275]
Epoch [87/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 332.38it/s, loss=0.00217]


EarlyStopping counter: 2 out of 15
Epoch [87/200], Train Loss: 0.002475, Val Loss: 0.002465, LR: 0.000125


Epoch [88/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 158.21it/s, loss=0.00235]
Epoch [88/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 299.08it/s, loss=0.00218]


EarlyStopping counter: 3 out of 15
Epoch [88/200], Train Loss: 0.002475, Val Loss: 0.002466, LR: 0.000125


Epoch [89/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 159.96it/s, loss=0.00235]
Epoch [89/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 318.92it/s, loss=0.00217]


EarlyStopping counter: 4 out of 15
Epoch [89/200], Train Loss: 0.002474, Val Loss: 0.002465, LR: 0.000125


Epoch [90/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 160.62it/s, loss=0.00225]
Epoch [90/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 326.76it/s, loss=0.00217]


EarlyStopping counter: 5 out of 15
Epoch [90/200], Train Loss: 0.002474, Val Loss: 0.002466, LR: 0.000125


Epoch [91/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 151.91it/s, loss=0.00239]
Epoch [91/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 310.11it/s, loss=0.00216]


EarlyStopping counter: 6 out of 15
Epoch [91/200], Train Loss: 0.002474, Val Loss: 0.002466, LR: 0.000125


Epoch [92/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 154.85it/s, loss=0.00234]
Epoch [92/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 340.42it/s, loss=0.00217]


EarlyStopping counter: 7 out of 15
Epoch [92/200], Train Loss: 0.002474, Val Loss: 0.002466, LR: 0.000125


Epoch [93/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 166.17it/s, loss=0.0024] 
Epoch [93/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 224.30it/s, loss=0.00216]


EarlyStopping counter: 8 out of 15
Epoch [93/200], Train Loss: 0.002474, Val Loss: 0.002466, LR: 0.000125


Epoch [94/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 155.62it/s, loss=0.00226]
Epoch [94/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 360.61it/s, loss=0.00217]


EarlyStopping counter: 9 out of 15
Epoch [94/200], Train Loss: 0.002474, Val Loss: 0.002467, LR: 0.000125


Epoch [95/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 162.34it/s, loss=0.00251]
Epoch [95/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 353.99it/s, loss=0.00217]


EarlyStopping counter: 10 out of 15
Epoch [95/200], Train Loss: 0.002474, Val Loss: 0.002466, LR: 0.000125


Epoch [96/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 164.83it/s, loss=0.00253]
Epoch [96/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 288.58it/s, loss=0.00218]


EarlyStopping counter: 11 out of 15
Epoch [96/200], Train Loss: 0.002474, Val Loss: 0.002466, LR: 6.25e-05


Epoch [97/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 162.06it/s, loss=0.00249]
Epoch [97/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 305.35it/s, loss=0.00218]


EarlyStopping counter: 12 out of 15
Epoch [97/200], Train Loss: 0.002473, Val Loss: 0.002465, LR: 6.25e-05


Epoch [98/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 155.01it/s, loss=0.00238]
Epoch [98/200] Validation: 100%|██████████| 313/313 [00:00<00:00, 324.98it/s, loss=0.00218]


EarlyStopping counter: 13 out of 15
Epoch [98/200], Train Loss: 0.002473, Val Loss: 0.002466, LR: 6.25e-05


Epoch [99/200] Training: 100%|██████████| 1250/1250 [00:08<00:00, 152.98it/s, loss=0.00215]
Epoch [99/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 308.69it/s, loss=0.00217]


EarlyStopping counter: 14 out of 15
Epoch [99/200], Train Loss: 0.002472, Val Loss: 0.002465, LR: 6.25e-05


Epoch [100/200] Training: 100%|██████████| 1250/1250 [00:07<00:00, 162.59it/s, loss=0.00251]
Epoch [100/200] Validation: 100%|██████████| 313/313 [00:01<00:00, 306.35it/s, loss=0.00217]

EarlyStopping counter: 15 out of 15
Epoch [100/200], Train Loss: 0.002473, Val Loss: 0.002465, LR: 6.25e-05
Early stopping triggered





In [None]:
class ModelEvaluator:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.metrics = {}

    def calculate_psnr(self, target, prediction):
        mse = np.mean((target - prediction) ** 2)
        if mse == 0:
            return float('inf')
        max_pixel = np.max(target)
        psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
        return psnr

    def calculate_peak_metrics(self, target, prediction, prominence=0.1):
        """Calculate peak detection accuracy"""
        peak_metrics = {'true_peaks': 0, 'predicted_peaks': 0, 'matching_peaks': 0}
        
        for i in range(len(target)):
            true_peaks, _ = find_peaks(target[i], prominence=prominence)
            pred_peaks, _ = find_peaks(prediction[i], prominence=prominence)
            
            # Count matching peaks (within ±1 index)
            matches = 0
            for tp in true_peaks:
                for pp in pred_peaks:
                    if abs(tp - pp) <= 1:
                        matches += 1
                        break
            
            peak_metrics['true_peaks'] += len(true_peaks)
            peak_metrics['predicted_peaks'] += len(pred_peaks)
            peak_metrics['matching_peaks'] += matches
        
        return peak_metrics

    def measure_inference_time(self, input_tensor, num_iterations=100):
        self.model.eval()
        start_time = time.time()
        with torch.no_grad():
            for _ in range(num_iterations):
                _ = self.model(input_tensor)
        end_time = time.time()
        return (end_time - start_time) / num_iterations

    def count_parameters(self):
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)

    def measure_memory_usage(self):
        process = psutil.Process(os.getpid())
        return process.memory_info().rss / 1024 / 1024  # Convert to MB

    def evaluate_model(self, val_loader):
        self.model.eval()
        metrics = {
            'mse': 0.0,
            'mae': 0.0,
            'psnr': 0.0,
            'r2_score': 0.0,
            'peak_accuracy': {},
            'inference_time': 0.0,
            'memory_usage': 0.0,
            'num_parameters': self.count_parameters()
        }
        
        all_targets = []
        all_predictions = []
        
        # Get sample input for inference time measurement
        sample_input, _ = next(iter(val_loader))
        metrics['inference_time'] = self.measure_inference_time(sample_input.to(self.device))
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                outputs = self.model(inputs)
                
                # Convert to numpy for metric calculation
                targets_np = targets.cpu().numpy()
                outputs_np = outputs.cpu().numpy()
                
                all_targets.append(targets_np)
                all_predictions.append(outputs_np)
        
        # Concatenate all batches
        all_targets = np.concatenate(all_targets)
        all_predictions = np.concatenate(all_predictions)
        
        # Calculate metrics
        metrics['mse'] = np.mean((all_targets - all_predictions) ** 2)
        metrics['mae'] = np.mean(np.abs(all_targets - all_predictions))
        metrics['psnr'] = self.calculate_psnr(all_targets, all_predictions)
        metrics['r2_score'] = r2_score(all_targets.flatten(), all_predictions.flatten())
        metrics['peak_metrics'] = self.calculate_peak_metrics(all_targets, all_predictions)
        metrics['memory_usage'] = self.measure_memory_usage()
        
        return metrics

evaluator = ModelEvaluator(model, device)

final_metrics = evaluator.evaluate_model(val_loader)

wandb.log({
    "final_mse": final_metrics['mse'],
    "final_mae": final_metrics['mae'],
    "final_psnr": final_metrics['psnr'],
    "final_r2_score": final_metrics['r2_score'],
    "peak_detection_accuracy": final_metrics['peak_metrics']['matching_peaks'] / final_metrics['peak_metrics']['true_peaks'],
    "inference_time_ms": final_metrics['inference_time'] * 1000,
    "memory_usage_mb": final_metrics['memory_usage'],
    "model_parameters": final_metrics['num_parameters']
})

print("\nFinal Model Evaluation:")
print(f"MSE: {final_metrics['mse']:.6f}")
print(f"MAE: {final_metrics['mae']:.6f}")
print(f"PSNR: {final_metrics['psnr']:.2f} dB")
print(f"R² Score: {final_metrics['r2_score']:.4f}")
print(f"Peak Detection Accuracy: {final_metrics['peak_metrics']['matching_peaks'] / final_metrics['peak_metrics']['true_peaks']:.2%}")
print(f"Average Inference Time: {final_metrics['inference_time']*1000:.2f} ms")
print(f"Memory Usage: {final_metrics['memory_usage']:.1f} MB")
print(f"Number of Parameters: {final_metrics['num_parameters']:,}")


Final Model Evaluation:
MSE: 0.002465
MAE: 0.026724
PSNR: 24.57 dB
R² Score: 0.4523
Peak Detection Accuracy: 35.70%
Average Inference Time: 0.87 ms
Memory Usage: 998.0 MB
Number of Parameters: 336,672


In [None]:
model.eval()

# Visualization function with peak analysis
def visualize_final_results():
    model.eval()
    results_by_peaks = {1: [], 2: [], 3: [], 4: []}
    
    with torch.no_grad():
        for num_peaks in range(1, 5):
            peak_indices = np.where(no_comp.flatten() == num_peaks)[0]
            sample_indices = np.random.choice(peak_indices, min(10, len(peak_indices)), replace=False)
            
            for idx in sample_indices:
                input_signal = torch.FloatTensor(input_noisy[:, idx]).unsqueeze(0).to(device)
                target = ref[:, idx]
                prediction = model(input_signal).cpu().squeeze(0).numpy()
                mse = np.mean((target - prediction)**2)
                results_by_peaks[num_peaks].append(mse)
    
    for num_peaks, mse_values in results_by_peaks.items():
        avg_mse = np.mean(mse_values)
        wandb.log({f"MSE_{num_peaks}_peaks": avg_mse})
        print(f"Average MSE for {num_peaks} peak(s): {avg_mse:.6f}")
        
visualize_final_results()

Average MSE for 1 peak(s): 0.001977
Average MSE for 2 peak(s): 0.001885
Average MSE for 3 peak(s): 0.002719
Average MSE for 4 peak(s): 0.002013


In [13]:
wandb.finish()

0,1
MSE_1_peaks,▁
MSE_2_peaks,▁
MSE_3_peaks,▁
MSE_4_peaks,▁
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
final_mae,▁
final_mse,▁
final_psnr,▁
final_r2_score,▁
inference_time_ms,▁

0,1
MSE_1_peaks,0.00198
MSE_2_peaks,0.00188
MSE_3_peaks,0.00272
MSE_4_peaks,0.00201
epoch,100.0
final_mae,0.02672
final_mse,0.00247
final_psnr,24.57014
final_r2_score,0.4523
inference_time_ms,0.86999
