In [15]:
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from src.datasets.seeg_dataset import SEEGDataset
from src.models.model import SEEGFusionModel
from src.utils import move_to_device
from tqdm import tqdm
import time
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [21]:
dataset = SEEGDataset(subjects=['Epat26'])
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])
dataloaders = {
    'train': DataLoader(train_ds, batch_size=4, shuffle=True),
    'val': DataLoader(val_ds, batch_size=4, shuffle=False),
}

[32m2025-11-04 15:42:52.209[0m | [32m[1mSUCCESS [0m | [36msrc.datasets.seeg_dataset[0m:[36m__init__[0m:[36m100[0m - [32m[1mâœ… Loaded 53 total samples from 1 subjects.[0m


In [22]:
print(next(iter(dataloaders['train']))[0]['convergent'].shape)
print(next(iter(dataloaders['train']))[0]['divergent'].shape)

torch.Size([4, 66, 50, 487])
torch.Size([4, 103, 50, 487])


In [23]:
model = SEEGFusionModel(embed_dim=128, n_classes=2)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [None]:
model.to(device)
n_epochs = 10

print(f"Starting training for {n_epochs} epochs on device: {device}\n{'='*60}")

for epoch in range(1, n_epochs + 1):
    epoch_start = time.time()
    train_loss = 0.0
    train_acc = 0

    model.train()
    print(f"\nEpoch {epoch}/{n_epochs}")
    print("-" * 60)

    # training loop
    for batch_idx, (inputs, labels) in enumerate(tqdm(dataloaders['train'], desc="Training", leave=False)):
        inputs = move_to_device(inputs, device)
        labels = move_to_device(labels, device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs['convergent'].size(0)
        _, preds = torch.max(outputs, 1)
        train_acc += torch.sum(preds == labels.data)

        # print progress every N batches
        if (batch_idx + 1) % 3 == 0 or (batch_idx + 1) == len(dataloaders['train']):
            avg_loss = train_loss / ((batch_idx + 1) * inputs['convergent'].size(0))
            avg_acc = train_acc.double() / ((batch_idx + 1) * inputs['convergent'].size(0))
            tqdm.write(f"[Batch {batch_idx+1}/{len(dataloaders['train'])}] "
                       f"Train Loss: {avg_loss:.4f}, Train Acc: {avg_acc:.4f}")

    # end of training epoch
    epoch_train_loss = train_loss / len(dataloaders['train'].dataset)
    epoch_train_acc = train_acc.double() / len(dataloaders['train'].dataset)

    # validation loop
    model.eval()
    valid_loss = 0.0
    valid_acc = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloaders['valid'], desc="Validating", leave=False):
            inputs = move_to_device(inputs, device)
            labels = move_to_device(labels, device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            valid_loss += loss.item() * inputs['convergent'].size(0)
            _, preds = torch.max(outputs, 1)
            valid_acc += torch.sum(preds == labels.data)

    epoch_valid_loss = valid_loss / len(dataloaders['valid'].dataset)
    epoch_valid_acc = valid_acc.double() / len(dataloaders['valid'].dataset)
    epoch_time = time.time() - epoch_start

    print(f"Epoch {epoch} Summary:")
    print(f"  Train Loss: {epoch_train_loss:.4f} | Train Acc: {epoch_train_acc:.4f}")
    print(f"  Valid Loss: {epoch_valid_loss:.4f} | Valid Acc: {epoch_valid_acc:.4f}")
    print(f"  Time: {epoch_time:.2f} sec")
    print("=" * 60)


Starting training for 10 epochs on device: cpu

Epoch 1/10
------------------------------------------------------------


                                                        

KeyboardInterrupt: 