In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import copy
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, global_mean_pool, global_add_pool, global_max_pool
from torch_geometric.utils import add_self_loops, dense_to_sparse
from torch.utils.data import TensorDataset
from tslearn.datasets import UCR_UEA_datasets
import warnings
from scipy import stats
from collections import defaultdict
import pandas as pd
from scipy.signal import find_peaks
warnings.filterwarnings('ignore')

# Configuration
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)

class Config:
    MAX_EPOCHS = 200
    BATCH_SIZE = 32
    LR = 0.001
    N_SPLITS = 5
    HIDDEN_DIM = 128
    HEADS = 8
    DROPOUT = 0.3
    WEIGHT_DECAY = 1e-4
    PATIENCE = 15
    PLOT_DIR = "./plots"
    RESULTS_DIR = "./results"
    
    # Model architecture
    NUM_GAT_LAYERS = 3
    USE_EDGE_FEATURES = True
    POOLING_STRATEGY = 'hybrid'  # 'mean', 'max', 'add', 'hybrid'
    USE_RESIDUAL = True
    USE_BATCH_NORM = True
    
    # Novel components
    USE_TEMPORAL_ENCODING = True
    USE_MULTI_SCALE = True
    USE_WAVELET_FEATURES = True
    
    def __init__(self):
        os.makedirs(self.PLOT_DIR, exist_ok=True)
        os.makedirs(self.RESULTS_DIR, exist_ok=True)

config = Config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def analyze_dataset(X, y, dataset_name):
    """Comprehensive dataset analysis and visualization."""
    stats_dict = {
        'Dataset Name': dataset_name,
        'Number of Samples': len(X),
        'Time Series Length': X.shape[1],
        'Number of Classes': len(np.unique(y)),
        'Class Distribution': dict(zip(*np.unique(y, return_counts=True))),
        'Mean Length': X.shape[1],
        'Mean Value': X.mean(),
        'Std Value': X.std(),
        'Max Value': X.max(),
        'Min Value': X.min()
    }
    
    # Create visualization
    fig = plt.figure(figsize=(15, 10))
    gs = plt.GridSpec(3, 2)
    
    # Plot 1: Example time series from each class
    ax1 = fig.add_subplot(gs[0, :])
    classes = np.unique(y)
    colors = plt.cm.tab10(np.linspace(0, 1, len(classes)))
    for i, cls in enumerate(classes):
        idx = np.where(y == cls)[0][0]
        ax1.plot(X[idx], color=colors[i], label=f'Class {cls}', alpha=0.8)
    ax1.set_title('Example Time Series from Each Class')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Class distribution
    ax2 = fig.add_subplot(gs[1, 0])
    sns.barplot(x=list(stats_dict['Class Distribution'].keys()),
                y=list(stats_dict['Class Distribution'].values()),
                palette='tab10', ax=ax2)
    ax2.set_title('Class Distribution')
    
    # Plot 3: Signal characteristics
    ax3 = fig.add_subplot(gs[1, 1])
    for i, cls in enumerate(classes):
        class_data = X[y == cls].flatten()
        sns.kdeplot(data=class_data, ax=ax3, label=f'Class {cls}')
    ax3.set_title('Signal Distribution by Class')
    ax3.legend()
    
    # Plot 4: Peaks and valleys analysis
    ax4 = fig.add_subplot(gs[2, :])
    example_idx = np.where(y == classes[0])[0][0]
    example_series = X[example_idx]
    peaks, _ = find_peaks(example_series)
    valleys, _ = find_peaks(-example_series)
    
    ax4.plot(example_series, label='Signal')
    ax4.plot(peaks, example_series[peaks], "x", label='Peaks')
    ax4.plot(valleys, example_series[valleys], "o", label='Valleys')
    ax4.set_title('Peak and Valley Analysis')
    ax4.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(config.PLOT_DIR, f'{dataset_name}_analysis.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    return pd.DataFrame([stats_dict]).T

def extract_wavelets(X):
    """Extract wavelet features from time series."""
    import pywt
    wavelets = []
    for x in X:
        coeffs = pywt.wavedec(x, 'db4', level=3)
        features = np.concatenate([c for c in coeffs])
        wavelets.append(features)
    return np.array(wavelets)

def compute_temporal_encoding(length, max_freq=10000):
    """Compute temporal position encoding."""
    position = np.arange(length)
    freq_term = np.exp(-np.log(max_freq) * np.arange(64) / 64)
    pos_enc = np.zeros((length, 64))
    
    for i in range(0, 64, 2):
        pos_enc[:, i] = np.sin(position / freq_term[i])
        pos_enc[:, i+1] = np.cos(position / freq_term[i])
        
    return torch.FloatTensor(pos_enc)

class MultiScaleGATConv(nn.Module):
    """Multi-scale GAT layer with multiple window sizes."""
    def __init__(self, in_channels, out_channels, heads=8, window_sizes=[3, 5, 7]):
        super().__init__()
        self.window_sizes = window_sizes
        self.convs = nn.ModuleList([
            GATConv(in_channels, out_channels // len(window_sizes), 
                   heads=heads, dropout=config.DROPOUT)
            for _ in window_sizes
        ])
        self.attention_weights = None

    def forward(self, x, edge_index):
        outputs = []
        self.attention_weights = []
        
        for conv, window_size in zip(self.convs, self.window_sizes):
            # Create edges for current window size
            edge_list = []
            num_nodes = x.size(0)
            for i in range(num_nodes):
                for j in range(max(0, i-window_size), min(num_nodes, i+window_size+1)):
                    if i != j:
                        edge_list.append([i, j])
            
            if edge_list:
                window_edge_index = torch.tensor(edge_list, device=x.device).t()
                window_edge_index = add_self_loops(window_edge_index, num_nodes=num_nodes)[0]
                
                # Get output and attention weights
                out, (_, att_weights) = conv(x, window_edge_index, return_attention_weights=True)
                outputs.append(out)
                self.attention_weights.append((window_edge_index, att_weights))
            
        return torch.cat(outputs, dim=-1)

class EdgeFeatureGATConv(GATConv):
    """GAT layer with edge feature support."""
    def __init__(self, in_channels, out_channels, heads=8, edge_dim=1):
        super().__init__(in_channels, out_channels, heads=heads, edge_dim=edge_dim)
        
    def forward(self, x, edge_index, edge_attr=None):
        if edge_attr is None:
            # Create simple edge features based on node distance
            edge_attr = torch.abs(edge_index[0] - edge_index[1]).float().unsqueeze(-1)
        return super().forward(x, edge_index, edge_attr=edge_attr)

class DynamicTemporalGraph(nn.Module):
    """Dynamic graph construction with learnable adjacency."""
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.feature_transform = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def forward(self, x):
        # Transform features
        h = self.feature_transform(x)
        
        # Compute similarity matrix
        sim_matrix = torch.matmul(h, h.transpose(-2, -1))
        
        # Apply softmax to get attention weights
        adj_matrix = torch.softmax(sim_matrix / np.sqrt(h.size(-1)), dim=-1)
        
        # Convert to sparse format
        edge_index, edge_attr = dense_to_sparse(adj_matrix)
        
        return edge_index, edge_attr

class EnhancedGNNClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        # Temporal position encoding
        if config.USE_TEMPORAL_ENCODING:
            self.temporal_enc = nn.Parameter(torch.randn(1000, input_dim))
            input_dim = input_dim * 2
            
        # Feature transformation
        self.feature_transform = nn.Linear(input_dim, hidden_dim)
        
        # Dynamic graph construction
        self.graph_constructor = DynamicTemporalGraph(hidden_dim, hidden_dim)
        
        # Multi-scale GAT layers
        self.gat_layers = nn.ModuleList([
            MultiScaleGATConv(
                hidden_dim if i == 0 else hidden_dim * config.HEADS,
                hidden_dim,
                heads=config.HEADS
            )
            for i in range(config.NUM_GAT_LAYERS)
        ])
        
        # Batch normalization layers
        if config.USE_BATCH_NORM:
            self.batch_norms = nn.ModuleList([
                nn.BatchNorm1d(hidden_dim * config.HEADS)
                for _ in range(config.NUM_GAT_LAYERS)
            ])
        
        # Pooling layers
        if config.POOLING_STRATEGY == 'hybrid':
            self.pool_dim = hidden_dim * config.HEADS * 3
        else:
            self.pool_dim = hidden_dim * config.HEADS
            
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(self.pool_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(config.DROPOUT),
            nn.Linear(hidden_dim, num_classes)
        )
        
        # Store attention weights for interpretability
        self.attention_weights = []
        
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        self.attention_weights = []
        
        # Add temporal encoding if enabled
        if config.USE_TEMPORAL_ENCODING:
            seq_len = x.size(0)
            x = torch.cat([x, self.temporal_enc[:seq_len]], dim=-1)
        
        # Initial feature transformation
        x = self.feature_transform(x)
        
        # Dynamic graph construction
        if config.USE_EDGE_FEATURES:
            edge_index, edge_attr = self.graph_constructor(x)
        else:
            edge_attr = None
            
        # Apply GAT layers with residual connections
        for i, gat_layer in enumerate(self.gat_layers):
            identity = x
            x = gat_layer(x, edge_index)
            
            if config.USE_BATCH_NORM:
                x = self.batch_norms[i](x)
                
            x = torch.relu(x)
            
            if config.USE_RESIDUAL and x.size(-1) == identity.size(-1):
                x = x + identity
                
            self.attention_weights.append(gat_layer.attention_weights)
            
        # Pooling
        if config.POOLING_STRATEGY == 'hybrid':
            x_mean = global_mean_pool(x, batch)
            x_max = global_max_pool(x, batch)
            x_add = global_add_pool(x, batch)
            x = torch.cat([x_mean, x_max, x_add], dim=-1)
        elif config.POOLING_STRATEGY == 'mean':
            x = global_mean_pool(x, batch)
        elif config.POOLING_STRATEGY == 'max':
            x = global_max_pool(x, batch)
            
        # Classification
        out = self.classifier(x)
        return out
    
    def get_attention_maps(self, data):
        """Get attention maps for all layers."""
        self.eval()
        with torch.no_grad():
            _ = self.forward(data)
            
        attention_maps = []
        for layer_weights in self.attention_weights:
            layer_maps = []
            for edge_index, weights in layer_weights:
                # Convert to dense attention matrix
                num_nodes = data.x.size(0)
                dense_att = torch.zeros(num_nodes, num_nodes, device=data.x.device)
                dense_att[edge_index[0], edge_index[1]] = weights.mean(dim=1)
                layer_maps.append(dense_att.cpu().numpy())
            attention_maps.append(layer_maps)
            
        return attention_maps
    
class TrainingManager:
    def __init__(self, model, optimizer, criterion, device, config):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.config = config
        self.early_stopping = EarlyStopping(patience=config.PATIENCE)
        self.history = defaultdict(list)
        
    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        predictions = []
        targets = []
        outputs = []
        
        for batch in train_loader:
            batch = batch.to(self.device)
            self.optimizer.zero_grad()
            
            output = self.model(batch)
            loss = self.criterion(output, batch.y.squeeze())
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            
            total_loss += loss.item()
            predictions.extend(output.argmax(dim=1).cpu().numpy())
            targets.extend(batch.y.cpu().numpy())
            outputs.extend(torch.softmax(output, dim=1).cpu().numpy())
            
        metrics = compute_metrics(targets, predictions, outputs)
        metrics['loss'] = total_loss / len(train_loader)
        
        return metrics
    
    def evaluate(self, val_loader):
        self.model.eval()
        total_loss = 0
        predictions = []
        targets = []
        outputs = []
        attention_maps = []
        
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(self.device)
                output = self.model(batch)
                loss = self.criterion(output, batch.y.squeeze())
                
                total_loss += loss.item()
                predictions.extend(output.argmax(dim=1).cpu().numpy())
                targets.extend(batch.y.cpu().numpy())
                outputs.extend(torch.softmax(output, dim=1).cpu().numpy())
                
                # Collect attention maps for first batch
                if len(attention_maps) == 0:
                    attention_maps = self.model.get_attention_maps(batch)
        
        metrics = compute_metrics(targets, predictions, outputs)
        metrics['loss'] = total_loss / len(val_loader)
        
        return metrics, attention_maps
    
    def train(self, train_loader, val_loader, epochs):
        best_val_auc = 0
        best_model = None
        
        for epoch in range(epochs):
            # Training phase
            train_metrics = self.train_epoch(train_loader)
            
            # Validation phase
            val_metrics, attention_maps = self.evaluate(val_loader)
            
            # Store history
            for k, v in train_metrics.items():
                self.history[f'train_{k}'].append(v)
            for k, v in val_metrics.items():
                self.history[f'val_{k}'].append(v)
            
            # Print progress
            if epoch % 10 == 0:
                print(f'Epoch {epoch}:')
                print(f"Train - Loss: {train_metrics['loss']:.4f}, "
                      f"AUC: {train_metrics['auc']:.4f}, "
                      f"Acc: {train_metrics['accuracy']:.4f}")
                print(f"Val - Loss: {val_metrics['loss']:.4f}, "
                      f"AUC: {val_metrics['auc']:.4f}, "
                      f"Acc: {val_metrics['accuracy']:.4f}")
            
            # Save best model
            if val_metrics['auc'] > best_val_auc:
                best_val_auc = val_metrics['auc']
                best_model = copy.deepcopy(self.model.state_dict())
                best_attention_maps = attention_maps
            
            # Early stopping
            self.early_stopping(val_metrics['loss'])
            if self.early_stopping.early_stop:
                print(f"Early stopping at epoch {epoch}")
                break
        
        self.model.load_state_dict(best_model)
        return best_val_auc, best_attention_maps

def compute_metrics(y_true, y_pred, y_prob):
    """Compute comprehensive set of metrics."""
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'auc': roc_auc_score(y_true, y_prob[:, 1]) if y_prob.shape[1] == 2 else None,
    }
    
    # Add precision-recall metrics
    precision, recall, _ = precision_recall_curve(y_true, y_prob[:, 1])
    metrics['avg_precision'] = average_precision_score(y_true, y_prob[:, 1])
    
    # Add confusion matrix based metrics
    cm = confusion_matrix(y_true, y_pred)
    metrics['specificity'] = cm[0,0] / (cm[0,0] + cm[0,1])
    metrics['sensitivity'] = cm[1,1] / (cm[1,1] + cm[1,0])
    
    return metrics

def visualize_training_history(history, save_path):
    """Plot training history with confidence intervals."""
    plt.style.use('seaborn-paper')
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    metrics = ['loss', 'accuracy', 'auc', 'avg_precision']
    titles = ['Loss', 'Accuracy', 'ROC AUC', 'Average Precision']
    
    for ax, metric, title in zip(axes.flat, metrics, titles):
        train_metric = history[f'train_{metric}']
        val_metric = history[f'val_{metric}']
        epochs = range(1, len(train_metric) + 1)
        
        ax.plot(epochs, train_metric, 'b-', label='Train')
        ax.plot(epochs, val_metric, 'r-', label='Validation')
        ax.set_title(title)
        ax.set_xlabel('Epoch')
        ax.set_ylabel(title)
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def visualize_attention_evolution(attention_maps, save_path):
    """Visualize how attention patterns evolve across layers."""
    plt.style.use('seaborn-paper')
    num_layers = len(attention_maps)
    
    fig, axes = plt.subplots(1, num_layers, figsize=(5*num_layers, 4))
    
    for i, layer_maps in enumerate(attention_maps):
        # Average attention weights across different scales
        avg_attention = np.mean([m for m in layer_maps], axis=0)
        
        sns.heatmap(avg_attention, ax=axes[i], cmap='YlOrRd', 
                   xticklabels=20, yticklabels=20)
        axes[i].set_title(f'Layer {i+1}')
        
    plt.suptitle('Evolution of Attention Patterns Across Layers', y=1.05)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def visualize_multi_scale_attention(attention_maps, time_series, save_path):
    """Visualize attention at different scales."""
    plt.style.use('seaborn-paper')
    num_scales = len(attention_maps[0])
    
    fig = plt.figure(figsize=(15, 3*num_scales))
    gs = plt.GridSpec(num_scales, 1)
    
    for scale_idx in range(num_scales):
        ax = fig.add_subplot(gs[scale_idx])
        
        # Plot time series
        ax.plot(time_series, color='blue', alpha=0.6, label='Signal')
        
        # Plot attention weights
        attention = attention_maps[0][scale_idx]  # Use first layer
        attention_weights = attention.mean(axis=0)
        
        # Normalize attention weights for visualization
        norm_weights = (attention_weights - attention_weights.min()) / \
                      (attention_weights.max() - attention_weights.min())
                      
        for i in range(len(time_series)):
            ax.axvspan(i-0.5, i+0.5, color='red', alpha=0.1*norm_weights[i])
            
        ax.set_title(f'Scale {scale_idx + 1}')
        ax.legend()
        
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_advanced_roc_curve(y_true, y_scores, fold_idx, save_path):
    """Plot ROC curve with confidence intervals."""
    plt.style.use('seaborn-paper')
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # Calculate ROC curve and AUC for each fold
    fprs, tprs, roc_aucs = [], [], []
    for i in range(len(fold_idx)):
        fold_true = y_true[fold_idx[i]]
        fold_score = y_scores[fold_idx[i]]
        
        fpr, tpr, _ = roc_curve(fold_true, fold_score[:, 1])
        roc_auc = auc(fpr, tpr)
        
        fprs.append(fpr)
        tprs.append(tpr)
        roc_aucs.append(roc_auc)
    
    # Calculate mean ROC curve and AUC
    mean_fpr = np.linspace(0, 1, 100)
    mean_tpr = np.zeros_like(mean_fpr)
    for fpr, tpr in zip(fprs, tprs):
        mean_tpr += np.interp(mean_fpr, fpr, tpr)
    mean_tpr /= len(fprs)
    
    mean_auc = np.mean(roc_aucs)
    std_auc = np.std(roc_aucs)
    
    # Plot mean ROC curve
    ax.plot(mean_fpr, mean_tpr, color='b',
            label=f'Mean ROC (AUC = {mean_auc:.3f} ± {std_auc:.3f})',
            lw=2, alpha=.8)
    
    # Plot confidence interval
    tprs_upper = np.minimum(mean_tpr + np.std(tprs, axis=0), 1)
    tprs_lower = np.maximum(mean_tpr - np.std(tprs, axis=0), 0)
    ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
                    label=r'$\pm$ 1 std. dev.')
    
    # Plot random chance line
    ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
            label='Random Chance', alpha=.8)
    
    ax.set_xlim([-0.01, 1.01])
    ax.set_ylim([-0.01, 1.01])
    ax.set_xlabel('False Positive Rate', fontsize=12)
    ax.set_ylabel('True Positive Rate', fontsize=12)
    ax.set_title('Receiver Operating Characteristic (ROC) Curve', fontsize=14)
    ax.legend(loc="lower right", fontsize=10)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def generate_results_table(results_dict, dataset_name, save_path):
    """Generate comprehensive results table with confidence intervals."""
    metrics = ['accuracy', 'auc', 'specificity', 'sensitivity', 'avg_precision']
    table_data = []
    
    for metric in metrics:
        values = results_dict[metric]
        mean = np.mean(values)
        ci = stats.t.interval(0.95, len(values)-1, 
                            loc=mean, 
                            scale=stats.sem(values))
        
        table_data.append({
            'Metric': metric.capitalize(),
            'Mean': f"{mean:.3f}",
            '95% CI': f"({ci[0]:.3f}, {ci[1]:.3f})"
        })
    
    df = pd.DataFrame(table_data)
    df.to_csv(save_path, index=False)
    return df

def run_experiment(dataset_name, config):
    """Run complete experiment pipeline with enhanced evaluation."""
    print(f"\n=== Dataset: {dataset_name} ===")
    
    # Load and analyze dataset
    X_all, y_all, unique_labels = load_dataset(dataset_name)
    dataset_stats = analyze_dataset(X_all, y_all, dataset_name)
    print("\nDataset Statistics:")
    print(dataset_stats)
    
    # Convert to graphs with enhanced features
    graphs = timeseries_to_graph(X_all)
    labels = torch.tensor(y_all, dtype=torch.long)
    
    # Initialize results storage
    results = defaultdict(list)
    fold_predictions = []
    fold_indices = []
    attention_maps_collection = []
    
    # Cross-validation
    skf = StratifiedKFold(n_splits=config.N_SPLITS, shuffle=True, random_state=42)
    
    for fold, (train_idx, test_idx) in enumerate(skf.split(graphs, labels), 1):
        print(f"\nFold {fold}/{config.N_SPLITS}")
        
        # Split data
        train_graphs = [graphs[i] for i in train_idx]
        test_graphs = [graphs[i] for i in test_idx]
        train_labels = labels[train_idx]
        test_labels = labels[test_idx]
        
        # Create dataloaders
        train_loader = DataLoader(train_graphs, batch_size=config.BATCH_SIZE, shuffle=True)
        test_loader = DataLoader(test_graphs, batch_size=config.BATCH_SIZE)
        
        # Initialize model and training components
        model = EnhancedGNNClassifier(
            input_dim=2, 
            hidden_dim=config.HIDDEN_DIM,
            num_classes=len(unique_labels)
        ).to(device)
        
        optimizer = optim.Adam(
            model.parameters(), 
            lr=config.LR,
            weight_decay=config.WEIGHT_DECAY
        )
        
        criterion = nn.CrossEntropyLoss()
        trainer = TrainingManager(model, optimizer, criterion, device, config)
        
        # Train model
        best_val_auc, attention_maps = trainer.train(
            train_loader, 
            test_loader,
            config.MAX_EPOCHS
        )
        
        # Store results
        for k, v in trainer.history.items():
            if k.startswith('val_'):
                metric_name = k.split('_')[1]
                results[metric_name].append(v[-1])
        
        attention_maps_collection.append(attention_maps)
        fold_indices.append(test_idx)
        
        # Create visualizations for this fold
        visualize_attention_evolution(
            attention_maps,
            os.path.join(config.PLOT_DIR, f'{dataset_name}_fold{fold}_attention_evolution.png')
        )
        
        if fold == 1:  # Save training history for first fold
            visualize_training_history(
                trainer.history,
                os.path.join(config.PLOT_DIR, f'{dataset_name}_training_history.png')
            )
    
    # Generate final results
    results_df = generate_results_table(
        results,
        dataset_name,
        os.path.join(config.RESULTS_DIR, f'{dataset_name}_results.csv')
    )
    
    print("\nFinal Results:")
    print(results_df)
    
    return results, attention_maps_collection

if __name__ == "__main__":
    print(f"Using device: {device}")
    
    for dataset_name in ["ECG200", "ECGFiveDays", "TwoLeadECG"]:
        try:
            results, attention_maps = run_experiment(dataset_name, config)
        except Exception as e:
            print(f"Error processing dataset {dataset_name}: {str(e)}")
            continue
    
    print("\nAll experiments completed. Results saved in:", config.RESULTS_DIR)

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import copy
import matplotlib
matplotlib.use('Agg')  # For headless environments
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, roc_curve, auc
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.utils import add_self_loops
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader as TorchDataLoader
from tslearn.datasets import UCR_UEA_datasets
from scipy.stats import bootstrap
import warnings
warnings.filterwarnings('ignore')

# Configuration
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)

MAX_EPOCHS = 100
BATCH_SIZE = 32
LR = 0.001
N_SPLITS = 5
HIDDEN_DIM = 64
HEADS = 8
DROPOUT = 0.2
PLOT_DIR = "./plots"
os.makedirs(PLOT_DIR, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DATASETS = ["ECG200", "ECGFiveDays", "TwoLeadECG"]  # Example datasets

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def load_dataset(name):
    """Load UCR/UEA dataset, normalize, and return all data and labels."""
    ucr = UCR_UEA_datasets()
    X_train, y_train, X_test, y_test = ucr.load_dataset(name)
    
    # Squeeze unnecessary dimensions for univariate TS
    if X_train.ndim == 3 and X_train.shape[1] == 1:
        X_train = X_train.squeeze(1)
    if X_test.ndim == 3 and X_test.shape[1] == 1:
        X_test = X_test.squeeze(1)
    
    unique_labels = np.unique(np.concatenate([y_train, y_test]))
    label_to_idx = {lab: i for i, lab in enumerate(unique_labels)}
    y_train = np.array([label_to_idx[lab] for lab in y_train], dtype=np.int64)
    y_test = np.array([label_to_idx[lab] for lab in y_test], dtype=np.int64)
    
    # Normalize data
    X_mean = X_train.mean()
    X_std = X_train.std()
    X_train = (X_train - X_mean) / (X_std + 1e-8)
    X_test = (X_test - X_mean) / (X_std + 1e-8)
    
    X_all = np.concatenate([X_train, X_test], axis=0)
    y_all = np.concatenate([y_train, y_test], axis=0)
    
    return X_all, y_all, unique_labels, (X_train, y_train, X_test, y_test)

def timeseries_to_graph(X, window=5):
    """Convert time series to graph data by connecting each node to neighbors within a given window."""
    graphs = []
    for i in range(X.shape[0]):
        x_val = X[i]
        length = x_val.shape[0]
        positions = np.linspace(0, 1, length)
        x_feat = np.column_stack([x_val, positions])
        x_feat = torch.tensor(x_feat, dtype=torch.float32)
        
        edge_list = []
        w = min(window, length)
        for j in range(length):
            for k in range(max(0, j-w), min(length, j+w+1)):
                if j != k:
                    edge_list.append([j, k])
        
        if edge_list:
            edge_index = torch.tensor(edge_list, dtype=torch.long).t()
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
        
        edge_index, _ = add_self_loops(edge_index, num_nodes=length)
        data = Data(x=x_feat, edge_index=edge_index)
        graphs.append(data)
    return graphs

class GATConvWithAlpha(GATConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.last_alpha = None
        self._cached_edge_index = None

    def forward(self, x, edge_index, return_attention_weights=False):
        out, (edge_index_out, alpha) = super().forward(x, edge_index, return_attention_weights=True)
        self.last_alpha = alpha
        self._cached_edge_index = edge_index_out
        return out

class GNNTimeSeriesClassifier(nn.Module):
    """GAT-based classifier for time series represented as graphs."""
    def __init__(self, input_dim=2, hidden_dim=64, num_classes=2, heads=8, dropout=0.2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.conv1 = GATConvWithAlpha(hidden_dim, hidden_dim, heads=heads, dropout=dropout)
        self.conv2 = GATConvWithAlpha(hidden_dim*heads, hidden_dim, heads=heads, dropout=dropout, concat=False)
        
        self.ln1 = nn.LayerNorm(hidden_dim * heads)
        self.ln2 = nn.LayerNorm(hidden_dim)
        
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.input_proj(x)
        x = torch.relu(x)
        
        x1 = self.conv1(x, edge_index)
        x1 = self.ln1(x1)
        x1 = torch.relu(x1)
        x1 = self.dropout(x1)
        
        x2 = self.conv2(x1, edge_index)
        x2 = self.ln2(x2)
        x2 = torch.relu(x2)
        
        x_pool = global_mean_pool(x2, batch)
        x_pool = self.fc1(x_pool)
        x_pool = torch.relu(x_pool)
        x_pool = self.dropout(x_pool)
        out = self.fc2(x_pool)
        
        return out

    def get_attention_weights(self, data):
        """Obtain normalized node-level attention weights."""
        self.eval()
        with torch.no_grad():
            _ = self.forward(data)
            alpha1 = self.conv1.last_alpha
            edge_idx = self.conv1._cached_edge_index

        num_nodes = data.x.size(0)
        node_importance = torch.zeros(num_nodes, device=data.x.device)
        
        alpha_mean = alpha1.mean(dim=1)
        for i in range(edge_idx.size(1)):
            dst_node = edge_idx[1, i].item()
            node_importance[dst_node] += alpha_mean[i].item()
        
        node_importance = node_importance / (node_importance.sum() + 1e-9)
        return node_importance.cpu().numpy()

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    predictions = []
    targets = []
    
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        output = model(batch)
        loss = criterion(output, batch.y.squeeze())
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        predictions.extend(output.argmax(dim=1).cpu().numpy())
        targets.extend(batch.y.cpu().numpy())
    
    return total_loss / len(loader), accuracy_score(targets, predictions)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    predictions = []
    targets = []
    outputs = []
    
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            output = model(batch)
            loss = criterion(output, batch.y.squeeze())
            
            total_loss += loss.item()
            predictions.extend(output.argmax(dim=1).cpu().numpy())
            targets.extend(batch.y.cpu().numpy())
            outputs.extend(torch.softmax(output, dim=1).cpu().numpy())
    
    return (
        total_loss / len(loader),
        accuracy_score(targets, predictions),
        np.array(predictions),
        np.array(targets),
        np.array(outputs)
    )

def train_model(model, train_loader, val_loader, criterion, optimizer, device, max_epochs=100):
    early_stopping = EarlyStopping(patience=10)
    best_val_acc = 0
    best_model = None
    
    for epoch in range(max_epochs):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, _, _, _ = evaluate(model, val_loader, criterion, device)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = copy.deepcopy(model.state_dict())
        
        # Optional: print every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Train Loss = {train_loss:.4f}, Train Acc = {train_acc:.4f}, '
                  f'Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.4f}')
        
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print(f"Early stopping at epoch {epoch}")
            break
    
    model.load_state_dict(best_model)
    return model, best_val_acc

def bootstrap_confidence_interval(data, stat_func=np.mean, confidence_level=0.95, n_resamples=10000):
    """Compute bootstrap confidence interval for a given statistic."""
    res = bootstrap((data,), stat_func, confidence_level=confidence_level, n_resamples=n_resamples, method='basic')
    return stat_func(data), res.confidence_interval

def plot_confusion_matrix(cm, classes, title, save_path):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.title(title, fontsize=16)
    plt.ylabel('True Label', fontsize=14)
    plt.xlabel('Predicted Label', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, format='pdf')
    plt.close()

def plot_roc_with_ci(fprs, tprs, title, save_path):
    """Plot mean ROC curve and 95% CI band from multiple folds."""
    # Interpolate TPRs at common FPR points
    mean_fpr = np.linspace(0, 1, 100)
    interp_tprs = []
    for fpr, tpr in zip(fprs, tprs):
        interp_tpr = np.interp(mean_fpr, fpr, tpr)
        interp_tprs.append(interp_tpr)
    interp_tprs = np.array(interp_tprs)
    
    mean_tpr = interp_tprs.mean(axis=0)
    std_tpr = interp_tprs.std(axis=0)
    tpr_upper = np.minimum(mean_tpr + 1.96 * std_tpr, 1)
    tpr_lower = np.maximum(mean_tpr - 1.96 * std_tpr, 0)
    mean_auc = auc(mean_fpr, mean_tpr)
    
    plt.figure(figsize=(8,6))
    plt.plot(mean_fpr, mean_tpr, color='darkorange', lw=2, label=f'Mean ROC (AUC = {mean_auc:.2f})')
    plt.fill_between(mean_fpr, tpr_lower, tpr_upper, color='grey', alpha=0.2, label='95% CI')
    plt.plot([0,1],[0,1], color='navy', lw=2, linestyle='--')
    plt.xlim([0,1])
    plt.ylim([0,1.05])
    plt.xlabel('False Positive Rate', fontsize=14)
    plt.ylabel('True Positive Rate', fontsize=14)
    plt.title(title, fontsize=16)
    plt.legend(loc='lower right')
    plt.tight_layout()
    plt.savefig(save_path, format='pdf')
    plt.close()

def visualize_attention(series, attention_weights, title, save_path):
    """Visualize time series with attention weights as a color overlay."""
    plt.figure(figsize=(12, 6))
    plt.plot(series, label='Time Series', color='blue')
    
    # Normalize attention weights for visualization
    att_norm = (attention_weights - attention_weights.min()) / (attention_weights.max() - attention_weights.min() + 1e-9)

    # Create a colormap (e.g., Reds)
    cmap = plt.cm.Reds
    for i in range(len(series)):
        color = cmap(att_norm[i])
        plt.axvspan(i-0.5, i+0.5, color=color, alpha=0.3)
    
    plt.title(title, fontsize=16)
    plt.xlabel('Time Step', fontsize=14)
    plt.ylabel('Normalized Value', fontsize=14)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, format='pdf')
    plt.close()

def plot_example_timeseries(X, y, unique_labels, dataset_name, save_path):
    """Plot a few example time series from each class for qualitative inspection."""
    plt.figure(figsize=(12, 8))
    num_classes = len(unique_labels)
    examples_per_class = min(3, len(X)//num_classes)
    
    for i, cls in enumerate(unique_labels):
        class_indices = np.where(y == i)[0]
        chosen = np.random.choice(class_indices, size=examples_per_class, replace=False)
        for j, idx in enumerate(chosen):
            plt.subplot(num_classes, examples_per_class, i*examples_per_class + j + 1)
            plt.plot(X[idx], color='blue')
            plt.title(f'Class: {cls}', fontsize=12)
            plt.xticks([])
            plt.yticks([])
    plt.suptitle(f'{dataset_name} - Example Time Series', fontsize=16)
    plt.tight_layout(rect=[0,0,1,0.95])
    plt.savefig(save_path, format='pdf')
    plt.close()

def run_experiment(dataset_name):
    print(f"\n=== Dataset: {dataset_name} ===")
    
    # Load and preprocess data
    X_all, y_all, unique_labels, (X_train_orig, y_train_orig, X_test_orig, y_test_orig) = load_dataset(dataset_name)
    graphs = timeseries_to_graph(X_all)
    labels = torch.tensor(y_all, dtype=torch.long)
    num_classes = len(unique_labels)

    # Print dataset details
    print(f"Number of classes: {num_classes}")
    print(f"Total samples: {len(X_all)} (Train: {len(X_train_orig)}, Test: {len(X_test_orig)})")
    class_counts = np.bincount(y_all)
    for i, lab in enumerate(unique_labels):
        print(f"Class {lab}: {class_counts[i]} samples")

    # Plot example time series
    plot_example_timeseries(X_all, y_all, unique_labels, dataset_name, 
                            save_path=os.path.join(PLOT_DIR, f"{dataset_name}_example_ts.pdf"))

    all_val_accs = []
    all_test_accs = []
    all_test_preds = []
    all_test_targets = []
    all_test_probs = []
    fprs_list = []
    tprs_list = []

    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_SEED)
    for fold, (train_idx, test_idx) in enumerate(skf.split(graphs, labels), 1):
        print(f"\nFold {fold}/{N_SPLITS}")
        
        val_size = int(0.2 * len(train_idx))
        train_indices = train_idx[:-val_size]
        val_indices = train_idx[-val_size:]
        
        # Create datasets
        train_dataset = []
        for i in train_indices:
            graph = graphs[i]
            graph.y = labels[i]
            train_dataset.append(graph)
            
        val_dataset = []
        for i in val_indices:
            graph = graphs[i]
            graph.y = labels[i]
            val_dataset.append(graph)
            
        test_dataset = []
        for i in test_idx:
            graph = graphs[i]
            graph.y = labels[i]
            test_dataset.append(graph)
        
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        model = GNNTimeSeriesClassifier(
            input_dim=2,
            hidden_dim=HIDDEN_DIM,
            num_classes=num_classes,
            heads=HEADS,
            dropout=DROPOUT
        ).to(device)
        
        optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
        
        # Class weighting
        train_labels_np = labels[train_indices].numpy()
        class_counts_train = np.bincount(train_labels_np)
        class_weights = torch.FloatTensor(1.0 / class_counts_train).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        
        model, val_acc = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            max_epochs=MAX_EPOCHS
        )
        
        test_loss, test_acc, predictions, targets, probabilities = evaluate(
            model=model,
            loader=test_loader,
            criterion=criterion,
            device=device
        )
        
        # Compute AUC if binary classification
        if num_classes == 2:
            fpr, tpr, _ = roc_curve(targets, probabilities[:, 1])
            fprs_list.append(fpr)
            tprs_list.append(tpr)
        
        all_val_accs.append(val_acc)
        all_test_accs.append(test_acc)
        all_test_preds.extend(predictions)
        all_test_targets.extend(targets)
        all_test_probs.extend(probabilities)
        
        print(f"Fold {fold} - Test Accuracy: {test_acc:.4f}")
        
        # Visualize attention weights for first test sample (if available)
        if len(test_dataset) > 0:
            example_data = test_dataset[0].to(device)
            attention_weights = model.get_attention_weights(example_data)
            example_series = example_data.x[:, 0].cpu().numpy()
            
            visualize_attention(
                series=example_series,
                attention_weights=attention_weights,
                title=f"{dataset_name} - Fold {fold} Attention Visualization",
                save_path=os.path.join(PLOT_DIR, f"{dataset_name}_fold{fold}_attention.pdf")
            )
    
    all_test_preds = np.array(all_test_preds)
    all_test_targets = np.array(all_test_targets)
    all_test_probs = np.array(all_test_probs)

    # Compute bootstrap CIs for accuracy and AUC
    mean_acc, ci_acc = bootstrap_confidence_interval(all_test_accs)
    print("\nFinal Results:")
    print(f"Validation Accuracy (mean ± std): {np.mean(all_val_accs):.4f} ± {np.std(all_val_accs):.4f}")
    print(f"Test Accuracy: {mean_acc:.4f} (95% CI: {ci_acc.low:.4f}-{ci_acc.high:.4f})")

    if num_classes == 2:
        # Compute AUC across all folds (using combined predictions)
        combined_fpr, combined_tpr, _ = roc_curve(all_test_targets, all_test_probs[:,1])
        combined_auc = auc(combined_fpr, combined_tpr)
        
        # Bootstrap AUC CI
        def auc_stat(data):
            # data: (targets, probs)
            targets_ = data[0].astype(int)
            probs_ = data[1]
            fpr_, tpr_, _ = roc_curve(targets_, probs_)
            return auc(fpr_, tpr_)
        
        auc_samples = []
        rng = np.random.default_rng(RANDOM_SEED)
        for _ in range(10000):
            idx = rng.integers(0, len(all_test_targets), len(all_test_targets))
            auc_samples.append(auc_stat((all_test_targets[idx], all_test_probs[idx,1])))
        auc_samples = np.array(auc_samples)
        mean_auc_ = np.mean(auc_samples)
        low_auc_ = np.percentile(auc_samples, 2.5)
        high_auc_ = np.percentile(auc_samples, 97.5)
        
        print(f"Test AUC: {mean_auc_:.4f} (95% CI: {low_auc_:.4f}-{high_auc_:.4f})")
    
    # Plot confusion matrix
    cm = confusion_matrix(all_test_targets, all_test_preds)
    plot_confusion_matrix(
        cm=cm,
        classes=[str(l) for l in unique_labels],
        title=f"{dataset_name} Confusion Matrix",
        save_path=os.path.join(PLOT_DIR, f"{dataset_name}_confusion_matrix.pdf")
    )
    
    # Plot aggregated ROC curve with CIs (for binary classification)
    if num_classes == 2 and len(fprs_list) == N_SPLITS:
        plot_roc_with_ci(
            fprs=fprs_list,
            tprs=tprs_list,
            title=f"{dataset_name} ROC Curve (with 95% CI)",
            save_path=os.path.join(PLOT_DIR, f"{dataset_name}_roc_curve.pdf")
        )
    
    # Print classification report
    print("\nClassification Report:")
    print(classification_report(all_test_targets, all_test_preds, target_names=[str(l) for l in unique_labels]))

if __name__ == "__main__":
    print(f"Using device: {device}")
    for dataset_name in DATASETS:
        try:
            run_experiment(dataset_name)
        except Exception as e:
            print(f"Error processing dataset {dataset_name}: {str(e)}")
            continue
    
    print("\nAll experiments completed. Results saved in:", PLOT_DIR)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu

=== Dataset: ECG200 ===
Number of classes: 2
Total samples: 200 (Train: 100, Test: 100)
Class -1: 67 samples
Class 1: 133 samples

Fold 1/5
Epoch 0: Train Loss = 0.7026, Train Acc = 0.4609, Val Loss = 0.6961, Val Acc = 0.6562
Epoch 10: Train Loss = 0.6858, Train Acc = 0.5703, Val Loss = 0.6897, Val Acc = 0.3438
Epoch 20: Train Loss = 0.6707, Train Acc = 0.6875, Val Loss = 0.6589, Val Acc = 0.4062
Epoch 30: Train Loss = 0.4710, Train Acc = 0.7969, Val Loss = 0.4841, Val Acc = 0.7500
Epoch 40: Train Loss = 0.3532, Train Acc = 0.8438, Val Loss = 0.4912, Val Acc = 0.7500
Epoch 50: Train Loss = 0.3623, Train Acc = 0.8516, Val Loss = 0.4403, Val Acc = 0.8125
Early stopping at epoch 52
Fold 1 - Test Accuracy: 0.9000

Fold 2/5
Epoch 0: Train Loss = 0.7153, Train Acc = 0.4766, Val Loss = 0.7008, Val Acc = 0.6562
Epoch 10: Train Loss = 0.6914, Train Acc = 0.5078, Val Loss = 0.6899, Val Acc = 0.3438
Epoch 20: Train Loss = 0.6463, Train Acc = 0.7188, Val Loss = 0.6127, Val Acc =

In [3]:
import os
import copy
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib
matplotlib.use('Agg')  # For headless environments
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import bootstrap
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (accuracy_score, confusion_matrix, classification_report, roc_curve, auc,
                             balanced_accuracy_score, matthews_corrcoef)
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.utils import add_self_loops
from tslearn.datasets import UCR_UEA_datasets
from captum.attr import Saliency

warnings.filterwarnings('ignore')

# Configuration
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)

MAX_EPOCHS = 100
BATCH_SIZE = 32
LR = 0.001
N_SPLITS = 5
HIDDEN_DIM = 64
HEADS = 8
DROPOUT = 0.2
PLOT_DIR = "./plots"
os.makedirs(PLOT_DIR, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DATASETS = ["ECG200", "ECGFiveDays", "TwoLeadECG"]  # Example datasets

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def load_dataset(name):
    """Load UCR/UEA dataset, normalize, and return all data and labels."""
    ucr = UCR_UEA_datasets()
    X_train, y_train, X_test, y_test = ucr.load_dataset(name)
    
    # Squeeze unnecessary dimensions for univariate TS
    if X_train.ndim == 3 and X_train.shape[1] == 1:
        X_train = X_train.squeeze(1)
    if X_test.ndim == 3 and X_test.shape[1] == 1:
        X_test = X_test.squeeze(1)
    
    unique_labels = np.unique(np.concatenate([y_train, y_test]))
    label_to_idx = {lab: i for i, lab in enumerate(unique_labels)}
    y_train = np.array([label_to_idx[lab] for lab in y_train], dtype=np.int64)
    y_test = np.array([label_to_idx[lab] for lab in y_test], dtype=np.int64)
    
    # Normalize data
    X_mean = X_train.mean()
    X_std = X_train.std()
    X_train = (X_train - X_mean) / (X_std + 1e-8)
    X_test = (X_test - X_mean) / (X_std + 1e-8)
    
    X_all = np.concatenate([X_train, X_test], axis=0)
    y_all = np.concatenate([y_train, y_test], axis=0)
    
    return X_all, y_all, unique_labels, (X_train, y_train, X_test, y_test)

def timeseries_to_graph(X, window=5):
    """Convert time series to graph data by connecting each node to neighbors within a given window."""
    graphs = []
    for i in range(X.shape[0]):
        x_val = X[i]
        length = x_val.shape[0]
        positions = np.linspace(0, 1, length)
        x_feat = np.column_stack([x_val, positions])
        x_feat = torch.tensor(x_feat, dtype=torch.float32)
        
        edge_list = []
        w = min(window, length)
        for j in range(length):
            for k in range(max(0, j-w), min(length, j+w+1)):
                if j != k:
                    edge_list.append([j, k])
        
        if edge_list:
            edge_index = torch.tensor(edge_list, dtype=torch.long).t()
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
        
        edge_index, _ = add_self_loops(edge_index, num_nodes=length)
        data = Data(x=x_feat, edge_index=edge_index)
        graphs.append(data)
    return graphs

class GATConvWithAlpha(GATConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.last_alpha = None
        self._cached_edge_index = None

    def forward(self, x, edge_index, return_attention_weights=False):
        out, (edge_index_out, alpha) = super().forward(x, edge_index, return_attention_weights=True)
        self.last_alpha = alpha
        self._cached_edge_index = edge_index_out
        return out

class GNNTimeSeriesClassifier(nn.Module):
    """GAT-based classifier for time series represented as graphs."""
    def __init__(self, input_dim=2, hidden_dim=64, num_classes=2, heads=8, dropout=0.2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.conv1 = GATConvWithAlpha(hidden_dim, hidden_dim, heads=heads, dropout=dropout)
        self.conv2 = GATConvWithAlpha(hidden_dim*heads, hidden_dim, heads=heads, dropout=dropout, concat=False)
        
        self.ln1 = nn.LayerNorm(hidden_dim * heads)
        self.ln2 = nn.LayerNorm(hidden_dim)
        
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.input_proj(x)
        x = torch.relu(x)
        
        x1 = self.conv1(x, edge_index)
        x1 = self.ln1(x1)
        x1 = torch.relu(x1)
        x1 = self.dropout(x1)
        
        x2 = self.conv2(x1, edge_index)
        x2 = self.ln2(x2)
        x2 = torch.relu(x2)
        
        x_pool = global_mean_pool(x2, batch)
        x_pool = self.fc1(x_pool)
        x_pool = torch.relu(x_pool)
        x_pool = self.dropout(x_pool)
        out = self.fc2(x_pool)
        
        return out

    def get_attention_weights(self, data):
        """Obtain normalized node-level attention weights from the first GAT layer."""
        self.eval()
        with torch.no_grad():
            _ = self.forward(data)
            alpha1 = self.conv1.last_alpha
            edge_idx = self.conv1._cached_edge_index

        num_nodes = data.x.size(0)
        node_importance = torch.zeros(num_nodes, device=data.x.device)
        
        alpha_mean = alpha1.mean(dim=1)
        for i in range(edge_idx.size(1)):
            dst_node = edge_idx[1, i].item()
            node_importance[dst_node] += alpha_mean[i].item()
        
        node_importance = node_importance / (node_importance.sum() + 1e-9)
        return node_importance.cpu().numpy()

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    predictions = []
    targets = []
    
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        output = model(batch)
        loss = criterion(output, batch.y.squeeze())
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        predictions.extend(output.argmax(dim=1).cpu().numpy())
        targets.extend(batch.y.cpu().numpy())
    
    return total_loss / len(loader), accuracy_score(targets, predictions)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    predictions = []
    targets = []
    outputs = []
    
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            output = model(batch)
            loss = criterion(output, batch.y.squeeze())
            
            total_loss += loss.item()
            predictions.extend(output.argmax(dim=1).cpu().numpy())
            targets.extend(batch.y.cpu().numpy())
            outputs.extend(torch.softmax(output, dim=1).cpu().numpy())
    
    return (
        total_loss / len(loader),
        accuracy_score(targets, predictions),
        np.array(predictions),
        np.array(targets),
        np.array(outputs)
    )

def train_model(model, train_loader, val_loader, criterion, optimizer, device, max_epochs=100):
    early_stopping = EarlyStopping(patience=10)
    best_val_acc = 0
    best_model = None
    
    for epoch in range(max_epochs):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, _, _, _ = evaluate(model, val_loader, criterion, device)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = copy.deepcopy(model.state_dict())
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Train Loss = {train_loss:.4f}, Train Acc = {train_acc:.4f}, '
                  f'Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.4f}')
        
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print(f"Early stopping at epoch {epoch}")
            break
    
    model.load_state_dict(best_model)
    return model, best_val_acc

def bootstrap_confidence_interval(data, stat_func=np.mean, confidence_level=0.95, n_resamples=10000):
    """Compute bootstrap confidence interval for a given statistic."""
    res = bootstrap((data,), stat_func, confidence_level=confidence_level, n_resamples=n_resamples, method='basic')
    return stat_func(data), res.confidence_interval

def plot_confusion_matrix(cm, classes, title, save_path):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.title(title, fontsize=16)
    plt.ylabel('True Label', fontsize=14)
    plt.xlabel('Predicted Label', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, format='pdf')
    plt.close()

def plot_roc_with_ci(fprs, tprs, title, save_path):
    """Plot mean ROC curve and 95% CI band from multiple folds."""
    mean_fpr = np.linspace(0, 1, 100)
    interp_tprs = []
    for fpr, tpr in zip(fprs, tprs):
        interp_tpr = np.interp(mean_fpr, fpr, tpr)
        interp_tprs.append(interp_tpr)
    interp_tprs = np.array(interp_tprs)
    
    mean_tpr = interp_tprs.mean(axis=0)
    std_tpr = interp_tprs.std(axis=0)
    tpr_upper = np.minimum(mean_tpr + 1.96 * std_tpr, 1)
    tpr_lower = np.maximum(mean_tpr - 1.96 * std_tpr, 0)
    mean_auc = auc(mean_fpr, mean_tpr)
    
    plt.figure(figsize=(8,6))
    plt.plot(mean_fpr, mean_tpr, color='darkorange', lw=2, label=f'Mean ROC (AUC = {mean_auc:.2f})')
    plt.fill_between(mean_fpr, tpr_lower, tpr_upper, color='grey', alpha=0.2, label='95% CI')
    plt.plot([0,1],[0,1], color='navy', lw=2, linestyle='--')
    plt.xlim([0,1])
    plt.ylim([0,1.05])
    plt.xlabel('False Positive Rate', fontsize=14)
    plt.ylabel('True Positive Rate', fontsize=14)
    plt.title(title, fontsize=16)
    plt.legend(loc='lower right')
    plt.tight_layout()
    plt.savefig(save_path, format='pdf')
    plt.close()

def visualize_attention(series, attention_weights, title, save_path):
    """Visualize time series with attention weights as a color overlay."""
    plt.figure(figsize=(12, 6))
    plt.plot(series, label='Time Series', color='blue')
    att_norm = (attention_weights - attention_weights.min()) / (attention_weights.max() - attention_weights.min() + 1e-9)
    cmap = plt.cm.Reds
    for i in range(len(series)):
        color = cmap(att_norm[i])
        plt.axvspan(i-0.5, i+0.5, color=color, alpha=0.3)
    plt.title(title, fontsize=16)
    plt.xlabel('Time Step', fontsize=14)
    plt.ylabel('Normalized Value', fontsize=14)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, format='pdf')
    plt.close()

def plot_example_timeseries(X, y, unique_labels, dataset_name, save_path):
    """Plot a few example time series from each class for qualitative inspection."""
    plt.figure(figsize=(12, 8))
    num_classes = len(unique_labels)
    examples_per_class = min(3, len(X)//num_classes)
    
    for i, cls in enumerate(unique_labels):
        class_indices = np.where(y == i)[0]
        chosen = np.random.choice(class_indices, size=examples_per_class, replace=False)
        for j, idx in enumerate(chosen):
            plt.subplot(num_classes, examples_per_class, i*examples_per_class + j + 1)
            plt.plot(X[idx], color='blue')
            plt.title(f'Class: {cls}', fontsize=12)
            plt.xticks([])
            plt.yticks([])
    plt.suptitle(f'{dataset_name} - Example Time Series', fontsize=16)
    plt.tight_layout(rect=[0,0,1,0.95])
    plt.savefig(save_path, format='pdf')
    plt.close()

def plot_graph_with_attention(data, attention_weights, title, save_path):
    """Visualize graph structure with nodes colored by attention weights.
    Here we position nodes along x-axis (time) and y-axis as their original value."""
    x_coords = data.x[:,1].cpu().numpy() * len(data.x)  # position scaled to length
    y_vals = data.x[:,0].cpu().numpy()
    att_norm = (attention_weights - attention_weights.min()) / (attention_weights.max() - attention_weights.min() + 1e-9)

    plt.figure(figsize=(12, 6))
    # Draw edges
    edge_index = data.edge_index.cpu().numpy()
    for i in range(edge_index.shape[1]):
        src, dst = edge_index[:, i]
        plt.plot([x_coords[src], x_coords[dst]], [y_vals[src], y_vals[dst]], color='lightgray', linewidth=1, alpha=0.5)

    # Draw nodes
    plt.scatter(x_coords, y_vals, c=att_norm, cmap='Reds', s=50, edgecolors='black')
    plt.colorbar(label='Attention Weight (normalized)')
    plt.title(title, fontsize=16)
    plt.xlabel('Time Step (scaled)')
    plt.ylabel('Normalized Value')
    plt.tight_layout()
    plt.savefig(save_path, format='pdf')
    plt.close()

def plot_attention_distribution(att_weights_all, title, save_path):
    """Plot the distribution (histogram) of attention weights aggregated across samples."""
    plt.figure(figsize=(8,6))
    plt.hist(att_weights_all, bins=30, color='steelblue', edgecolor='black', alpha=0.7)
    plt.title(title, fontsize=16)
    plt.xlabel('Attention Weight')
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.savefig(save_path, format='pdf')
    plt.close()

def plot_average_attention_profile(all_att_arrays, title, save_path):
    """Plot mean and confidence interval of attention weights across samples at each timestep."""
    # all_att_arrays: list of arrays of shape (num_nodes,)
    # We need to align them by length. Let's assume all have same length for simplicity.
    lengths = [len(a) for a in all_att_arrays]
    min_len = min(lengths)  # if not equal, truncate
    trimmed = np.array([a[:min_len] for a in all_att_arrays])
    mean_att = trimmed.mean(axis=0)
    std_att = trimmed.std(axis=0)

    timesteps = np.arange(min_len)
    plt.figure(figsize=(10,6))
    plt.plot(timesteps, mean_att, color='red', linewidth=2, label='Mean Attention')
    plt.fill_between(timesteps, mean_att - 1.96*std_att, mean_att + 1.96*std_att, color='pink', alpha=0.3, label='95% CI')
    plt.title(title, fontsize=16)
    plt.xlabel('Time Step')
    plt.ylabel('Attention Weight')
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, format='pdf')
    plt.close()

def compute_saliency(model, data, device):
    """Compute saliency (gradient-based) using Captum to understand feature importance."""
    model.eval()
    data = data.to(device)
    # We consider input node features as baseline
    saliency = Saliency(model_forward_wrapper(model))
    baseline = data.x.clone().detach().requires_grad_(True)
    attributions = saliency.attribute(data.x.unsqueeze(0), target=int(data.y[0].item()), additional_forward_args=(data.edge_index, data.batch))
    # attributions shape: [1, num_nodes, num_features]
    return attributions.squeeze(0).cpu().numpy()

def model_forward_wrapper(model):
    """Wrapper for model forward to use Captum with node-level graph data."""
    def forward(x, edge_index, batch):
        # Rebuild a Data object
        data = Data(x=x[0], edge_index=edge_index, batch=batch)
        out = model(data)
        return out
    return forward

def run_experiment(dataset_name):
    print(f"\n=== Dataset: {dataset_name} ===")
    
    # Load and preprocess data
    X_all, y_all, unique_labels, (X_train_orig, y_train_orig, X_test_orig, y_test_orig) = load_dataset(dataset_name)
    graphs = timeseries_to_graph(X_all)
    labels = torch.tensor(y_all, dtype=torch.long)
    num_classes = len(unique_labels)

    # Print dataset details
    print(f"Number of classes: {num_classes}")
    print(f"Total samples: {len(X_all)} (Train: {len(X_train_orig)}, Test: {len(X_test_orig)})")
    class_counts = np.bincount(y_all)
    for i, lab in enumerate(unique_labels):
        print(f"Class {lab}: {class_counts[i]} samples")

    # Plot example time series
    plot_example_timeseries(X_all, y_all, unique_labels, dataset_name, 
                            save_path=os.path.join(PLOT_DIR, f"{dataset_name}_example_ts.pdf"))

    all_val_accs = []
    all_test_accs = []
    all_bal_accs = []
    all_mccs = []
    all_test_preds = []
    all_test_targets = []
    all_test_probs = []
    fprs_list = []
    tprs_list = []
    all_test_attentions = []

    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_SEED)
    for fold, (train_idx, test_idx) in enumerate(skf.split(graphs, labels), 1):
        print(f"\nFold {fold}/{N_SPLITS}")
        
        val_size = int(0.2 * len(train_idx))
        train_indices = train_idx[:-val_size]
        val_indices = train_idx[-val_size:]
        
        # Create datasets
        train_dataset = []
        for i in train_indices:
            graph = graphs[i]
            graph.y = labels[i]
            train_dataset.append(graph)
            
        val_dataset = []
        for i in val_indices:
            graph = graphs[i]
            graph.y = labels[i]
            val_dataset.append(graph)
            
        test_dataset = []
        for i in test_idx:
            graph = graphs[i]
            graph.y = labels[i]
            test_dataset.append(graph)
        
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        model = GNNTimeSeriesClassifier(
            input_dim=2,
            hidden_dim=HIDDEN_DIM,
            num_classes=num_classes,
            heads=HEADS,
            dropout=DROPOUT
        ).to(device)
        
        optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
        
        # Class weighting
        train_labels_np = labels[train_indices].numpy()
        class_counts_train = np.bincount(train_labels_np)
        class_weights = torch.FloatTensor(1.0 / class_counts_train).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        
        model, val_acc = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            max_epochs=MAX_EPOCHS
        )
        
        test_loss, test_acc, predictions, targets, probabilities = evaluate(
            model=model,
            loader=test_loader,
            criterion=criterion,
            device=device
        )

        bal_acc = balanced_accuracy_score(targets, predictions)
        mcc = matthews_corrcoef(targets, predictions)

        # Compute AUC if binary classification
        if num_classes == 2:
            fpr, tpr, _ = roc_curve(targets, probabilities[:, 1])
            fprs_list.append(fpr)
            tprs_list.append(tpr)
        
        # Store metrics
        all_val_accs.append(val_acc)
        all_test_accs.append(test_acc)
        all_bal_accs.append(bal_acc)
        all_mccs.append(mcc)
        all_test_preds.extend(predictions)
        all_test_targets.extend(targets)
        all_test_probs.extend(probabilities)

        print(f"Fold {fold} - Test Accuracy: {test_acc:.4f}, Balanced Acc: {bal_acc:.4f}, MCC: {mcc:.4f}")

        # Attention visualization on a test sample
        if len(test_dataset) > 0:
            example_data = test_dataset[0].clone().to(device)
            attention_weights = model.get_attention_weights(example_data)
            example_series = example_data.x[:, 0].cpu().numpy()
            visualize_attention(
                series=example_series,
                attention_weights=attention_weights,
                title=f"{dataset_name} - Fold {fold} Attention on TS",
                save_path=os.path.join(PLOT_DIR, f"{dataset_name}_fold{fold}_attention_ts.pdf")
            )
            plot_graph_with_attention(
                data=example_data.cpu(),
                attention_weights=attention_weights,
                title=f"{dataset_name} - Fold {fold} Graph with Node Attention",
                save_path=os.path.join(PLOT_DIR, f"{dataset_name}_fold{fold}_attention_graph.pdf")
            )
            
            # Also compute saliency for interpretation
            saliency_vals = compute_saliency(model, example_data.clone(), device)
            # Plot saliency as a heatmap
            plt.figure(figsize=(10,4))
            plt.imshow(saliency_vals.T, aspect='auto', cmap='coolwarm')
            plt.colorbar(label='Saliency')
            plt.title(f"{dataset_name} - Fold {fold} Saliency Heatmap")
            plt.xlabel('Node Index')
            plt.ylabel('Feature (0:value, 1:position)')
            plt.tight_layout()
            plt.savefig(os.path.join(PLOT_DIR, f"{dataset_name}_fold{fold}_saliency.pdf"), format='pdf')
            plt.close()

        # Collect attention weights from all test samples for aggregated analysis
        test_att_weights = []
        for d in test_dataset:
            d = d.to(device)
            aw = model.get_attention_weights(d)
            test_att_weights.append(aw)
        all_test_attentions.extend(test_att_weights)

    all_test_preds = np.array(all_test_preds)
    all_test_targets = np.array(all_test_targets)
    all_test_probs = np.array(all_test_probs)

    # Compute bootstrap CIs for accuracy and AUC
    mean_acc, ci_acc = bootstrap_confidence_interval(all_test_accs)
    mean_bal_acc, ci_bal_acc = bootstrap_confidence_interval(all_bal_accs)
    mean_mcc, ci_mcc = bootstrap_confidence_interval(all_mccs)

    print("\nFinal Results:")
    print(f"Validation Accuracy (mean ± std): {np.mean(all_val_accs):.4f} ± {np.std(all_val_accs):.4f}")
    print(f"Test Accuracy: {mean_acc:.4f} (95% CI: {ci_acc.low:.4f}-{ci_acc.high:.4f})")
    print(f"Balanced Accuracy: {mean_bal_acc:.4f} (95% CI: {ci_bal_acc.low:.4f}-{ci_bal_acc.high:.4f})")
    print(f"MCC: {mean_mcc:.4f} (95% CI: {ci_mcc.low:.4f}-{ci_mcc.high:.4f})")

    if num_classes == 2:
        # Compute AUC across all samples
        combined_fpr, combined_tpr, _ = roc_curve(all_test_targets, all_test_probs[:,1])
        combined_auc = auc(combined_fpr, combined_tpr)
        
        # Bootstrap AUC
        def auc_stat(data):
            targets_, probs_ = data
            fpr_, tpr_, _ = roc_curve(targets_.astype(int), probs_)
            return auc(fpr_, tpr_)

        auc_samples = []
        rng = np.random.default_rng(RANDOM_SEED)
        for _ in range(10000):
            idx = rng.integers(0, len(all_test_targets), len(all_test_targets))
            auc_samples.append(auc_stat((all_test_targets[idx], all_test_probs[idx,1])))
        auc_samples = np.array(auc_samples)
        mean_auc_ = np.mean(auc_samples)
        low_auc_ = np.percentile(auc_samples, 2.5)
        high_auc_ = np.percentile(auc_samples, 97.5)
        
        print(f"Test AUC: {mean_auc_:.4f} (95% CI: {low_auc_:.4f}-{high_auc_:.4f})")

    # Plot confusion matrix
    cm = confusion_matrix(all_test_targets, all_test_preds)
    plot_confusion_matrix(
        cm=cm,
        classes=[str(l) for l in unique_labels],
        title=f"{dataset_name} Confusion Matrix",
        save_path=os.path.join(PLOT_DIR, f"{dataset_name}_confusion_matrix.pdf")
    )
    
    # Plot aggregated ROC curve with CIs (for binary classification)
    if num_classes == 2 and len(fprs_list) == N_SPLITS:
        plot_roc_with_ci(
            fprs=fprs_list,
            tprs=tprs_list,
            title=f"{dataset_name} ROC Curve (with 95% CI)",
            save_path=os.path.join(PLOT_DIR, f"{dataset_name}_roc_curve.pdf")
        )

    # Classification report
    print("\nClassification Report:")
    print(classification_report(all_test_targets, all_test_preds, target_names=[str(l) for l in unique_labels]))

    # Attention distribution and average profile
    all_test_attentions_arr = np.concatenate(all_test_attentions)
    plot_attention_distribution(all_test_attentions_arr,
                                title=f"{dataset_name} - Attention Weight Distribution",
                                save_path=os.path.join(PLOT_DIR, f"{dataset_name}_attention_distribution.pdf"))

    plot_average_attention_profile(all_test_attentions,
                                   title=f"{dataset_name} - Average Attention Profile",
                                   save_path=os.path.join(PLOT_DIR, f"{dataset_name}_average_attention_profile.pdf"))

    # Save metrics to CSV
    results_dict = {
        'Val_Accuracy': all_val_accs,
        'Test_Accuracy': all_test_accs,
        'Balanced_Accuracy': all_bal_accs,
        'MCC': all_mccs
    }
    if num_classes == 2:
        results_dict['Test_AUC'] = [auc_stat((all_test_targets, all_test_probs[:,1]))]*len(all_test_accs)

    # Mean and CI also stored
    summary_dict = {
        'Mean_Accuracy': [mean_acc],
        'CI_Accuracy_Low': [ci_acc.low],
        'CI_Accuracy_High': [ci_acc.high],
        'Mean_Balanced_Accuracy': [mean_bal_acc],
        'CI_Bal_Acc_Low': [ci_bal_acc.low],
        'CI_Bal_Acc_High': [ci_bal_acc.high],
        'Mean_MCC': [mean_mcc],
        'CI_MCC_Low': [ci_mcc.low],
        'CI_MCC_High': [ci_mcc.high]
    }

    if num_classes == 2:
        summary_dict['Mean_AUC'] = [mean_auc_]
        summary_dict['CI_AUC_Low'] = [low_auc_]
        summary_dict['CI_AUC_High'] = [high_auc_]

    import pandas as pd
    pd.DataFrame(results_dict).to_csv(os.path.join(PLOT_DIR, f"{dataset_name}_fold_metrics.csv"), index=False)
    pd.DataFrame(summary_dict).to_csv(os.path.join(PLOT_DIR, f"{dataset_name}_summary_metrics.csv"), index=False)
    
    print(f"Results and plots saved to {PLOT_DIR}")

if __name__ == "__main__":
    print(f"Using device: {device}")
    for dataset_name in DATASETS:
        try:
            run_experiment(dataset_name)
        except Exception as e:
            print(f"Error processing dataset {dataset_name}: {str(e)}")
            continue
    
    print("\nAll experiments completed. Results saved in:", PLOT_DIR)


Using device: cpu

=== Dataset: ECG200 ===
Number of classes: 2
Total samples: 200 (Train: 100, Test: 100)
Class -1: 67 samples
Class 1: 133 samples

Fold 1/5
Epoch 0: Train Loss = 0.7026, Train Acc = 0.4609, Val Loss = 0.6961, Val Acc = 0.6562
Epoch 10: Train Loss = 0.6858, Train Acc = 0.5703, Val Loss = 0.6897, Val Acc = 0.3438
Epoch 20: Train Loss = 0.6707, Train Acc = 0.6875, Val Loss = 0.6589, Val Acc = 0.4062
Epoch 30: Train Loss = 0.4710, Train Acc = 0.7969, Val Loss = 0.4841, Val Acc = 0.7500
Epoch 40: Train Loss = 0.3532, Train Acc = 0.8438, Val Loss = 0.4912, Val Acc = 0.7500
Epoch 50: Train Loss = 0.3623, Train Acc = 0.8516, Val Loss = 0.4403, Val Acc = 0.8125
Early stopping at epoch 52
Fold 1 - Test Accuracy: 0.9000, Balanced Acc: 0.9066, MCC: 0.7917
Error processing dataset ECG200: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number

=== Dataset: ECGFiveDays ===
Number of classes: 2
Total samples

In [None]:
import os
import copy
import json
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib
matplotlib.use('Agg')  # For headless environments
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import bootstrap
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (accuracy_score, confusion_matrix, classification_report, roc_curve, auc,
                             balanced_accuracy_score, matthews_corrcoef)
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.utils import add_self_loops
from tslearn.datasets import UCR_UEA_datasets
from captum.attr import Saliency

warnings.filterwarnings('ignore')

# Configuration
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)

MAX_EPOCHS = 100
BATCH_SIZE = 32
LR = 0.001
N_SPLITS = 5
HIDDEN_DIM = 64
HEADS = 8
DROPOUT = 0.2
RESULTS_DIR = "./results"
os.makedirs(RESULTS_DIR, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DATASETS = ["ECG200", "ECGFiveDays", "TwoLeadECG"]  # Example datasets

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def load_dataset(name):
    """Load UCR/UEA dataset without normalization (to avoid global data leakage).
    We'll do normalization per-fold using training data only."""
    ucr = UCR_UEA_datasets()
    X_train, y_train, X_test, y_test = ucr.load_dataset(name)
    
    # Squeeze for univariate
    if X_train.ndim == 3 and X_train.shape[1] == 1:
        X_train = X_train.squeeze(1)
    if X_test.ndim == 3 and X_test.shape[1] == 1:
        X_test = X_test.squeeze(1)
    
    unique_labels = np.unique(np.concatenate([y_train, y_test]))
    label_to_idx = {lab: i for i, lab in enumerate(unique_labels)}
    y_train = np.array([label_to_idx[lab] for lab in y_train], dtype=np.int64)
    y_test = np.array([label_to_idx[lab] for lab in y_test], dtype=np.int64)
    
    X_all = np.concatenate([X_train, X_test], axis=0)
    y_all = np.concatenate([y_train, y_test], axis=0)
    
    return X_all, y_all, unique_labels, (X_train, y_train, X_test, y_test)

def normalize_data(X, mean_, std_):
    return (X - mean_) / (std_ + 1e-8)

def timeseries_to_graph(X, window=5):
    """Convert time series to graphs: nodes are timesteps; edges connect neighbors within a window."""
    graphs = []
    for i in range(X.shape[0]):
        x_val = X[i]
        length = x_val.shape[0]
        positions = np.linspace(0, 1, length)
        x_feat = np.column_stack([x_val, positions])
        x_feat = torch.tensor(x_feat, dtype=torch.float32)
        
        edge_list = []
        w = min(window, length)
        for j in range(length):
            for k in range(max(0, j-w), min(length, j+w+1)):
                if j != k:
                    edge_list.append([j, k])
        
        if edge_list:
            edge_index = torch.tensor(edge_list, dtype=torch.long).t()
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
        
        edge_index, _ = add_self_loops(edge_index, num_nodes=length)
        data = Data(x=x_feat, edge_index=edge_index)
        graphs.append(data)
    return graphs

class GATConvWithAlpha(GATConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.last_alpha = None
        self._cached_edge_index = None

    def forward(self, x, edge_index, return_attention_weights=False):
        out, (edge_index_out, alpha) = super().forward(x, edge_index, return_attention_weights=True)
        self.last_alpha = alpha
        self._cached_edge_index = edge_index_out
        return out

class GNNTimeSeriesClassifier(nn.Module):
    """GAT-based classifier for time series represented as graphs."""
    def __init__(self, input_dim=2, hidden_dim=64, num_classes=2, heads=8, dropout=0.2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.conv1 = GATConvWithAlpha(hidden_dim, hidden_dim, heads=heads, dropout=dropout)
        self.conv2 = GATConvWithAlpha(hidden_dim*heads, hidden_dim, heads=heads, dropout=dropout, concat=False)
        
        self.ln1 = nn.LayerNorm(hidden_dim * heads)
        self.ln2 = nn.LayerNorm(hidden_dim)
        
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.input_proj(x)
        x = torch.relu(x)
        
        x1 = self.conv1(x, edge_index)
        x1 = self.ln1(x1)
        x1 = torch.relu(x1)
        x1 = self.dropout(x1)
        
        x2 = self.conv2(x1, edge_index)
        x2 = self.ln2(x2)
        x2 = torch.relu(x2)
        
        x_pool = global_mean_pool(x2, batch)
        x_pool = self.fc1(x_pool)
        x_pool = torch.relu(x_pool)
        x_pool = self.dropout(x_pool)
        out = self.fc2(x_pool)
        
        return out

    def get_attention_weights(self, data):
        """Node-level attention weights from first GAT layer."""
        self.eval()
        with torch.no_grad():
            _ = self.forward(data)
            alpha1 = self.conv1.last_alpha
            edge_idx = self.conv1._cached_edge_index

        num_nodes = data.x.size(0)
        node_importance = torch.zeros(num_nodes, device=data.x.device)
        
        alpha_mean = alpha1.mean(dim=1)
        for i in range(edge_idx.size(1)):
            dst_node = edge_idx[1, i].item()
            node_importance[dst_node] += alpha_mean[i].item()
        
        node_importance = node_importance / (node_importance.sum() + 1e-9)
        return node_importance.cpu().numpy()

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    predictions = []
    targets = []
    
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        output = model(batch)
        loss = criterion(output, batch.y)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        predictions.extend(output.argmax(dim=1).cpu().numpy())
        targets.extend(batch.y.cpu().numpy())
    
    return total_loss / len(loader), accuracy_score(targets, predictions)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    predictions = []
    targets = []
    outputs = []
    
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            output = model(batch)
            loss = criterion(output, batch.y)
            
            total_loss += loss.item()
            predictions.extend(output.argmax(dim=1).cpu().numpy())
            targets.extend(batch.y.cpu().numpy())
            outputs.extend(torch.softmax(output, dim=1).cpu().numpy())
    
    return (
        total_loss / len(loader),
        accuracy_score(targets, predictions),
        np.array(predictions),
        np.array(targets),
        np.array(outputs)
    )

def train_model(model, train_loader, val_loader, criterion, optimizer, device, max_epochs=100):
    early_stopping = EarlyStopping(patience=10)
    best_val_acc = 0
    best_model = None
    
    for epoch in range(max_epochs):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, _, _, _ = evaluate(model, val_loader, criterion, device)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = copy.deepcopy(model.state_dict())
        
        # Logging every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Train Loss = {train_loss:.4f}, Train Acc = {train_acc:.4f}, '
                  f'Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.4f}')
        
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print(f"Early stopping at epoch {epoch}")
            break
    
    model.load_state_dict(best_model)
    return model, best_val_acc

def bootstrap_confidence_interval(data, stat_func=np.mean, confidence_level=0.95, n_resamples=10000):
    """Compute bootstrap confidence interval for a given statistic."""
    res = bootstrap((data,), stat_func, confidence_level=confidence_level, n_resamples=n_resamples, method='basic')
    return stat_func(data), res.confidence_interval

def plot_confusion_matrix(cm, classes, title, save_path):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.title(title, fontsize=16)
    plt.ylabel('True Label', fontsize=14)
    plt.xlabel('Predicted Label', fontsize=14)
    plt.tight_layout()
    plt.savefig(save_path, format='png')
    plt.close()

def plot_roc_with_ci(fprs, tprs, title, save_path):
    """Plot mean ROC curve with 95% CI."""
    mean_fpr = np.linspace(0, 1, 100)
    interp_tprs = []
    for fpr, tpr in zip(fprs, tprs):
        interp_tpr = np.interp(mean_fpr, fpr, tpr)
        interp_tprs.append(interp_tpr)
    interp_tprs = np.array(interp_tprs)
    
    mean_tpr = interp_tprs.mean(axis=0)
    std_tpr = interp_tprs.std(axis=0)
    tpr_upper = np.minimum(mean_tpr + 1.96 * std_tpr, 1)
    tpr_lower = np.maximum(mean_tpr - 1.96 * std_tpr, 0)
    mean_auc = auc(mean_fpr, mean_tpr)
    
    plt.figure(figsize=(8,6))
    plt.plot(mean_fpr, mean_tpr, color='darkorange', lw=2, label=f'Mean ROC (AUC = {mean_auc:.2f})')
    plt.fill_between(mean_fpr, tpr_lower, tpr_upper, color='grey', alpha=0.2, label='95% CI')
    plt.plot([0,1],[0,1], color='navy', lw=2, linestyle='--')
    plt.xlim([0,1])
    plt.ylim([0,1.05])
    plt.xlabel('False Positive Rate', fontsize=14)
    plt.ylabel('True Positive Rate', fontsize=14)
    plt.title(title, fontsize=16)
    plt.legend(loc='lower right')
    plt.tight_layout()
    plt.savefig(save_path, format='png')
    plt.close()

def visualize_attention(series, attention_weights, title, save_path):
    """Plot time series with attention weights as a translucent overlay."""
    plt.figure(figsize=(12, 6))
    plt.plot(series, label='Time Series', color='blue')
    att_norm = (attention_weights - attention_weights.min()) / (attention_weights.max() - attention_weights.min() + 1e-9)
    cmap = plt.cm.Reds
    for i in range(len(series)):
        color = cmap(att_norm[i])
        plt.axvspan(i-0.5, i+0.5, color=color, alpha=0.3)
    plt.title(title, fontsize=16)
    plt.xlabel('Time Step', fontsize=14)
    plt.ylabel('Normalized Value', fontsize=14)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, format='png')
    plt.close()

def plot_example_timeseries(X, y, unique_labels, dataset_name, save_path):
    """Plot examples of time series from each class."""
    plt.figure(figsize=(12, 8))
    num_classes = len(unique_labels)
    examples_per_class = min(3, len(X)//num_classes)
    
    for i, cls in enumerate(unique_labels):
        class_indices = np.where(y == i)[0]
        chosen = np.random.choice(class_indices, size=examples_per_class, replace=False)
        for j, idx in enumerate(chosen):
            plt.subplot(num_classes, examples_per_class, i*examples_per_class + j + 1)
            plt.plot(X[idx], color='blue')
            plt.title(f'Class: {cls}', fontsize=12)
            plt.xticks([])
            plt.yticks([])
    plt.suptitle(f'{dataset_name} - Example Time Series', fontsize=16)
    plt.tight_layout(rect=[0,0,1,0.95])
    plt.savefig(save_path, format='png')
    plt.close()

def plot_graph_with_attention(data, attention_weights, title, save_path):
    """Visualize graph structure with nodes colored by attention."""
    x_coords = data.x[:,1].cpu().numpy() * len(data.x)  # scale position
    y_vals = data.x[:,0].cpu().numpy()
    att_norm = (attention_weights - attention_weights.min()) / (attention_weights.max() - attention_weights.min() + 1e-9)

    plt.figure(figsize=(12, 6))
    # Draw edges
    edge_index = data.edge_index.cpu().numpy()
    for i in range(edge_index.shape[1]):
        src, dst = edge_index[:, i]
        plt.plot([x_coords[src], x_coords[dst]], [y_vals[src], y_vals[dst]], color='lightgray', linewidth=1, alpha=0.5)

    # Draw nodes
    sc = plt.scatter(x_coords, y_vals, c=att_norm, cmap='Reds', s=50, edgecolors='black')
    plt.colorbar(sc, label='Attention Weight (normalized)')
    plt.title(title, fontsize=16)
    plt.xlabel('Time Step (scaled)')
    plt.ylabel('Normalized Value')
    plt.tight_layout()
    plt.savefig(save_path, format='png')
    plt.close()

def plot_distribution(data_array, title, xlabel, save_path):
    """Plot a histogram distribution of given data."""
    plt.figure(figsize=(8,6))
    plt.hist(data_array, bins=30, color='steelblue', edgecolor='black', alpha=0.7)
    plt.title(title, fontsize=16)
    plt.xlabel(xlabel)
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.savefig(save_path, format='png')
    plt.close()

def plot_mean_ci_profile(all_arrays, title, xlabel, ylabel, save_path):
    """Plot mean and CI of profiles (e.g. attention over time) across samples."""
    lengths = [len(a) for a in all_arrays]
    min_len = min(lengths)
    trimmed = np.array([a[:min_len] for a in all_arrays])
    mean_ = trimmed.mean(axis=0)
    std_ = trimmed.std(axis=0)

    timesteps = np.arange(min_len)
    plt.figure(figsize=(10,6))
    plt.plot(timesteps, mean_, color='red', linewidth=2, label='Mean')
    plt.fill_between(timesteps, mean_ - 1.96*std_, mean_ + 1.96*std_, color='pink', alpha=0.3, label='95% CI')
    plt.title(title, fontsize=16)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, format='png')
    plt.close()

def model_forward_wrapper(model):
    def forward(x, edge_index, batch):
        data = Data(x=x[0], edge_index=edge_index, batch=batch)
        out = model(data)
        return out
    return forward

def compute_saliency(model, data, device):
    """Compute saliency (gradient-based) on node features."""
    model.eval()
    data = data.to(device)
    saliency = Saliency(model_forward_wrapper(model))
    # For classification target is a scalar, we get class from data.y
    target_class = int(data.y.item())
    attributions = saliency.attribute(data.x.unsqueeze(0), target=target_class,
                                      additional_forward_args=(data.edge_index, data.batch))
    return attributions.squeeze(0).cpu().numpy()

def run_experiment(dataset_name):
    print(f"\n=== Dataset: {dataset_name} ===")
    
    # Create dataset-specific directory
    ds_dir = os.path.join(RESULTS_DIR, dataset_name)
    os.makedirs(ds_dir, exist_ok=True)
    
    # Load raw data (no normalization)
    X_all, y_all, unique_labels, (X_train_orig, y_train_orig, X_test_orig, y_test_orig) = load_dataset(dataset_name)
    num_classes = len(unique_labels)

    print(f"Number of classes: {num_classes}")
    print(f"Total samples: {len(X_all)} (Train: {len(X_train_orig)}, Test: {len(X_test_orig)})")
    class_counts = np.bincount(y_all)
    for i, lab in enumerate(unique_labels):
        print(f"Class {lab}: {class_counts[i]} samples")

    # Plot example time series
    plot_example_timeseries(X_all, y_all, unique_labels, dataset_name, 
                            save_path=os.path.join(ds_dir, f"{dataset_name}_example_ts.png"))

    all_val_accs = []
    all_test_accs = []
    all_bal_accs = []
    all_mccs = []
    all_test_preds = []
    all_test_targets = []
    all_test_probs = []
    fprs_list = []
    tprs_list = []
    all_test_attentions = []
    all_test_saliencies = []

    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_SEED)
    for fold, (train_idx, test_idx) in enumerate(skf.split(X_all, y_all), 1):
        print(f"\nFold {fold}/{N_SPLITS}")

        # Split into train/val/test sets for this fold
        # val = last 20% of train indices
        val_size = int(0.2 * len(train_idx))
        train_indices = train_idx[:-val_size]
        val_indices = train_idx[-val_size:]
        
        # Normalize using only train subset
        X_train_fold = X_all[train_indices]
        mean_ = X_train_fold.mean()
        std_ = X_train_fold.std()

        X_train_norm = normalize_data(X_all[train_indices], mean_, std_)
        X_val_norm = normalize_data(X_all[val_indices], mean_, std_)
        X_test_norm = normalize_data(X_all[test_idx], mean_, std_)

        # Build graphs
        train_graphs = timeseries_to_graph(X_train_norm)
        val_graphs = timeseries_to_graph(X_val_norm)
        test_graphs = timeseries_to_graph(X_test_norm)
        
        labels_tensor = torch.tensor(y_all, dtype=torch.long)
        
        for i, g in enumerate(train_graphs):
            g.y = labels_tensor[train_indices[i]].unsqueeze(0)
        for i, g in enumerate(val_graphs):
            g.y = labels_tensor[val_indices[i]].unsqueeze(0)
        for i, g in enumerate(test_graphs):
            g.y = labels_tensor[test_idx[i]].unsqueeze(0)

        train_loader = DataLoader(train_graphs, batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_graphs, batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(test_graphs, batch_size=BATCH_SIZE, shuffle=False)
        
        model = GNNTimeSeriesClassifier(
            input_dim=2,
            hidden_dim=HIDDEN_DIM,
            num_classes=num_classes,
            heads=HEADS,
            dropout=DROPOUT
        ).to(device)
        
        optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
        
        # Class weighting
        train_labels_fold = y_all[train_indices]
        class_counts_train = np.bincount(train_labels_fold)
        class_weights = torch.FloatTensor(1.0 / class_counts_train).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        
        model, val_acc = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            max_epochs=MAX_EPOCHS
        )
        
        test_loss, test_acc, predictions, targets, probabilities = evaluate(
            model=model,
            loader=test_loader,
            criterion=criterion,
            device=device
        )

        bal_acc = balanced_accuracy_score(targets, predictions)
        mcc = matthews_corrcoef(targets, predictions)

        # Compute AUC if binary
        if num_classes == 2:
            fpr, tpr, _ = roc_curve(targets, probabilities[:, 1])
            fprs_list.append(fpr)
            tprs_list.append(tpr)
        
        # Store metrics
        all_val_accs.append(val_acc)
        all_test_accs.append(test_acc)
        all_bal_accs.append(bal_acc)
        all_mccs.append(mcc)
        all_test_preds.extend(predictions)
        all_test_targets.extend(targets)
        all_test_probs.extend(probabilities)

        print(f"Fold {fold} - Test Accuracy: {test_acc:.4f}, Balanced Acc: {bal_acc:.4f}, MCC: {mcc:.4f}")

        # Interpretability on multiple test samples
        # Let's pick 3 random test samples to visualize attention and saliency
        test_sample_indices = np.random.choice(len(test_graphs), size=min(3, len(test_graphs)), replace=False)
        for idx_ in test_sample_indices:
            example_data = test_graphs[idx_].clone().to(device)
            attention_weights = model.get_attention_weights(example_data)
            example_series = example_data.x[:, 0].cpu().numpy()

            # Attention on TS
            visualize_attention(
                series=example_series,
                attention_weights=attention_weights,
                title=f"{dataset_name} - Fold {fold} Sample {idx_} Attention (TS)",
                save_path=os.path.join(ds_dir, f"{dataset_name}_fold{fold}_sample{idx_}_attention_ts.png")
            )

            # Graph with attention
            plot_graph_with_attention(
                data=example_data.cpu(),
                attention_weights=attention_weights,
                title=f"{dataset_name} - Fold {fold} Sample {idx_} Graph Attention",
                save_path=os.path.join(ds_dir, f"{dataset_name}_fold{fold}_sample{idx_}_attention_graph.png")
            )
            
            # Saliency
            saliency_vals = compute_saliency(model, example_data.clone(), device)
            # Plot saliency heatmap
            plt.figure(figsize=(10,4))
            plt.imshow(saliency_vals.T, aspect='auto', cmap='coolwarm')
            plt.colorbar(label='Saliency')
            plt.title(f"{dataset_name} - Fold {fold} Sample {idx_} Saliency Heatmap")
            plt.xlabel('Node Index')
            plt.ylabel('Feature (0:value, 1:position)')
            plt.tight_layout()
            plt.savefig(os.path.join(ds_dir, f"{dataset_name}_fold{fold}_sample{idx_}_saliency.png"), format='png')
            plt.close()

        # Gather attention & saliency for all test samples
        for d in test_graphs:
            d = d.to(device)
            aw = model.get_attention_weights(d)
            all_test_attentions.append(aw)
            sal = compute_saliency(model, d.clone(), device)
            # We can aggregate saliency on the "value" feature only for interpretability
            all_test_saliencies.append(sal[:,0])  # focus on value dimension

    # Convert to np arrays
    all_test_preds = np.array(all_test_preds)
    all_test_targets = np.array(all_test_targets)
    all_test_probs = np.array(all_test_probs)

    # Compute bootstrap CIs
    mean_acc, ci_acc = bootstrap_confidence_interval(np.array(all_test_accs))
    mean_bal_acc, ci_bal_acc = bootstrap_confidence_interval(np.array(all_bal_accs))
    mean_mcc, ci_mcc = bootstrap_confidence_interval(np.array(all_mccs))

    # AUC bootstrap if binary
    mean_auc_, low_auc_, high_auc_ = None, None, None
    if num_classes == 2:
        def auc_stat(data):
            targets_, probs_ = data
            fpr_, tpr_, _ = roc_curve(targets_.astype(int), probs_)
            return auc(fpr_, tpr_)

        auc_samples = []
        rng = np.random.default_rng(RANDOM_SEED)
        for _ in range(10000):
            idx = rng.integers(0, len(all_test_targets), len(all_test_targets))
            auc_samples.append(auc_stat((all_test_targets[idx], all_test_probs[idx,1])))
        auc_samples = np.array(auc_samples)
        mean_auc_ = np.mean(auc_samples)
        low_auc_ = np.percentile(auc_samples, 2.5)
        high_auc_ = np.percentile(auc_samples, 97.5)

    # Confusion matrix
    cm = confusion_matrix(all_test_targets, all_test_preds)
    plot_confusion_matrix(
        cm=cm,
        classes=[str(l) for l in unique_labels],
        title=f"{dataset_name} Confusion Matrix",
        save_path=os.path.join(ds_dir, f"{dataset_name}_confusion_matrix.png")
    )

    # ROC curve with CI
    if num_classes == 2 and len(fprs_list) == N_SPLITS:
        plot_roc_with_ci(
            fprs=fprs_list,
            tprs=tprs_list,
            title=f"{dataset_name} ROC Curve (with 95% CI)",
            save_path=os.path.join(ds_dir, f"{dataset_name}_roc_curve.png")
        )

    # Classification report
    cls_report = classification_report(all_test_targets, all_test_preds, target_names=[str(l) for l in unique_labels])
    print("\nClassification Report:")
    print(cls_report)

    # Plot attention and saliency distributions
    all_test_attentions_arr = np.concatenate(all_test_attentions)
    plot_distribution(all_test_attentions_arr,
                      title=f"{dataset_name} - Attention Weight Distribution",
                      xlabel='Attention Weight',
                      save_path=os.path.join(ds_dir, f"{dataset_name}_attention_distribution.png"))

    # Average attention profile
    plot_mean_ci_profile(all_test_attentions,
                         title=f"{dataset_name} - Average Attention Profile",
                         xlabel='Time Step',
                         ylabel='Attention Weight',
                         save_path=os.path.join(ds_dir, f"{dataset_name}_average_attention_profile.png"))

    # Saliency distribution & profile
    all_sal = np.concatenate(all_test_saliencies)
    plot_distribution(all_sal,
                      title=f"{dataset_name} - Saliency Distribution (Value Feature)",
                      xlabel='Saliency Value',
                      save_path=os.path.join(ds_dir, f"{dataset_name}_saliency_distribution.png"))

    plot_mean_ci_profile(all_test_saliencies,
                         title=f"{dataset_name} - Average Saliency Profile (Value Feature)",
                         xlabel='Time Step',
                         ylabel='Saliency',
                         save_path=os.path.join(ds_dir, f"{dataset_name}_average_saliency_profile.png"))

    # Save metrics to CSV and JSON
    import pandas as pd

    # Per-fold metrics
    results_dict = {
        'Val_Accuracy': all_val_accs,
        'Test_Accuracy': all_test_accs,
        'Balanced_Accuracy': all_bal_accs,
        'MCC': all_mccs
    }
    if num_classes == 2:
        # Compute per-fold AUC (not bootstrapped, just direct from combined might not be correct per fold,
        # but we already have aggregated AUC from all samples)
        pass

    pd.DataFrame(results_dict).to_csv(os.path.join(ds_dir, f"{dataset_name}_fold_metrics.csv"), index=False)

    # Summary metrics with CIs
    summary_dict = {
        'Mean_Accuracy': mean_acc,
        'CI_Accuracy_Low': ci_acc.low,
        'CI_Accuracy_High': ci_acc.high,
        'Mean_Balanced_Accuracy': mean_bal_acc,
        'CI_Bal_Acc_Low': ci_bal_acc.low,
        'CI_Bal_Acc_High': ci_bal_acc.high,
        'Mean_MCC': mean_mcc,
        'CI_MCC_Low': ci_mcc.low,
        'CI_MCC_High': ci_mcc.high
    }

    if num_classes == 2 and mean_auc_ is not None:
        summary_dict.update({
            'Mean_AUC': mean_auc_,
            'CI_AUC_Low': low_auc_,
            'CI_AUC_High': high_auc_
        })

    # Save summary to CSV
    pd.DataFrame([summary_dict]).to_csv(os.path.join(ds_dir, f"{dataset_name}_summary_metrics.csv"), index=False)

    # Save summary to JSON
    with open(os.path.join(ds_dir, f"{dataset_name}_summary_metrics.json"), 'w') as f:
        json.dump(summary_dict, f, indent=4)
    
    # Save classification report
    with open(os.path.join(ds_dir, f"{dataset_name}_classification_report.txt"), 'w') as f:
        f.write(cls_report)

    print(f"Results and plots saved to {ds_dir}")

if __name__ == "__main__":
    print(f"Using device: {device}")
    for dataset_name in DATASETS:
        try:
            run_experiment(dataset_name)
        except Exception as e:
            print(f"Error processing dataset {dataset_name}: {str(e)}")
            continue
    
    print("\nAll experiments completed. Results saved in:", RESULTS_DIR)


Using device: cpu

=== Dataset: ECG200 ===
Number of classes: 2
Total samples: 200 (Train: 100, Test: 100)
Class -1: 67 samples
Class 1: 133 samples

Fold 1/5
Epoch 0: Train Loss = 0.7026, Train Acc = 0.4609, Val Loss = 0.6961, Val Acc = 0.6562
Epoch 10: Train Loss = 0.6858, Train Acc = 0.5703, Val Loss = 0.6897, Val Acc = 0.3438
Epoch 20: Train Loss = 0.6707, Train Acc = 0.6875, Val Loss = 0.6590, Val Acc = 0.4062
Epoch 30: Train Loss = 0.4736, Train Acc = 0.7969, Val Loss = 0.4828, Val Acc = 0.7500
Epoch 40: Train Loss = 0.3536, Train Acc = 0.8594, Val Loss = 0.4859, Val Acc = 0.7500
Epoch 50: Train Loss = 0.3882, Train Acc = 0.8281, Val Loss = 0.4613, Val Acc = 0.8125
Early stopping at epoch 52
Fold 1 - Test Accuracy: 0.9000, Balanced Acc: 0.9066, MCC: 0.7917

Fold 2/5
Epoch 0: Train Loss = 0.7153, Train Acc = 0.4766, Val Loss = 0.7008, Val Acc = 0.6562
Epoch 10: Train Loss = 0.6914, Train Acc = 0.5078, Val Loss = 0.6899, Val Acc = 0.3438
Epoch 20: Train Loss = 0.6464, Train Acc = 0