In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tqdm import tqdm
import numpy as np
from torch_geometric.nn import SAGEConv

In [2]:


class BandSpecificGraphSAGE(nn.Module):
    """Process a single band's sequence of temporal graphs"""
    def __init__(self, in_channels=18, hidden_channels=64 #64
    , dropout_rate=0.5, num_nodes=19):
        super().__init__()
        self.dropout_rate = dropout_rate
        self.hidden_channels = hidden_channels
        self.num_nodes = num_nodes
        
        # GraphSAGE layers for spatial feature extraction
        self.sage1 = SAGEConv(in_channels, hidden_channels)
        self.bn1 = nn.BatchNorm1d(hidden_channels)
        self.sage2 = SAGEConv(hidden_channels, hidden_channels)
        self.bn2 = nn.BatchNorm1d(hidden_channels)
        
        # Temporal processing
        self.lstm = nn.LSTM(
            input_size=hidden_channels * num_nodes,
            hidden_size=hidden_channels,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout= 0.0 #if 2 > 1 else 0
        )
        
        # Attention for temporal features
        self.temporal_attention = nn.Linear(hidden_channels * 2, 1)
        
        # Normalization and dropout
        self.layer_norm1 = nn.LayerNorm(hidden_channels)
        self.layer_norm2 = nn.LayerNorm(hidden_channels * 2)
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, graphs):
        seq_embs = []
        for graph in graphs:
            # First GraphSAGE layer
            x = self.sage1(graph.x, graph.edge_index)
            x = self.bn1(x)
            x = F.relu(x)
            x = self.layer_norm1(x)
            if self.training:
                x = self.dropout(x)
            
            # Second GraphSAGE layer
            x = self.sage2(x, graph.edge_index)
            x = self.bn2(x)
            x = F.relu(x)
            if self.training:
                x = self.dropout(x)
            
            # Flatten for temporal processing
            x_flat = x.reshape(1, -1)
            seq_embs.append(x_flat)
        
        if len(seq_embs) > 0:
            seq_tensor = torch.cat(seq_embs, dim=0).unsqueeze(0)
            lstm_out, _ = self.lstm(seq_tensor)
            
            # Apply temporal attention
            lstm_out = self.layer_norm2(lstm_out)
            attn_weights = F.softmax(self.temporal_attention(lstm_out).squeeze(-1), dim=1)
            attn_applied = torch.bmm(attn_weights.unsqueeze(1), lstm_out)
            output = attn_applied.squeeze(1)  # Shape: [1, hidden_channels*2]
            return output
        else:
            device = next(self.parameters()).device
            return torch.zeros(1, self.hidden_channels * 2, device=device)

class MultiBandAttentionFusion(nn.Module):
    """Fusion model with improved multi-head attention for band fusion"""
    def __init__(self, num_bands=5, hidden_channels=64, num_classes=2, dropout_rate=0.5, num_nodes=19, in_channels=18):
        super().__init__()
        self.num_bands = num_bands
        self.hidden_channels = hidden_channels

        self.band_processors = nn.ModuleList()
        for band_idx in range(num_bands):
            self.band_processors.append(
                BandSpecificGraphSAGE(in_channels, hidden_channels, dropout_rate, num_nodes)
            )
        
        self.band_projection = nn.Linear(hidden_channels * 2, hidden_channels * 2)
        self.band_attention = nn.Linear(hidden_channels * 2, 1)
        self.layer_norm = nn.LayerNorm(hidden_channels * 2)

        self.fc1 = nn.Linear(hidden_channels * 2, hidden_channels)
        self.fc2 = nn.Linear(hidden_channels, num_classes)
        self.dropout = nn.Dropout(dropout_rate)

        self.band_weights_history = []
        self.l1_factor = 0.01 #0.01 #0.05
    
    def forward(self, subjects_data):
        all_outputs = []
        all_band_weights = []
        
        for subject in subjects_data:
            band_features = []
            
            for band_idx, band_graphs in enumerate(subject):
                band_output = self.band_processors[band_idx](band_graphs)
                band_output = self.band_projection(band_output)
                band_features.append(band_output)
            
            band_features = torch.cat(band_features, dim=0)  # [num_bands, hidden_channels*2]
            num_bands = band_features.shape[0]

            attention_logits = self.band_attention(band_features).squeeze(-1)
            attention_weights = F.softmax(attention_logits, dim=0)
            
            attended_features = torch.zeros_like(band_features[0]).unsqueeze(0)
            for i in range(num_bands):
                attended_features += band_features[i].unsqueeze(0) * attention_weights[i]
            
            all_band_weights.append(attention_weights.detach().cpu())

            fused_features = self.layer_norm(attended_features)
            out = self.fc1(fused_features)
            out = F.relu(out)
            if self.training:
                out = self.dropout(out)
            out = self.fc2(out)

            if self.training and self.l1_factor > 0:
                l1_loss = self.l1_factor * torch.abs(attention_weights - 1.0 / num_bands).sum()
                self.l1_loss = l1_loss

            all_outputs.append(out)
        
        if self.training:
            self.band_weights_history.append(torch.stack(all_band_weights).mean(dim=0))
        
        return torch.cat(all_outputs, dim=0)
    def get_band_importance(self):
        """Returns the average attention weight for each band over the training history"""
        if not self.band_weights_history:
            return None

        all_weights = torch.stack(self.band_weights_history)  # shape: [num_epochs, num_bands]
        mean_weights = all_weights.mean(dim=0)  # shape: [num_bands]
        return mean_weights.cpu().numpy()    
