## Setup and Imports

In [12]:
import sys
import os

# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath('__file__'))))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Import pipeline modules
from data.preprocessing import preprocess_scan
from data.dataset import SwiFTPretrainDataset, SwiFTFinetuneDataset, create_dummy_labels
from models.swin4d_transformer_ver7 import SwinTransformer4D
from models.heads import ContrastiveHead, ClassificationHead
from training.losses import NTXentLoss
from configs.config_pretrain import (
    MODEL_CONFIG,
    CONTRASTIVE_CONFIG,
    TRAIN_CONFIG,
    DATA_CONFIG,
)

print("✓ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

✓ All imports successful!
PyTorch version: 2.0.1+cpu
CUDA available: False


## TEST 1: Data Preprocessing

Test the data preprocessing pipeline with a dummy scan

In [13]:
print("=" * 80)
print("TEST 1: Data Preprocessing")
print("=" * 80)

TEST 1: Data Preprocessing


In [14]:
# Create dummy scan [1, 91, 109, 91, 140]
print("Creating dummy scan with shape [1, 91, 109, 91, 140]...")
dummy_scan = torch.randn(1, 91, 109, 91, 140)

# Add some "brain-like" structure (non-zero regions in center)
# This simulates brain regions with higher values than background
dummy_scan[:, 10:85, 15:95, 10:85, :] += 5.0

print(f"✓ Created dummy scan: {dummy_scan.shape}")
print(f"  - Simulated brain region: H[10:85], W[15:95], D[10:85]")

Creating dummy scan with shape [1, 91, 109, 91, 140]...
✓ Created dummy scan: torch.Size([1, 91, 109, 91, 140])
  - Simulated brain region: H[10:85], W[15:95], D[10:85]
✓ Created dummy scan: torch.Size([1, 91, 109, 91, 140])
  - Simulated brain region: H[10:85], W[15:95], D[10:85]


In [15]:
# Preprocess (NOW returns 3 values: data, indices, metadata)
print("Running preprocessing pipeline...")
preprocessed, indices, metadata = preprocess_scan(
    dummy_scan,
    target_spatial_size=DATA_CONFIG["target_spatial_size"],
    window_size=DATA_CONFIG["window_size"],
    stride=DATA_CONFIG["stride"],
    normalize=DATA_CONFIG["normalize"],
    to_float16=False,  # Keep float32 for validation
    crop_background=True,  # Use brain-aware cropping
)

print(f"\n✓ Preprocessing complete!")
print(f"  - Output shape: {preprocessed.shape}")
print(f"  - Number of windows: {len(indices)}")
print(f"  - Background value: {metadata['background_value']:.6f}")
print(f"  - Brain bbox: {metadata['brain_bbox']}")
print(f"  - Current spatial size: {metadata['current_spatial_size']}")
print(f"  - Needs padding: {metadata['needs_padding']}")
print(f"  - Data range: [{preprocessed.min():.3f}, {preprocessed.max():.3f}]")
print(f"  - Data stats: mean={preprocessed.mean():.3f}, std={preprocessed.std():.3f}")

print(f"\nℹ️  Temporal windowing explanation:")
print(f"  - Original time points: 140")
print(f"  - Window size: {DATA_CONFIG['window_size']}")
print(f"  - Stride: {DATA_CONFIG['stride']}")
print(f"  - Windows overlap: YES (stride < window_size)")
print(f"  - Number of windows: (140 - 20) / 10 + 1 = {len(indices)}")
print(f"  - Formula: (T - window_size) / stride + 1")

Running preprocessing pipeline...
Input shape: torch.Size([1, 91, 109, 91, 140])
After adding channel: torch.Size([1, 1, 91, 109, 91, 140])
Detected brain regions: 91x109x91
Background value: 0.000001
  Cropped to brain regions: H[0:91]=91, W[0:109]=109, D[0:91]=91
After brain-aware cropping: torch.Size([91, 109, 91])
After normalization - background value updated to: -2.948255
After windowing: torch.Size([13, 1, 91, 109, 91, 20]), 13 windows

✓ Preprocessing complete!
  - Output shape: torch.Size([13, 1, 91, 109, 91, 20])
  - Number of windows: 13
  - Background value: -2.948255
  - Brain bbox: {'height': (0, 91), 'width': (0, 109), 'depth': (0, 91)}
  - Current spatial size: torch.Size([91, 109, 91])
  - Needs padding: (5, -13, 5)
After normalization - background value updated to: -2.948255
After windowing: torch.Size([13, 1, 91, 109, 91, 20]), 13 windows

✓ Preprocessing complete!
  - Output shape: torch.Size([13, 1, 91, 109, 91, 20])
  - Number of windows: 13
  - Background value: 

In [16]:
# Verify output shape
# Note: Spatial dimensions may NOT be exactly 96x96x96 yet (padding happens in Dataset)
assert preprocessed.shape[1] == 1, f"Expected 1 channel, got {preprocessed.shape[1]}"
assert preprocessed.shape[-1] == 20, f"Expected 20 time points, got {preprocessed.shape[-1]}"
assert len(indices) > 0, "No windows created!"

print("✓ Basic shape assertions passed!")
print(f"  Note: Spatial dimensions are {preprocessed.shape[2:5]}")
print(f"  Final padding to 96x96x96 will happen in Dataset class")

✓ Basic shape assertions passed!
  Note: Spatial dimensions are torch.Size([91, 109, 91])
  Final padding to 96x96x96 will happen in Dataset class


### Understanding Temporal Windows

**Why overlapping windows?**

SwiFT uses overlapping temporal windows to:
1. **Increase training samples**: Get more data from limited scans
2. **Capture temporal dynamics**: Overlapping windows help model learn smooth transitions
3. **Standard practice**: Most video/time-series models use overlapping windows

**Example with your data:**
- Original: 140 time points
- Window size: 20 time points
- Stride: 10 time points

Windows created:
- Window 0: frames [0-19]
- Window 1: frames [10-29] ← overlaps with Window 0 by 10 frames
- Window 2: frames [20-39] ← overlaps with Window 1 by 10 frames
- ...
- Window 12: frames [120-139]

Total: (140 - 20) / 10 + 1 = 13 windows ✓

## TEST 2: Model Forward Pass

Test the SwiFT model with a forward pass

In [17]:
print("=" * 80)
print("TEST 2: Model Forward Pass")
print("=" * 80)

TEST 2: Model Forward Pass


In [18]:
# Create model
print("Initializing SwiFT model...")
model = SwinTransformer4D(**MODEL_CONFIG)
print(f"✓ Model created")

# Print model architecture summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"  - Total parameters: {total_params:,}")
print(f"  - Trainable parameters: {trainable_params:,}")

Initializing SwiFT model...
img_size:  (96, 96, 96, 20)
patch_size:  (6, 6, 6, 1)
patch_dim:  (16, 16, 16, 20)
✓ Model created
  - Total parameters: 4,315,212
  - Trainable parameters: 4,315,212
✓ Model created
  - Total parameters: 4,315,212
  - Trainable parameters: 4,315,212


In [19]:
# Create dummy input
batch_size = 2
dummy_input = torch.randn(batch_size, 1, 96, 96, 96, 20)
print(f"Testing forward pass with input shape: {dummy_input.shape}")

Testing forward pass with input shape: torch.Size([2, 1, 96, 96, 96, 20])


In [20]:
# Forward pass
model.eval()
with torch.no_grad():
    output = model(dummy_input)

print(f"✓ Forward pass successful!")
print(f"  - Output shape: {output.shape}")
expected_dim = MODEL_CONFIG['embed_dim'] * (MODEL_CONFIG['c_multiplier'] ** 3)
print(f"  - Expected feature dim: {expected_dim}")

✓ Forward pass successful!
  - Output shape: torch.Size([2, 288, 2, 2, 2, 20])
  - Expected feature dim: 288


## TEST 3: Contrastive Pretraining

Pretrain the model using contrastive learning (self-supervised, no labels needed)

In [21]:
print("=" * 80)
print("TEST 3: Contrastive Pretraining")
print("=" * 80)

print("Stage 1: Pretrain backbone with contrastive learning (unlabeled data)")
print("=" * 80)

TEST 3: Contrastive Pretraining
Stage 1: Pretrain backbone with contrastive learning (unlabeled data)


In [22]:
# Create dummy preprocessed data (simulating output after brain-aware cropping)
num_windows = 12
# Note: Using slightly smaller spatial dims to simulate brain-cropped data
data = torch.randn(num_windows, 1, 90, 92, 88, 20)
indices = torch.tensor([[0, i * 10] for i in range(num_windows)])

# Create metadata (simulating preprocessing output)
metadata = {
    'background_value': -2.5,
    'brain_bbox': {'height': (5, 85), 'width': (8, 100), 'depth': (5, 83)},
    'needs_padding': (6, 4, 8),
    'current_spatial_size': (90, 92, 88)
}

print(f"Created dummy data: {data.shape}")
print(f"Created indices: {indices.shape}")
print(f"Background value for padding: {metadata['background_value']}")

Created dummy data: torch.Size([12, 1, 90, 92, 88, 20])
Created indices: torch.Size([12, 2])
Background value for padding: -2.5


In [23]:
# Create contrastive dataset (NOW with metadata for padding)
print("Creating contrastive dataset...")
dataset = SwiFTPretrainDataset(
    data, 
    indices,
    metadata=metadata,  # Pass metadata for background-value padding
    target_spatial_size=(96, 96, 96)
)
dataloader = DataLoader(
    dataset, batch_size=TRAIN_CONFIG["batch_size"], shuffle=True
)
print(f"✓ Dataset created: {len(dataset)} samples")
print(f"  - Batch size: {TRAIN_CONFIG['batch_size']}")
print(f"  - Padding will be applied during data loading")

Creating contrastive dataset...
  Dataset will pad from torch.Size([90, 92, 88]) to (96, 96, 96)
  Using background value: -2.500000
Contrastive dataset: 12 valid samples from 1 scans
✓ Dataset created: 12 samples
  - Batch size: 4
  - Padding will be applied during data loading


In [24]:
# Create model and contrastive head
print("Initializing model and contrastive head...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
contrastive_model = SwinTransformer4D(**MODEL_CONFIG).to(device)
contrastive_head = ContrastiveHead(**CONTRASTIVE_CONFIG).to(device)
print(f"✓ Models initialized on {device}")

Initializing model and contrastive head...
img_size:  (96, 96, 96, 20)
patch_size:  (6, 6, 6, 1)
patch_dim:  (16, 16, 16, 20)
✓ Models initialized on cpu
✓ Models initialized on cpu


In [25]:
# Create loss and optimizer
criterion = NTXentLoss(
    device=device,
    batch_size=TRAIN_CONFIG["batch_size"],
    temperature=TRAIN_CONFIG["temperature"],
    use_cosine_similarity=TRAIN_CONFIG["use_cosine_similarity"],
)
optimizer = optim.AdamW(
    list(contrastive_model.parameters()) + list(contrastive_head.parameters()),
    lr=TRAIN_CONFIG["learning_rate"],
)
print(f"✓ Loss and optimizer created")
print(f"  - Learning rate: {TRAIN_CONFIG['learning_rate']}")
print(f"  - Temperature: {TRAIN_CONFIG['temperature']}")

✓ Loss and optimizer created
  - Learning rate: 5e-05
  - Temperature: 0.5


In [27]:
# Run pretraining iterations (simulate multiple epochs)
print("Running pretraining (5 epochs)...")
print("Note: In real training, you'd use many more epochs (e.g., 300 in SwiFT paper)")
contrastive_model.train()
contrastive_head.train()

for epoch in range(5):
    epoch_loss = 0
    num_batches = 0
    
    for batch_idx, (view1, view2) in enumerate(dataloader):
        view1, view2 = view1.to(device), view2.to(device)
        
        # Verify shapes after padding (only first batch)
        if batch_idx == 0 and epoch == 0:
            assert view1.shape[2:5] == (96, 96, 96), f"Expected 96x96x96, got {view1.shape[2:5]}"
            assert view2.shape[2:5] == (96, 96, 96), f"Expected 96x96x96, got {view2.shape[2:5]}"
            print(f"  ✓ Shapes verified: both views are 96x96x96 after padding")

        # Forward pass
        optimizer.zero_grad()

        # Encode both views
        features1 = contrastive_model(view1)
        features2 = contrastive_model(view2)

        # Project to embedding space
        embeddings1 = contrastive_head(features1)
        embeddings2 = contrastive_head(features2)

        # Compute contrastive loss
        loss = criterion(embeddings1, embeddings2)

        # Backward pass
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        num_batches += 1
        
        if num_batches >= 2:  # Limit batches for demo
            break
    
    avg_loss = epoch_loss / num_batches
    print(f"  Epoch {epoch+1}/5: Loss = {avg_loss:.4f}")

print(f"\n✓ Pretraining complete!")
print("Backbone has learned representations from unlabeled data using contrastive learning")

Running pretraining (5 epochs)...
Note: In real training, you'd use many more epochs (e.g., 300 in SwiFT paper)
  ✓ Shapes verified: both views are 96x96x96 after padding
  ✓ Shapes verified: both views are 96x96x96 after padding
  Epoch 1/5: Loss = 1.9452
  Epoch 1/5: Loss = 1.9452
  Epoch 2/5: Loss = 1.9170
  Epoch 2/5: Loss = 1.9170


KeyboardInterrupt: 

In [None]:
# Save pretrained backbone weights
print("\nSaving pretrained weights...")
pretrained_path = "../checkpoints/pretrained_backbone.pth"
os.makedirs(os.path.dirname(pretrained_path), exist_ok=True)

torch.save({
    'model_state_dict': contrastive_model.state_dict(),
    'config': MODEL_CONFIG,
    'training_info': {
        'method': 'contrastive_pretraining',
        'final_loss': loss.item(),
    }
}, pretrained_path)

print(f"✓ Saved pretrained weights to: {pretrained_path}")
print("Note: Only the backbone is saved, not the contrastive projection head")

## TEST 4: Fine-tuning for Classification

Use the pretrained backbone to fine-tune on a supervised classification task

In [None]:
print("=" * 80)
print("TEST 4: Fine-tuning for Classification")
print("=" * 80)

print("Stage 2: Fine-tune on supervised task using pretrained features")
print("=" * 80)

In [None]:
# Import finetuning config
from configs.config_finetune import (
    TASK_CONFIG,
    HEAD_CONFIG,
    TRAIN_CONFIG as FINETUNE_TRAIN_CONFIG,
    OPTIMIZER_CONFIG,
)

print("Fine-tuning configuration:")
print(f"  - Task type: {TASK_CONFIG['task_type']}")
print(f"  - Number of classes: {TASK_CONFIG['num_classes']}")
print(f"  - Freeze encoder: {TASK_CONFIG['freeze_encoder']}")
print(f"  - Learning rate: {FINETUNE_TRAIN_CONFIG['learning_rate']}")
print(f"  - Batch size: {FINETUNE_TRAIN_CONFIG['batch_size']}")
print(f"  - Epochs: {FINETUNE_TRAIN_CONFIG['num_epochs']}")

In [None]:
# Load the pretrained backbone
print("Loading pretrained backbone weights...")
finetuned_model = SwinTransformer4D(**MODEL_CONFIG).to(device)

# Load weights from saved checkpoint
checkpoint = torch.load(pretrained_path, map_location=device)
finetuned_model.load_state_dict(checkpoint['model_state_dict'])
print(f"✓ Loaded pretrained weights from: {pretrained_path}")

# Freeze encoder weights if specified
if TASK_CONFIG['freeze_encoder']:
    print("Freezing encoder weights...")
    for param in finetuned_model.parameters():
        param.requires_grad = False
    
    trainable = sum(p.numel() for p in finetuned_model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in finetuned_model.parameters())
    print(f"  - Trainable parameters: {trainable:,} / {total:,}")
else:
    print("Encoder weights are trainable (not frozen)")
    trainable = sum(p.numel() for p in finetuned_model.parameters() if p.requires_grad)
    print(f"  - Trainable parameters: {trainable:,}")

In [None]:
# Create classification head
print("Creating classification head...")
classification_head = ClassificationHead(
    num_classes=HEAD_CONFIG['num_classes'],
    num_features=HEAD_CONFIG['num_features']
).to(device)

head_params = sum(p.numel() for p in classification_head.parameters())
print(f"✓ Classification head created")
print(f"  - Head parameters: {head_params:,}")
print(f"  - Input features: {HEAD_CONFIG['num_features']}")
print(f"  - Output classes: {HEAD_CONFIG['num_classes']}")

In [None]:
# Create finetuning dataset with labels
print("Creating finetuning dataset with labels...")

# Create dummy binary classification labels
num_samples = len(data)
labels = create_dummy_labels(num_samples, task_type='binary')
print(f"  - Created {num_samples} labels")
print(f"  - Label distribution: {labels.sum().item()} positive, {(labels == 0).sum().item()} negative")

# Split into train/val (80/20)
train_size = int(0.8 * num_samples)
train_data = data[:train_size]
train_labels = labels[:train_size]
val_data = data[train_size:]
val_labels = labels[train_size:]

print(f"  - Train samples: {len(train_data)}")
print(f"  - Val samples: {len(val_data)}")

# Create datasets (use same metadata for padding)
train_dataset = SwiFTFinetuneDataset(
    train_data, 
    train_labels,
    metadata=metadata,
    target_spatial_size=(96, 96, 96)
)

val_dataset = SwiFTFinetuneDataset(
    val_data, 
    val_labels,
    metadata=metadata,
    target_spatial_size=(96, 96, 96)
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=FINETUNE_TRAIN_CONFIG['batch_size'], 
    shuffle=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=FINETUNE_TRAIN_CONFIG['batch_size'], 
    shuffle=False
)

print(f"✓ Datasets and dataloaders created")

In [None]:
# Setup optimizer and loss function
print("Setting up optimizer and loss function...")

# Only optimize classification head parameters (encoder is frozen)
if TASK_CONFIG['freeze_encoder']:
    params_to_optimize = classification_head.parameters()
else:
    params_to_optimize = list(finetuned_model.parameters()) + list(classification_head.parameters())

finetune_optimizer = optim.AdamW(
    params_to_optimize,
    lr=OPTIMIZER_CONFIG['lr'],
    weight_decay=OPTIMIZER_CONFIG['weight_decay'],
    betas=OPTIMIZER_CONFIG['betas']
)

# Binary classification loss
finetune_criterion = nn.BCEWithLogitsLoss()

print(f"✓ Optimizer and loss function created")
print(f"  - Optimizer: AdamW")
print(f"  - Learning rate: {OPTIMIZER_CONFIG['lr']}")
print(f"  - Loss function: BCEWithLogitsLoss")

In [None]:
# Training loop for fine-tuning
print("Running fine-tuning training (10 epochs for demo)...")
print("Note: In real training, you'd use more epochs (e.g., 50-100)")

num_epochs = 10
best_val_loss = float('inf')

# Set model to appropriate mode
finetuned_model.eval()  # Keep frozen encoder in eval mode
classification_head.train()

for epoch in range(num_epochs):
    # Training phase
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = targets.float().to(device)
        
        # Forward pass
        finetune_optimizer.zero_grad()
        
        # Get features from pretrained encoder (no gradients if frozen)
        with torch.set_grad_enabled(not TASK_CONFIG['freeze_encoder']):
            features = finetuned_model(inputs)
        
        # Get predictions from classification head
        outputs = classification_head(features)
        outputs = outputs.squeeze()
        
        # Compute loss
        loss = finetune_criterion(outputs, targets)
        
        # Backward pass (only for classification head if encoder is frozen)
        loss.backward()
        finetune_optimizer.step()
        
        # Track metrics
        train_loss += loss.item()
        predictions = (torch.sigmoid(outputs) > 0.5).float()
        train_correct += (predictions == targets).sum().item()
        train_total += targets.size(0)
    
    # Validation phase
    classification_head.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.float().to(device)
            
            # Forward pass
            features = finetuned_model(inputs)
            outputs = classification_head(features)
            outputs = outputs.squeeze()
            
            # Compute loss
            loss = finetune_criterion(outputs, targets)
            
            # Track metrics
            val_loss += loss.item()
            predictions = (torch.sigmoid(outputs) > 0.5).float()
            val_correct += (predictions == targets).sum().item()
            val_total += targets.size(0)
    
    # Calculate epoch metrics
    train_loss /= len(train_loader)
    train_acc = 100 * train_correct / train_total
    val_loss /= len(val_loader)
    val_acc = 100 * val_correct / val_total
    
    print(f"  Epoch {epoch+1}/{num_epochs}: "
          f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, "
          f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch + 1
    
    # Set back to train mode for next epoch
    classification_head.train()

print(f"\n✓ Fine-tuning complete!")
print(f"  Best validation loss: {best_val_loss:.4f} at epoch {best_epoch}")

In [None]:
# Save finetuned model
print("\nSaving finetuned model...")
finetuned_checkpoint_path = "../checkpoints/finetuned_model.pth"
os.makedirs(os.path.dirname(finetuned_checkpoint_path), exist_ok=True)

torch.save({
    'encoder_state_dict': finetuned_model.state_dict(),
    'head_state_dict': classification_head.state_dict(),
    'optimizer_state_dict': finetune_optimizer.state_dict(),
    'config': {
        'model_config': MODEL_CONFIG,
        'task_config': TASK_CONFIG,
        'head_config': HEAD_CONFIG,
    },
    'training_info': {
        'best_val_loss': best_val_loss,
        'best_epoch': best_epoch,
        'final_train_loss': train_loss,
        'final_val_loss': val_loss,
    }
}, finetuned_checkpoint_path)

print(f"✓ Saved finetuned model to: {finetuned_checkpoint_path}")
print("  - Saved encoder (backbone) weights")
print("  - Saved classification head weights")
print("  - Saved optimizer state")

## Summary: Complete SwiFT Pipeline

The notebook demonstrates the complete SwiFT workflow:

### Stage 1: Contrastive Pretraining (Self-Supervised)
- **Goal**: Learn robust representations from unlabeled fMRI data
- **Method**: Contrastive learning with temporal augmentation
- **Output**: Pretrained encoder that understands fMRI patterns
- **Key advantage**: No labels needed, can use large amounts of unlabeled data

### Stage 2: Supervised Fine-tuning
- **Goal**: Adapt pretrained encoder to specific downstream task
- **Method**: Freeze encoder, train lightweight classification head
- **Output**: Task-specific classifier with strong performance
- **Key advantage**: Requires fewer labeled samples, faster training

### Benefits of this approach:
1. **Data efficiency**: Pretrain on large unlabeled dataset, fine-tune on small labeled dataset
2. **Transfer learning**: Learned features generalize across tasks
3. **Better performance**: Pretraining provides better initialization than random weights
4. **Faster convergence**: Fine-tuning converges faster than training from scratch

This is the standard approach used in many successful deep learning applications (e.g., BERT for NLP, SwiFT for fMRI).