In [1]:
from torch_geometric.explain import GNNExplainer

from captum.attr import IntegratedGradients

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score, confusion_matrix
from sklearn.model_selection import StratifiedKFold
import numpy as np
from tqdm import tqdm
import import_ipynb

from netwrks_2 import MultiBandAttentionFusion, BandSpecificGraphSAGE



def train_with_cross_validation(dataset, model_class, num_epochs=1, n_folds=2, band_names=None):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    if band_names is None:
        band_names = [f"Band {i+1}" for i in range(len(dataset[0]))]
    
    min_delta = 0.001
    labels = [subject[0][0].y.item() for subject in dataset]
    
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
    
    all_fold_metrics = []
    all_band_importances = []
    fold_models = []
    
    fold_idx = 1
    best_overall_f1 = 0
    best_overall_metrics = None
    best_overall_model_state = None
    
    for train_idx, val_idx in skf.split(range(len(labels)), labels):
        print(f"\n===== Fold {fold_idx}/{n_folds} =====")
        
        train_data = [dataset[i] for i in train_idx]
        val_data = [dataset[i] for i in val_idx]
        
        def count_classes(data):
            class0 = sum(1 for subject in data if subject[0][0].y.item() == 0)
            return class0, len(data) - class0
        
        class0_count, class1_count = count_classes(train_data)
        print(f"Training set: Class 0: {class0_count}, Class 1: {class1_count}")
        
        class0_count, class1_count = count_classes(val_data)
        print(f"Validation set: Class 0: {class0_count}, Class 1: {class1_count}")
        
        num_bands = len(dataset[0])
        num_nodes = dataset[0][0][0].x.size(0)
        in_channels = dataset[0][0][0].x.size(1)
        
        best_init_loss = float('inf')
        best_init_model = None
        
        for init_attempt in range(3):
            temp_model = model_class(
                num_bands=num_bands, 
                hidden_channels=64, 
                num_classes=2, 
                dropout_rate=0.5,
                num_nodes=num_nodes,
                in_channels=in_channels
            ).to(device)
            
            for module in temp_model.modules():
                if isinstance(module, nn.Linear):
                    nn.init.xavier_uniform_(module.weight)
                    if module.bias is not None:
                        nn.init.zeros_(module.bias)
            
            temp_model.eval()
            init_loss = 0
            with torch.no_grad():
                for batch in DataLoader(val_data, batch_size=1, shuffle=False, collate_fn=lambda x: x):
                    processed_batch = []
                    for subject_data in batch:
                        processed_subject = []
                        for band_graphs in subject_data:
                            band_graphs = [g.to(device) for g in band_graphs]
                            processed_subject.append(band_graphs)
                        processed_batch.append(tuple(processed_subject))
                    
                    try:
                        labels = torch.tensor([subject[0][0].y.item() for subject in batch]).to(device).long()
                        outputs = temp_model(processed_batch)
                        loss = nn.CrossEntropyLoss()(outputs, labels)
                        init_loss += loss.item()
                    except Exception as e:
                        print(f"Error during initialization validation: {e}")
                        init_loss = float('inf')
                        break
            
            if init_loss < best_init_loss:
                best_init_loss = init_loss
                best_init_state = temp_model.state_dict().copy()
        
        model = model_class(
            num_bands=num_bands, 
            hidden_channels=64, 
            num_classes=2, 
            dropout_rate=0.5,
            num_nodes=num_nodes,
            in_channels=in_channels
        ).to(device)
        model.load_state_dict(best_init_state)
        
        optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=1e-3)
        
        class0_count, class1_count = count_classes(train_data)
        total = class0_count + class1_count
        
        weight_0 = total / (2 * class0_count)
        weight_1 = total / (2 * class1_count)
        
        minority_boost = 1.5
        if class1_count < class0_count:
            weight_1 *= minority_boost
        else:
            weight_0 *= minority_boost
            
        class_weights = torch.tensor([weight_0, weight_1], device=device)
        print(f"Using class weights: {class_weights.cpu().numpy()}")
        
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='max',
            factor=0.5,
            patience=20,
            threshold=0.002,
            min_lr=1e-6,
            verbose=True
        )
        
        train_loader = DataLoader(train_data, batch_size=1, shuffle=True, collate_fn=lambda x: x)
        val_loader = DataLoader(val_data, batch_size=1, shuffle=False, collate_fn=lambda x: x)
        
        best_val_f1 = 0
        best_metrics = None
        best_epoch = 0
        best_model_state = None
        
        epoch_metrics = []
        
        with tqdm(total=num_epochs, desc=f"Training Fold {fold_idx}", leave=True) as pbar:
            for epoch in range(num_epochs):
                model.train()
                train_loss = 0
                train_preds, train_true = [], []
                
                for batch in train_loader:
                    processed_batch = []
                    for subject_data in batch:
                        processed_subject = []
                        for band_graphs in subject_data:
                            band_graphs = [g.to(device) for g in band_graphs]
                            for g in band_graphs:
                                if g.y.item() not in [0, 1]:
                                    g.y = torch.tensor(1 if g.y.item() == 1 else 0, dtype=g.y.dtype, device=g.y.device)
                            processed_subject.append(band_graphs)
                        processed_batch.append(tuple(processed_subject))
                    
                    try:
                        labels = torch.tensor([subject[0][0].y.item() for subject in batch]).to(device)
                        for i, label in enumerate(labels):
                            if label.item() not in [0, 1]:
                                labels[i] = torch.tensor(1, dtype=labels.dtype, device=labels.device)
                        
                        optimizer.zero_grad()
                        outputs = model(processed_batch)
                        
                        if labels.dim() > 1 and labels.size(1) > 1:
                            labels = labels.argmax(1)
                        else:
                            labels = labels.view(-1).long()
                        
                        loss = criterion(outputs, labels)
                        
                        if hasattr(model, 'l1_loss'):
                            loss += model.l1_loss
                        
                        loss.backward()
                        nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                        optimizer.step()
                        
                        train_loss += loss.item()
                        preds = outputs.argmax(1)
                        train_preds.extend(preds.cpu().numpy())
                        train_true.extend(labels.cpu().numpy())
                    except Exception as e:
                        print(f"Error during training: {e}")
                        raise e
                
                train_f1 = f1_score(train_true, train_preds, average='weighted')
                train_accuracy = accuracy_score(train_true, train_preds)
                train_loss /= len(train_loader)
                
                model.eval()
                val_preds, val_true = [], []
                val_loss = 0
                val_outputs_all = []
                
                with torch.no_grad():
                    for batch in val_loader:
                        processed_batch = []
                        for subject_data in batch:
                            processed_subject = []
                            for band_graphs in subject_data:
                                band_graphs = [g.to(device) for g in band_graphs]
                                for g in band_graphs:
                                    if g.y.item() not in [0, 1]:
                                        g.y = torch.tensor(1 if g.y.item() == 2 else 0, dtype=g.y.dtype, device=g.y.device)
                                processed_subject.append(band_graphs)
                            processed_batch.append(tuple(processed_subject))
                        
                        try:
                            labels = torch.tensor([subject[0][0].y.item() for subject in batch]).to(device)
                            for i, label in enumerate(labels):
                                if label.item() not in [0, 1]:
                                    labels[i] = torch.tensor(1, dtype=labels.dtype, device=labels.device)
                            
                            if labels.dim() > 1 and labels.size(1) > 1:
                                labels = labels.argmax(1)
                            else:
                                labels = labels.view(-1).long()
                            
                            outputs = model(processed_batch)
                            val_outputs_all.append(outputs.cpu().numpy())
                            val_loss += criterion(outputs, labels).item()
                            preds = outputs.argmax(1)
                            val_preds.extend(preds.cpu().numpy())
                            val_true.extend(labels.cpu().numpy())
                        except Exception as e:
                            print(f"Error during validation: {e}")
                            raise e
                
                val_f1 = f1_score(val_true, val_preds, average='weighted')
                val_recall = recall_score(val_true, val_preds, average='weighted')
                val_precision = precision_score(val_true, val_preds, average='weighted')
                val_accuracy = accuracy_score(val_true, val_preds)
                val_loss /= len(val_loader)
                
                scheduler.step(val_f1)
                
                val_outputs_all = np.concatenate(val_outputs_all, axis=0)
                softmax_outputs = torch.nn.functional.softmax(torch.tensor(val_outputs_all), dim=1).numpy()
                avg_confidence = np.mean(np.max(softmax_outputs, axis=1))
                
                if val_f1 > (best_val_f1 + min_delta):
                    best_val_f1 = val_f1
                    best_metrics = {
                        'f1': val_f1,
                        'recall': val_recall,
                        'precision': val_precision,
                        'accuracy': val_accuracy,
                        'train_loss': train_loss,
                        'val_loss': val_loss,
                        'val_preds': val_preds,
                        'val_true': val_true,
                        'confidence': avg_confidence
                    }
                    best_epoch = epoch + 1
                    best_model_state = model.state_dict().copy()
                    no_improve_count = 0
                else:
                    no_improve_count += 1
                
                epoch_metrics.append({
                    'epoch': epoch + 1,
                    'train_loss': train_loss,
                    'train_f1': train_f1,
                    'train_accuracy': train_accuracy,
                    'val_loss': val_loss,
                    'val_f1': val_f1,
                    'val_accuracy': val_accuracy,
                    'val_recall': val_recall,
                    'val_precision': val_precision,
                    'confidence': avg_confidence
                })
                
                pbar.set_postfix(
                    train_loss=f"{train_loss:.4f}",
                    val_loss=f"{val_loss:.4f}",
                    val_f1=f"{val_f1:.4f}",
                    best_f1=f"{best_val_f1:.4f}",
                    best_epoch=best_epoch,
                    conf=f"{avg_confidence:.2f}"
                )
                pbar.update(1)
        
        print(f"\nFold {fold_idx} - Best Epoch: {best_epoch}")
        print(f"Best Validation F1: {best_metrics['f1']:.4f}")
        print(f"Accuracy: {best_metrics['accuracy']:.4f} | Recall: {best_metrics['recall']:.4f} | Precision: {best_metrics['precision']:.4f}")
        print(f"Average prediction confidence: {best_metrics['confidence']:.4f}")
        
        cm = confusion_matrix(best_metrics['val_true'], best_metrics['val_preds'])
        print("Confusion Matrix:")
        print(cm)
        
        model.load_state_dict(best_model_state)
        fold_model = model_class(
            num_bands=num_bands, 
            hidden_channels=64, 
            num_classes=2, 
            dropout_rate=0.5, 
            num_nodes=num_nodes,
            in_channels=in_channels
        ).to(device)
        fold_model.load_state_dict(best_model_state)
        fold_models.append(fold_model)
        
        band_importance = model.get_band_importance()
        if band_importance is not None:
            print(f"\nBand Importance Analysis (Fold {fold_idx}):")
            for i, (band_name, weight) in enumerate(zip(band_names, band_importance)):
                print(f"{band_name}: {weight:.4f} ({weight*100:.1f}%)")
            
            best_metrics['band_importance'] = {band: float(weight) for band, weight in zip(band_names, band_importance)}
            all_band_importances.append(band_importance)
        
        all_fold_metrics.append(best_metrics)
        
        if best_val_f1 > best_overall_f1:
            best_overall_f1 = best_val_f1
            best_overall_metrics = best_metrics
            best_overall_model_state = best_model_state
        
        fold_idx += 1
    
    avg_f1 = np.mean([m['f1'] for m in all_fold_metrics])
    avg_accuracy = np.mean([m['accuracy'] for m in all_fold_metrics])
    avg_recall = np.mean([m['recall'] for m in all_fold_metrics])
    avg_precision = np.mean([m['precision'] for m in all_fold_metrics])
    
    print("\n===== Cross-Validation Results =====")
    print(f"Average F1: {avg_f1:.4f}")
    print(f"Average Accuracy: {avg_accuracy:.4f}")
    print(f"Average Recall: {avg_recall:.4f}")
    print(f"Average Precision: {avg_precision:.4f}")
    
    std_f1 = np.std([m['f1'] for m in all_fold_metrics])
    print(f"F1 Standard Deviation: {std_f1:.4f}")
    
    if all_band_importances:
        avg_band_importance = np.mean(all_band_importances, axis=0)
        print("\nAverage Band Importance Across All Folds:")
        for i, (band_name, weight) in enumerate(zip(band_names, avg_band_importance)):
            print(f"{band_name}: {weight:.4f} ({weight*100:.1f}%)")
    
    final_model = model_class(
        num_bands=num_bands, 
        hidden_channels=64, 
        num_classes=2, 
        dropout_rate=0.5, 
        num_nodes=num_nodes,
        in_channels=in_channels
    ).to(device)
    
    final_model.load_state_dict(best_overall_model_state)
    
    class EnsembleModel(nn.Module):
        def __init__(self, models):
            super(EnsembleModel, self).__init__()
            self.model_state_dicts = [model.state_dict() for model in models]
            self.template_model = model_class(
                num_bands=num_bands, 
                hidden_channels=64, 
                num_classes=2, 
                dropout_rate=0.5, 
                num_nodes=num_nodes,
                in_channels=in_channels
            ).to(device)
            
        def forward(self, x):
            outputs = []
            for state_dict in self.model_state_dicts:
                self.template_model.load_state_dict(state_dict)
                self.template_model.eval()
                with torch.no_grad():
                    outputs.append(self.template_model(x))
            
            return torch.mean(torch.stack(outputs), dim=0)
    
    ensemble_model = EnsembleModel(fold_models)
    
    return final_model, fold_models, ensemble_model, all_fold_metrics, {
        'avg_f1': avg_f1,
        'avg_accuracy': avg_accuracy,
        'avg_recall': avg_recall,
        'avg_precision': avg_precision,
        'f1_std': std_f1,
        'best_fold_metrics': best_overall_metrics,
        'avg_band_importance': {band: float(weight) for band, weight in zip(band_names, avg_band_importance)} if all_band_importances else None
    }



