# Transformer with Squeeze-and-Excitation Training

This notebook implements a dual-input neural network that combines:
- **Convolutional Neural Networks** with Squeeze-and-Excitation (SE) blocks for feature extraction
- **Transformer encoders** for sequential pattern learning
- **Cross-validation** for robust performance evaluation

## Model Architecture
- **Dual Input Processing**: Handles coherence/SCC data and RBP features separately
- **SE Blocks**: Channel attention mechanisms for improved feature learning
- **Transformer Layers**: Self-attention for capturing long-range dependencies
- **Multi-modal Fusion**: Combines features from both input streams

## Key Features
- 10-fold stratified cross-validation
- Early stopping to prevent overfitting
- Comprehensive performance metrics
- Training curve visualization
- Statistical analysis across folds

## Expected Performance
This model showed good performance in initial experiments with:
- High accuracy on Alzheimer's vs Control classification
- Robust cross-validation results
- Effective fusion of multiple EEG feature types

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import torch
import os

# Neural network components
import torch.nn as nn
import torch.optim as optim
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import torch.nn.functional as F

# Machine learning utilities
from sklearn.metrics import classification_report

In [None]:
# Load participant information and create train/test splits
subjects_info = pd.read_csv('/kaggle/input/open-nuro-dataset/dataset/participants.tsv', delimiter='\t')
from sklearn.model_selection import train_test_split

# Focus on Alzheimer's (A) and Control (C) groups
groups = ["A", "C"]
train_dfs = []
test_dfs = []
total = []

# Split each group while maintaining gender balance
for group in groups:
    group_df = subjects_info[subjects_info['Group'] == group]
    train_group, test_group = train_test_split(
        group_df, test_size=0.3, stratify=group_df['Gender'], random_state=42
    )
    total.append(group_df)
    train_dfs.append(train_group)
    test_dfs.append(test_group)

# Combine all splits
train_df = pd.concat(train_dfs)
test_df = pd.concat(test_dfs)
total_df = pd.concat(total)

# Extract subject IDs
training_subjects = train_df['participant_id'].str.extract(r'sub-(\d+)').astype(int).squeeze().unique().tolist()
testing_subjects = test_df['participant_id'].str.extract(r'sub-(\d+)').astype(int).squeeze().unique().tolist()
total_subjects = total_df['participant_id'].str.extract(r'sub-(\d+)').astype(int).squeeze().unique().tolist()

print("Training Subjects:", training_subjects)
print("Testing Subjects:", testing_subjects)
print("Total Subjects:", total_subjects)

In [None]:
def load_data(subjects_info, subjects, data_type='training'):
    """
    Load EEG data for specified subjects including coherence, RBP, and SCC features.
    
    Args:
        subjects_info: DataFrame containing subject information
        subjects: List of subject IDs to load
        data_type: Type of data loading (for logging purposes)
    
    Returns:
        Tuple of (coherences, rbps, scc, labels, groups)
    """
    groups = []
    coherences = []
    rbps = []
    scc = []
    labels = []
    
    output_folder = "/kaggle/input/fork-of-extraction-cleaned"
    
    for idx in subjects:
        # Define file paths for different feature types
        file_path_rbp = os.path.join(output_folder, f'rbp/rbp_{idx}.npy')
        file_path_coherence = os.path.join(output_folder, f'coherence/coherence_{idx}.npy')
        file_path_scc = os.path.join(output_folder, f'scc_cleaned_base/sub-{idx}_epochs.npy')
        
        # Load feature data
        subject_data_coherence = np.load(file_path_coherence)
        coherences.append(subject_data_coherence)
        
        subject_rbp = np.load(file_path_rbp)
        rbps.append(subject_rbp)
        
        subject_scc = np.load(file_path_scc)
        scc.append(subject_scc)
        
        # Create labels for each epoch
        num_epochs = subject_rbp.shape[0]
        subject_id = f"sub-{str(idx).zfill(3)}"
        group_info = subjects_info[subjects_info['participant_id'] == subject_id]['Group'].values[0]
        
        # Extend labels and groups for all epochs of this subject
        labels.extend([group_info] * num_epochs)
        groups.extend([idx] * num_epochs)
        
    return (np.concatenate(coherences, axis=0), 
            np.concatenate(rbps, axis=0),
            np.concatenate(scc, axis=0),
            labels, 
            groups)

In [None]:
# Load subject information
subjects_info = pd.read_csv('/kaggle/input/open-nuro-dataset/dataset/participants.tsv', delimiter='\t')

In [None]:
# Load all data for training
coherences_total, rbps_total, scc_total, total_labels, groups_total = load_data(
    subjects_info, total_subjects, data_type='total'
)

In [None]:
# Check data dimensions
print("Coherences shape:", coherences_total.shape)
print("RBP shape:", rbps_total.shape)
print("SCC shape:", scc_total.shape)

In [None]:
from sklearn.preprocessing import OneHotEncoder

# Define label mapping and encode labels
label_mapping = {'A': 0, 'C': 1, 'F': 2}  # Alzheimer's, Control, Frontotemporal

# Map string labels to numeric
numeric_labels = pd.Series(total_labels).map(label_mapping)

# Create and fit one-hot encoder
encoder = OneHotEncoder(categories='auto', sparse=False)
encoder.fit(numeric_labels.values.reshape(-1, 1))

# Encode labels to one-hot format
total_labels_encoded = encoder.transform(numeric_labels.values.reshape(-1, 1))

print("Total labels shape:", total_labels_encoded.shape)
print("Label distribution:", pd.Series(total_labels).value_counts())

In [None]:
from torch.utils.data import TensorDataset, DataLoader

# Convert data to PyTorch tensors
coherence_tensor_total = torch.tensor(scc_total, dtype=torch.float32).unsqueeze(1)
rbps_tensor_total = torch.tensor(rbps_total, dtype=torch.float32).unsqueeze(1)
labels_tensor_total = torch.tensor(total_labels_encoded, dtype=torch.float32)

# Create dataset and dataloader
dataset_total = TensorDataset(coherence_tensor_total, rbps_tensor_total, labels_tensor_total)
batch_size = 64
train_loader = DataLoader(dataset_total, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
import math

class SEBlock(nn.Module):
    """Squeeze-and-Excitation Block for channel attention"""
    
    def __init__(self, in_channels, reduction_ratio=8):
        super(SEBlock, self).__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction_ratio, in_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        batch_size, channels, _, _ = x.size()
        out = self.squeeze(x).view(batch_size, channels)
        out = self.excitation(out).view(batch_size, channels, 1, 1)
        return x * out.expand_as(x)

class ConvNet(nn.Module):
    """
    Dual-input CNN with Squeeze-and-Excitation blocks and Transformer layers.
    Processes two input modalities: coherence data and RBP features.
    """
    
    def __init__(self, dropout_rate=0.15, num_classes=2):
        super(ConvNet, self).__init__()
        
        # First input path (coherence/SCC data)
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.se1 = SEBlock(64)
        self.conv2 = nn.Conv2d(64, 16, kernel_size=3, padding=1)
        self.se2 = SEBlock(16)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 42, 128)
        self.dropout1 = nn.Dropout(dropout_rate)

        # Transformer for first path
        self.transformer_layer1 = nn.TransformerEncoderLayer(d_model=128, nhead=8)
        self.transformer_encoder1 = nn.TransformerEncoder(self.transformer_layer1, num_layers=2)
        self.dropout2 = nn.Dropout(dropout_rate)

        # Second input path (RBP features)
        self.conv3 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.se3 = SEBlock(32)
        self.conv4 = nn.Conv2d(32, 16, kernel_size=3, padding=1)
        self.se4 = SEBlock(16)
        self.fc3 = nn.Linear(16 * 4, 128)
        self.dropout3 = nn.Dropout(dropout_rate)

        # Transformer for second path
        self.transformer_layer2 = nn.TransformerEncoderLayer(d_model=128, nhead=8)
        self.transformer_encoder2 = nn.TransformerEncoder(self.transformer_layer2, num_layers=1)
        self.dropout4 = nn.Dropout(dropout_rate)

        # Final classification layers
        self.fc2 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, num_classes)

    def forward(self, x1, x2):
        # Process first input (coherence/SCC)
        x1 = self.pool(F.relu(self.se1(self.conv1(x1))))
        x1 = self.pool(F.relu(self.se2(self.conv2(x1))))
        x1 = x1.view(-1, 16 * 42)
        x1 = F.relu(self.fc1(x1))
        x1 = self.dropout1(x1)

        # Add positional encoding and apply transformer
        seq_len = x1.size(0)
        pos_encoding = self.get_positional_encoding(seq_len, 128).to(x1.device)
        x1 = x1 + pos_encoding
        x1 = x1.unsqueeze(1)
        x1 = self.transformer_encoder1(x1)
        x1 = self.dropout2(x1.squeeze(1))

        # Process second input (RBP)
        x2 = self.pool(F.relu(self.se3(self.conv3(x2))))
        x2 = self.pool(F.relu(self.se4(self.conv4(x2))))
        x2 = x2.view(-1, 16 * 4)
        x2 = F.relu(self.fc3(x2))
        x2 = self.dropout3(x2)

        # Add positional encoding and apply transformer
        seq_len = x2.size(0)
        pos_encoding = self.get_positional_encoding(seq_len, 128).to(x2.device)
        x2 = x2 + pos_encoding
        x2 = x2.unsqueeze(1)
        x2 = self.transformer_encoder2(x2)
        x2 = self.dropout4(x2.squeeze(1))

        # Combine features and classify
        x = torch.cat((x1, x2), dim=1)
        x = F.relu(self.fc2(x))
        x = self.fc4(x)
        return x

    def get_positional_encoding(self, seq_len, d_model):
        """Generate sinusoidal positional encoding"""
        pos_encoding = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len).float().unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)
        return pos_encoding

In [None]:
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt

# Initialize cross-validation
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
inner_skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

# Convert multi-class labels to binary for stratification
binary_labels = np.any(total_labels_encoded, axis=1).astype(int)
group_indices = skf.split(np.arange(len(dataset_total)), groups_total)

# Training configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 256
num_epochs = 60
patience = 25  # Early stopping patience

# Storage for results
total_y_true = []
total_y_pred = []
fold_metrics = []

print(f"Training on device: {device}")
print(f"Starting {skf.n_splits}-fold cross-validation")

for fold, (train_indices, test_indices) in enumerate(group_indices, 1):
    print(f"\n=== Fold {fold}/{skf.n_splits} ===")
    
    # Create data splits
    train_data = [dataset_total[i] for i in train_indices]
    test_data = [dataset_total[i] for i in test_indices]
    
    # Inner split for validation
    train_split_indices, val_split_indices = next(inner_skf.split(
        np.arange(len(train_data)), 
        [groups_total[i] for i in train_indices]
    ))
    train_subset = [train_data[i] for i in train_split_indices]
    val_subset = [train_data[i] for i in val_split_indices]
    
    # Create data loaders
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    # Initialize model and optimizer
    model = ConvNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Early stopping variables
    best_val_loss = float('inf')
    early_stop_counter = 0
    best_model_state = None
    train_loss_list = []
    val_loss_list = []
    
    # Training loop
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        for inputs1, inputs2, labels in train_loader:
            inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs1, inputs2)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs1, inputs2, labels in val_loader:
                inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
                outputs = model(inputs1, inputs2)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
        # Calculate average losses
        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_loss_list.append(avg_train_loss)
        val_loss_list.append(avg_val_loss)
        
        print(f"Epoch {epoch+1:3d}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        
        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict()
            early_stop_counter = 0
        else:
            early_stop_counter += 1
        
        if early_stop_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Load best model and evaluate on test set
    model.load_state_dict(best_model_state)
    model.eval()
    
    y_true, y_pred = [], []
    with torch.no_grad():
        for inputs1, inputs2, labels in test_loader:
            inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
            outputs = model(inputs1, inputs2)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(labels.cpu().tolist())
            y_pred.extend(predicted.cpu().tolist())
    
    # Calculate metrics
    y_true_labels = np.argmax(y_true, axis=1)
    acc = accuracy_score(y_true_labels, y_pred)
    sens = recall_score(y_true_labels, y_pred, average='macro')
    spec = recall_score(y_true_labels, y_pred, average='macro', pos_label=0)
    prec = precision_score(y_true_labels, y_pred, average='macro')
    f1 = f1_score(y_true_labels, y_pred, average='macro')
    confusion_mat = confusion_matrix(y_true_labels, y_pred)
    
    # Store fold results
    fold_metrics.append({
        "Fold": fold,
        "Accuracy": acc,
        "Sensitivity": sens,
        "Specificity": spec,
        "Precision": prec,
        "F1 Score": f1,
        "Recall": sens,
        "Confusion Matrix": confusion_mat
    })
    
    # Add to total results
    total_y_true.extend(y_true)
    total_y_pred.extend(y_pred)
    
    # Print fold results
    print(f"Fold {fold} Results:")
    print(f"  Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {sens:.4f}, F1: {f1:.4f}")
    
    # Plot training curves
    plt.figure(figsize=(8, 5))
    plt.plot(train_loss_list, label='Train Loss', color='blue')
    plt.plot(val_loss_list, label='Validation Loss', color='red')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Training Curves - Fold {fold}')
    plt.legend()
    plt.grid(True)
    plt.show()

# Calculate overall metrics
total_y_true_labels = np.argmax(total_y_true, axis=1)
overall_acc = accuracy_score(total_y_true_labels, total_y_pred)
overall_prec = precision_score(total_y_true_labels, total_y_pred, average='macro')
overall_rec = recall_score(total_y_true_labels, total_y_pred, average='macro')
overall_f1 = f1_score(total_y_true_labels, total_y_pred, average='macro')
overall_confusion = confusion_matrix(total_y_true_labels, total_y_pred)

print(f"\n=== Overall Results Across All Folds ===")
print(f"Accuracy: {overall_acc:.4f}")
print(f"Precision: {overall_prec:.4f}")
print(f"Recall: {overall_rec:.4f}")
print(f"F1 Score: {overall_f1:.4f}")
print(f"Confusion Matrix:\n{overall_confusion}")

# Print detailed classification report
report = classification_report(total_y_true_labels, total_y_pred)
print(f"\nDetailed Classification Report:\n{report}")

# Calculate statistics across folds
metrics_array = np.array([[m['Accuracy'], m['Sensitivity'], m['Specificity'], 
                          m['Precision'], m['F1 Score'], m['Recall']] for m in fold_metrics])
metrics_mean = np.mean(metrics_array, axis=0)
metrics_std = np.std(metrics_array, axis=0)
metric_names = ['Accuracy', 'Sensitivity', 'Specificity', 'Precision', 'F1 Score', 'Recall']

print(f"\n=== Cross-Validation Statistics ===")
print(f"{'Metric':<12} {'Mean ± STD':<15}")
print("-" * 27)
for name, mean_val, std_val in zip(metric_names, metrics_mean, metrics_std):
    print(f"{name:<12} {mean_val:.4f} ± {std_val:.4f}")

In [None]:
import seaborn as sns

# Create confusion matrix visualization
plt.figure(figsize=(8, 6))
sns.heatmap(overall_confusion, 
            annot=True, 
            cmap="Blues", 
            fmt="d", 
            xticklabels=['Alzheimer\'s', 'Control'], 
            yticklabels=['Alzheimer\'s', 'Control'],
            cbar_kws={'label': 'Count'})

plt.title("Overall Confusion Matrix", fontsize=14, fontweight='bold')
plt.xlabel("Predicted Labels", fontsize=12)
plt.ylabel("True Labels", fontsize=12)
plt.tight_layout()
plt.show()

# Additional performance visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Metrics across folds
metrics_df = pd.DataFrame(fold_metrics)
metrics_to_plot = ['Accuracy', 'Precision', 'F1 Score', 'Recall']
metrics_df[metrics_to_plot].boxplot(ax=ax1)
ax1.set_title('Performance Metrics Distribution Across Folds')
ax1.set_ylabel('Score')
ax1.grid(True, alpha=0.3)

# Fold-wise accuracy
ax2.plot(range(1, len(fold_metrics) + 1), [m['Accuracy'] for m in fold_metrics], 
         marker='o', linewidth=2, markersize=6)
ax2.axhline(y=overall_acc, color='red', linestyle='--', label=f'Overall Mean: {overall_acc:.3f}')
ax2.set_title('Accuracy Across Folds')
ax2.set_xlabel('Fold')
ax2.set_ylabel('Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()