In [1]:
# Place this as the FIRST cell, before importing torch.
import os, random
import numpy as np

SEED = 42

# For Python determinism
os.environ["PYTHONHASHSEED"] = str(SEED)
# Deterministic cuBLAS (required for some CUDA matmul ops)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import torch

# Seed Python, NumPy, Torch (CPU and CUDA)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Turn on deterministic behavior
torch.use_deterministic_algorithms(True)  # may raise on nondeterministic ops
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Keep math consistent
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

torch.autograd.set_detect_anomaly(True)

# Optional: remove threading non-determinism
# torch.set_num_threads(1)

# Helpers for DataLoader reproducibility
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(SEED)

  self.setter(val)


<torch._C.Generator at 0x7192fefa7950>

In [2]:
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from src.datasets.seeg_dataset import SEEGDataset
from src.models.model import SEEGFusionModel
from src.utils import move_to_device
from tqdm import tqdm
import time
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

Using device: cpu


In [3]:
dataset = SEEGDataset(subjects=['Epat26', 'Epat30', 'Epat31', 'Epat34', 'Epat35'])
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_ds, val_ds = random_split(dataset, [train_size, val_size], generator=g)

dataloaders = {
    'train': DataLoader(train_ds, batch_size=16, shuffle=True,
                        num_workers=0, worker_init_fn=seed_worker, generator=g),
    'val': DataLoader(val_ds, batch_size=16, shuffle=False,
                      num_workers=0, worker_init_fn=seed_worker, generator=g),
}

## Need to implement weighted random sampler
# sampler = WeightedRandomSampler(class_weights, num_samples=len(labels), replacement=True)

[32m2025-11-05 17:33:28.877[0m | [32m[1mSUCCESS [0m | [36msrc.datasets.seeg_dataset[0m:[36m__init__[0m:[36m128[0m - [32m[1m✅ Loaded 218 total samples from 5 subjects.[0m


In [4]:
# get weights for imbalanced classes

def compute_class_weights(train_ds):
    labels = np.array([v[1] for v in train_ds])
    class_sample_count = np.array(
        [len(np.where(labels == t)[0]) for t in np.unique(labels)])
    weight = class_sample_count.sum() / class_sample_count
    return torch.from_numpy(weight).float()

weights = compute_class_weights(train_ds)

In [5]:
print(next(iter(dataloaders['train']))[0]['convergent'].shape)
print(next(iter(dataloaders['train']))[0]['divergent'].shape)

torch.Size([16, 67, 50, 487])
torch.Size([16, 108, 50, 487])


In [6]:
model = SEEGFusionModel(embed_dim=128, n_classes=2)
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-6)
criterion = nn.CrossEntropyLoss(weight=weights.to(device))

In [7]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters in the model: {total_params}")

Total trainable parameters in the model: 5294346


In [None]:
n_epochs = 5

print(f"Starting training for {n_epochs} epochs on device: {device}\n{'='*60}")

for epoch in range(1, n_epochs + 1):
    epoch_start = time.time()
    train_loss = 0.0
    train_acc = 0

    model.train()
    print(f"\nEpoch {epoch}/{n_epochs}")
    print("-" * 60)

    # training loop
    for batch_idx, (inputs, labels) in enumerate(tqdm(dataloaders['train'], desc="Training", leave=False)):
        inputs = move_to_device(inputs, device)
        labels = move_to_device(labels, device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        if torch.isnan(loss):
            print(f"NaN loss detected at batch {batch_idx}")
            break

        loss.backward()

        ## DEBUGGING ##
        # Check gradients
        with torch.no_grad():
            total_norm = 0
            for name, param in model.named_parameters():
                if param.grad is not None:
                    param_norm = param.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
                    if torch.isnan(param.grad).any():
                        print(f"NaN in gradient of {name}")
                    if param_norm.item() > 1e5:
                        print(f"HUGE gradient in {name}: {param_norm.item():.3e}")
            total_norm = total_norm ** 0.5
            print(f"Gradient norm: {total_norm:.3e}")

        # Optionally: predict what the update would do
        with torch.no_grad():
            for name, param in model.named_parameters():
                if param.grad is None: continue
                delta = -optimizer.param_groups[0]['lr'] * param.grad  # approximate for SGD-like step
                param_norm = param.data.norm().item()
                delta_norm = delta.norm().item()
                if delta_norm > 0:
                    ratio = delta_norm / (param_norm + 1e-12)
                else:
                    ratio = 0.0
                if ratio > 1e3 or delta_norm > 1e-1:   # tune thresholds
                    print(f"{name}: param_norm={param_norm:.3e}, delta_norm={delta_norm:.3e}, ratio={ratio:.3e}")

        ## END DEBUGGING ##

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        train_loss += loss.item() * inputs['convergent'].size(0)
        _, preds = torch.max(outputs, 1)
        train_acc += torch.sum(preds == labels.data)

        # print progress every N batches
        if (batch_idx + 1) % 3 == 0 or (batch_idx + 1) == len(dataloaders['train']):
            avg_loss = train_loss / ((batch_idx + 1) * inputs['convergent'].size(0))
            avg_acc = train_acc.double() / ((batch_idx + 1) * inputs['convergent'].size(0))
            tqdm.write(f"[Batch {batch_idx+1}/{len(dataloaders['train'])}] "
                       f"Train Loss: {avg_loss:.4f}, Train Acc: {avg_acc:.4f}")

    # end of training epoch
    epoch_train_loss = train_loss / len(dataloaders['train'].dataset)
    epoch_train_acc = train_acc.double() / len(dataloaders['train'].dataset)

    # validation loop
    model.eval()
    val_loss = 0.0
    val_acc = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloaders['val'], desc="Validating", leave=False):
            inputs = move_to_device(inputs, device)
            labels = move_to_device(labels, device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs['convergent'].size(0)
            _, preds = torch.max(outputs, 1)
            val_acc += torch.sum(preds == labels.data)

    epoch_val_loss = val_loss / len(dataloaders['val'].dataset)
    epoch_val_acc = val_acc.double() / len(dataloaders['val'].dataset)
    epoch_time = time.time() - epoch_start

    print(f"Epoch {epoch} Summary:")
    print(f"  Train Loss: {epoch_train_loss:.4f} | Train Acc: {epoch_train_acc:.4f}")
    print(f"  Validation Loss: {epoch_val_loss:.4f} | Validation Acc: {epoch_val_acc:.4f}")
    print(f"  Time: {epoch_time:.2f} sec")
    print("=" * 60)


Starting training for 5 epochs on device: cpu

Epoch 1/5
------------------------------------------------------------


Training:   0%|          | 0/11 [00:00<?, ?it/s]

resnet_conv_output: mean -2.132e-02, std 4.390e-01, min -2.594e+01, max 1.971e+01


Training:   9%|▉         | 1/11 [01:10<11:42, 70.26s/it]

Gradient norm: 2.993e+00
conv_msresnet.layer3x3_3.0.bn2.bias: param_norm=0.000e+00, delta_norm=1.843e-09, ratio=1.843e+03
conv_msresnet.layer3x3_3.0.downsample.1.bias: param_norm=0.000e+00, delta_norm=1.843e-09, ratio=1.843e+03
conv_msresnet.layer5x5_3.0.bn2.bias: param_norm=0.000e+00, delta_norm=1.707e-09, ratio=1.707e+03
conv_msresnet.layer5x5_3.0.downsample.1.bias: param_norm=0.000e+00, delta_norm=1.707e-09, ratio=1.707e+03
conv_msresnet.layer7x7_3.0.bn2.bias: param_norm=0.000e+00, delta_norm=1.787e-09, ratio=1.787e+03
conv_msresnet.layer7x7_3.0.downsample.1.bias: param_norm=0.000e+00, delta_norm=1.787e-09, ratio=1.787e+03
div_msresnet.layer3x3_3.0.bn2.bias: param_norm=0.000e+00, delta_norm=1.198e-09, ratio=1.198e+03
div_msresnet.layer3x3_3.0.downsample.1.bias: param_norm=0.000e+00, delta_norm=1.198e-09, ratio=1.198e+03
div_msresnet.layer5x5_3.0.bn2.bias: param_norm=0.000e+00, delta_norm=1.318e-09, ratio=1.318e+03
div_msresnet.layer5x5_3.0.downsample.1.bias: param_norm=0.000e+00, de

Training:  18%|█▊        | 2/11 [02:22<10:43, 71.52s/it]

Gradient norm: 2.234e+00
resnet_conv_output: mean -2.204e-02, std 4.468e-01, min -1.906e+01, max 1.968e+01


Training:  27%|██▋       | 3/11 [03:32<09:25, 70.73s/it]

Gradient norm: 2.420e+00
[Batch 3/11] Train Loss: 0.6858, Train Acc: 0.5208
resnet_conv_output: mean -2.160e-02, std 4.401e-01, min -1.338e+01, max 9.473e+00


Training:  36%|███▋      | 4/11 [04:41<08:10, 70.02s/it]

Gradient norm: 1.165e+00
resnet_conv_output: mean -2.031e-02, std 4.459e-01, min -5.998e+01, max 6.394e+01


Training:  45%|████▌     | 5/11 [05:47<06:50, 68.46s/it]

Gradient norm: 2.552e+00
resnet_conv_output: mean -1.162e-02, std 3.926e-01, min -5.539e+01, max 5.396e+01


Training:  55%|█████▍    | 6/11 [07:07<06:02, 72.51s/it]

Gradient norm: 1.670e+00
[Batch 6/11] Train Loss: 0.6925, Train Acc: 0.5521
resnet_conv_output: mean -1.865e-02, std 4.255e-01, min -2.844e+01, max 3.070e+01


Training:  64%|██████▎   | 7/11 [08:18<04:47, 71.99s/it]

Gradient norm: 2.373e+00
resnet_conv_output: mean -2.208e-02, std 4.277e-01, min -1.937e+01, max 1.496e+01
