# IMPACT Example Usage

This notebook demonstrates how to use the IMPACT framework for analyzing resting-state fMRI data in Parkinson's Disease.

In [None]:
import os
import sys
import numpy as np
import torch
from pathlib import Path

# Add IMPACT to path
sys.path.append('..')

from impact.data.preprocessor import FMRIPreprocessor
from impact.data.loader import IMPACTDataLoader
from impact.models.impact import IMPACTModel

## 1. Download and Prepare Data

First, we need to download the PPMI dataset from NITRC. You can find the dataset at:
https://fcon_1000.projects.nitrc.org/indi/retro/parkinsons.html

After downloading, organize your data directory as follows:
```
data/
├── raw/
│   ├── sub-001/
│   │   └── func/
│   │       └── sub-001_task-rest_bold.nii.gz
│   ├── sub-002/
│   └── ...
└── metadata.json
```

The metadata.json file should contain subject information including diagnosis labels.

In [None]:
# Set up paths
data_dir = Path('../data')
raw_dir = data_dir / 'raw'
processed_dir = data_dir / 'processed'

# Create preprocessor
preprocessor = FMRIPreprocessor(
    atlas='harvard-oxford',
    n_ica_components=5,
    window_size=50,
    window_stride=25
)

## 2. Preprocess Data

Now we'll preprocess the raw fMRI data to extract ROI time series, ICA components, and dynamic connectivity matrices.

In [None]:
# Process each subject
for subject_dir in raw_dir.glob('sub-*'):
    # Find fMRI file
    fmri_file = next(subject_dir.rglob('*_bold.nii.gz'))
    
    print(f'Processing {subject_dir.name}...')
    
    # Process subject
    results = preprocessor.process_subject(fmri_file)
    
    # Save results
    output_dir = processed_dir / subject_dir.name
    output_dir.mkdir(parents=True, exist_ok=True)
    
    np.save(output_dir / 'roi_timeseries.npy', results['roi_timeseries'])
    np.save(output_dir / 'ica_timeseries.npy', results['ica_timeseries'])
    np.save(output_dir / 'connectivity_matrices.npy', results['connectivity_matrices'])

## 3. Train Model

Now we can train the IMPACT model on our preprocessed data.

In [None]:
# Set up data loader
data_loader = IMPACTDataLoader(
    data_dir=processed_dir,
    batch_size=8,
    num_workers=4
)

train_loader, val_loader, test_loader = data_loader.get_dataloaders()

# Create model
model = IMPACTModel(
    roi_dim=train_loader.dataset.roi_data.size(-1),
    ica_dim=train_loader.dataset.ica_data.size(-1),
    embed_dim=256,
    n_heads=4,
    n_layers=3
)

# Set up training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Training loop
n_epochs = 100
best_val_loss = float('inf')

for epoch in range(n_epochs):
    # Train
    model.train()
    for batch in train_loader:
        inputs = {k: v.to(device) for k, v in batch[0].items()}
        labels = batch[1].to(device)
        
        optimizer.zero_grad()
        logits, _ = model(inputs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
    
    # Validate
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            inputs = {k: v.to(device) for k, v in batch[0].items()}
            labels = batch[1].to(device)
            
            logits, _ = model(inputs)
            val_loss += criterion(logits, labels).item()
    
    val_loss /= len(val_loader)
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss:.4f}')
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pt')

## 4. Evaluate Model

Finally, let's evaluate our trained model and visualize the results.

In [None]:
from impact.evaluate import evaluate_model

# Load best model
model.load_state_dict(torch.load('best_model.pt'))

# Create output directory
output_dir = Path('evaluation_results')
output_dir.mkdir(exist_ok=True)

# Evaluate model
metrics, attention_weights = evaluate_model(
    model,
    test_loader,
    device,
    output_dir
)

print('\nTest Metrics:')
for name, value in metrics.items():
    print(f'{name}: {value:.4f}')