In [8]:
# Place this as the FIRST cell, before importing torch.
import os, random
import numpy as np

SEED = 42

# For Python determinism
os.environ["PYTHONHASHSEED"] = str(SEED)
# Deterministic cuBLAS (required for some CUDA matmul ops)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import torch

# Seed Python, NumPy, Torch (CPU and CUDA)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Turn on deterministic behavior
torch.use_deterministic_algorithms(True)  # may raise on nondeterministic ops
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Keep math consistent
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

# Optional: remove threading non-determinism
# torch.set_num_threads(1)

# Helpers for DataLoader reproducibility
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

<torch._C.Generator at 0x7a5f44730f70>

In [9]:
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from src.datasets.seeg_dataset import SEEGDataset
from src.models.model import SEEGFusionModel
from src.utils import move_to_device
from tqdm import tqdm
import time
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cpu


In [10]:
dataset = SEEGDataset(subjects=['Epat26', 'Epat30', 'Epat31'])
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_ds, val_ds = random_split(dataset, [train_size, val_size], generator=g)

dataloaders = {
    'train': DataLoader(train_ds, batch_size=8, shuffle=True,
                        num_workers=0, worker_init_fn=seed_worker, generator=g),
    'val': DataLoader(val_ds, batch_size=8, shuffle=False,
                      num_workers=0, worker_init_fn=seed_worker, generator=g),
}

[32m2025-11-05 16:08:10.822[0m | [32m[1mSUCCESS [0m | [36msrc.datasets.seeg_dataset[0m:[36m__init__[0m:[36m128[0m - [32m[1m✅ Loaded 128 total samples from 3 subjects.[0m


In [11]:
print(next(iter(dataloaders['train']))[0]['convergent'].shape)
print(next(iter(dataloaders['train']))[0]['divergent'].shape)

torch.Size([8, 66, 50, 487])
torch.Size([8, 103, 50, 487])


In [12]:
model = SEEGFusionModel(embed_dim=128, n_classes=2)
optimizer = optim.AdamW(model.parameters(), lr=1e-6)
criterion = nn.CrossEntropyLoss()

In [13]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters in the model: {total_params}")

Total trainable parameters in the model: 5294346


In [14]:
model.to(device)
torch.autograd.set_detect_anomaly(True)
n_epochs = 3

print(f"Starting training for {n_epochs} epochs on device: {device}\n{'='*60}")

for epoch in range(1, n_epochs + 1):
    epoch_start = time.time()
    train_loss = 0.0
    train_acc = 0

    model.train()
    print(f"\nEpoch {epoch}/{n_epochs}")
    print("-" * 60)

    # training loop
    for batch_idx, (inputs, labels) in enumerate(tqdm(dataloaders['train'], desc="Training", leave=False)):
        inputs = move_to_device(inputs, device)
        labels = move_to_device(labels, device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        if torch.isnan(loss):
            print(f"NaN loss detected at batch {batch_idx}")
            break

        loss.backward()

        ## DEBUGGING ##
        # Check gradients
        with torch.no_grad():
            total_norm = 0
            for name, param in model.named_parameters():
                if param.grad is not None:
                    param_norm = param.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                    if torch.isnan(param.grad).any():
                        print(f"NaN in gradient of {name}")
                    if param_norm.item() > 1e5:
                        print(f"HUGE gradient in {name}: {param_norm.item():.3e}")
            total_norm = total_norm ** 0.5
            print(f"Gradient norm: {total_norm:.3e}")

        # Optionally: predict what the update would do
        with torch.no_grad():
            for name, param in model.named_parameters():
                if param.grad is None: continue
                delta = -optimizer.param_groups[0]['lr'] * param.grad  # approximate for SGD-like step
                param_norm = param.data.norm().item()
                delta_norm = delta.norm().item()
                if delta_norm > 0:
                    ratio = delta_norm / (param_norm + 1e-12)
                else:
                    ratio = 0.0
                if ratio > 1e3 or delta_norm > 1e-1:   # tune thresholds
                    print(f"{name}: param_norm={param_norm:.3e}, delta_norm={delta_norm:.3e}, ratio={ratio:.3e}")

        ## END DEBUGGING ##

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        train_loss += loss.item() * inputs['convergent'].size(0)
        _, preds = torch.max(outputs, 1)
        train_acc += torch.sum(preds == labels.data)

        # print progress every N batches
        if (batch_idx + 1) % 3 == 0 or (batch_idx + 1) == len(dataloaders['train']):
            avg_loss = train_loss / ((batch_idx + 1) * inputs['convergent'].size(0))
            avg_acc = train_acc.double() / ((batch_idx + 1) * inputs['convergent'].size(0))
            tqdm.write(f"[Batch {batch_idx+1}/{len(dataloaders['train'])}] "
                       f"Train Loss: {avg_loss:.4f}, Train Acc: {avg_acc:.4f}")

    # end of training epoch
    epoch_train_loss = train_loss / len(dataloaders['train'].dataset)
    epoch_train_acc = train_acc.double() / len(dataloaders['train'].dataset)

    # validation loop
    model.eval()
    val_loss = 0.0
    val_acc = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloaders['val'], desc="Validating", leave=False):
            inputs = move_to_device(inputs, device)
            labels = move_to_device(labels, device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs['convergent'].size(0)
            _, preds = torch.max(outputs, 1)
            val_acc += torch.sum(preds == labels.data)

    epoch_val_loss = val_loss / len(dataloaders['val'].dataset)
    epoch_val_acc = val_acc.double() / len(dataloaders['val'].dataset)
    epoch_time = time.time() - epoch_start

    print(f"Epoch {epoch} Summary:")
    print(f"  Train Loss: {epoch_train_loss:.4f} | Train Acc: {epoch_train_acc:.4f}")
    print(f"  Validation Loss: {epoch_val_loss:.4f} | Validation Acc: {epoch_val_acc:.4f}")
    print(f"  Time: {epoch_time:.2f} sec")
    print("=" * 60)


Starting training for 3 epochs on device: cpu

Epoch 1/3
------------------------------------------------------------


Training:   0%|          | 0/13 [00:00<?, ?it/s]

resnet_conv_output: mean -2.111e-02, std 4.223e-01, min -1.866e+01, max 1.984e+01


Training:   8%|▊         | 1/13 [00:39<07:49, 39.11s/it]

Gradient norm: 4.128e+00
conv_msresnet.layer3x3_3.0.bn2.bias: param_norm=0.000e+00, delta_norm=1.799e-09, ratio=1.799e+03
conv_msresnet.layer3x3_3.0.downsample.1.bias: param_norm=0.000e+00, delta_norm=1.799e-09, ratio=1.799e+03
conv_msresnet.layer5x5_3.0.bn2.bias: param_norm=0.000e+00, delta_norm=1.764e-09, ratio=1.764e+03
conv_msresnet.layer5x5_3.0.downsample.1.bias: param_norm=0.000e+00, delta_norm=1.764e-09, ratio=1.764e+03
conv_msresnet.layer7x7_3.0.bn2.bias: param_norm=0.000e+00, delta_norm=1.799e-09, ratio=1.799e+03
conv_msresnet.layer7x7_3.0.downsample.1.bias: param_norm=0.000e+00, delta_norm=1.799e-09, ratio=1.799e+03
div_msresnet.layer3x3_3.0.bn2.bias: param_norm=0.000e+00, delta_norm=1.171e-09, ratio=1.171e+03
div_msresnet.layer3x3_3.0.downsample.1.bias: param_norm=0.000e+00, delta_norm=1.171e-09, ratio=1.171e+03
div_msresnet.layer5x5_3.0.bn2.bias: param_norm=0.000e+00, delta_norm=1.156e-09, ratio=1.156e+03
div_msresnet.layer5x5_3.0.downsample.1.bias: param_norm=0.000e+00, de

Training:  15%|█▌        | 2/13 [01:14<06:45, 36.85s/it]

Gradient norm: 5.559e+00
resnet_conv_output: mean -2.127e-02, std 4.286e-01, min -1.347e+01, max 1.538e+01


Training:  23%|██▎       | 3/13 [01:49<06:00, 36.02s/it]

Gradient norm: 3.413e+00
[Batch 3/13] Train Loss: 0.7282, Train Acc: 0.4167
resnet_conv_output: mean -1.827e-02, std 4.173e-01, min -2.578e+01, max 2.005e+01


Training:  31%|███       | 4/13 [02:21<05:11, 34.56s/it]

Gradient norm: 3.817e+00
resnet_conv_output: mean -1.952e-02, std 4.242e-01, min -2.228e+01, max 1.326e+01


Training:  38%|███▊      | 5/13 [03:01<04:52, 36.61s/it]

Gradient norm: 1.630e+00
resnet_conv_output: mean -1.796e-02, std 4.156e-01, min -2.044e+01, max 1.421e+01


Training:  46%|████▌     | 6/13 [03:38<04:16, 36.69s/it]

Gradient norm: 1.160e+00
[Batch 6/13] Train Loss: 0.6946, Train Acc: 0.5208


                                                        

KeyboardInterrupt: 