In [None]:
import torch
from captum.attr import IntegratedGradients, Saliency, LayerConductance
from captum.attr import visualization as viz
import matplotlib.pyplot as plt
import numpy as np

class NetworkExplainer:
    """Explainability wrapper for your MultiBandAttentionFusion model"""
    
    def __init__(self, model, device='cuda'):
        """
        Args:
            model: Your trained MultiBandAttentionFusion model
            device: 'cuda' or 'cpu'
        """
        self.model = model.to(device)
        self.model.eval()  # Set to evaluation mode
        self.device = device
        
    def prepare_input(self, subject_data):
        """Convert subject data to format suitable for Captum"""
        processed_subject = []
        for band_graphs in subject_data:
            band_graphs = [g.to(self.device) for g in band_graphs]
            processed_subject.append(band_graphs)
        return tuple(processed_subject)
    
    def explain_with_integrated_gradients(self, subject_data, target_class=None):
        """
        Use Integrated Gradients to explain predictions
        
        Args:
            subject_data: Single subject's data (list of band sequences)
            target_class: Which class to explain (0 or 1). If None, uses predicted class
            
        Returns:
            Dictionary with attributions for each band
        """
        # Prepare input
        processed_input = self.prepare_input(subject_data)
        
        # Get prediction
        with torch.no_grad():
            output = self.model([processed_input])
            pred_class = output.argmax(1).item()
        
        if target_class is None:
            target_class = pred_class
            
        print(f"Predicted class: {pred_class}, Explaining class: {target_class}")
        
        # Create wrapper function for Captum
        def forward_func(x):
            # x will be the node features we want to explain
            return self.model([processed_input])
        
        # Collect attributions for each band
        band_attributions = {}
        
        for band_idx, band_graphs in enumerate(subject_data):
            band_attrs = []
            
            for graph_idx, graph in enumerate(band_graphs):
                # Create baseline (zero features)
                baseline = torch.zeros_like(graph.x)
                
                # Wrapper to explain this specific graph
                def graph_forward(features):
                    # Temporarily replace graph features
                    orig_features = graph.x.clone()
                    graph.x = features
                    output = self.model([processed_input])
                    graph.x = orig_features  # Restore
                    return output
                
                # Apply Integrated Gradients
                ig = IntegratedGradients(graph_forward)
                attributions = ig.attribute(
                    graph.x, 
                    baselines=baseline,
                    target=target_class,
                    n_steps=50
                )
                
                band_attrs.append(attributions.detach().cpu().numpy())
            
            band_attributions[f'band_{band_idx}'] = band_attrs
        
        return {
            'predicted_class': pred_class,
            'explained_class': target_class,
            'band_attributions': band_attributions
        }
    
    def explain_with_saliency(self, subject_data, target_class=None):
        """
        Use Saliency (gradient-based) - FASTEST method
        
        Similar to integrated gradients but uses simple gradients
        """
        processed_input = self.prepare_input(subject_data)
        
        # Get prediction
        with torch.no_grad():
            output = self.model([processed_input])
            pred_class = output.argmax(1).item()
        
        if target_class is None:
            target_class = pred_class
            
        print(f"Predicted class: {pred_class}, Explaining class: {target_class}")
        
        band_attributions = {}
        
        for band_idx, band_graphs in enumerate(subject_data):
            band_attrs = []
            
            for graph in band_graphs:
                # Enable gradients for this graph
                graph.x.requires_grad = True
                
                # Forward pass
                output = self.model([processed_input])
                
                # Backward pass
                self.model.zero_grad()
                output[0, target_class].backward(retain_graph=True)
                
                # Get gradients (saliency)
                attributions = graph.x.grad.detach().cpu().numpy()
                band_attrs.append(attributions)
                
                # Cleanup
                graph.x.requires_grad = False
            
            band_attributions[f'band_{band_idx}'] = band_attrs
        
        return {
            'predicted_class': pred_class,
            'explained_class': target_class,
            'band_attributions': band_attributions
        }
    
    def visualize_attributions(self, attributions, band_names=None, node_names=None):
        """
        Visualize feature attributions across bands and time
        
        Args:
            attributions: Output from explain_* methods
            band_names: List of band names (e.g., ['Alpha', 'Beta'])
            node_names: List of node/channel names
        """
        band_attrs = attributions['band_attributions']
        n_bands = len(band_attrs)
        
        if band_names is None:
            band_names = [f'Band {i}' for i in range(n_bands)]
        
        fig, axes = plt.subplots(n_bands, 1, figsize=(12, 4*n_bands))
        if n_bands == 1:
            axes = [axes]
        
        for band_idx, band_name in enumerate(band_names):
            band_key = f'band_{band_idx}'
            if band_key in band_attrs:
                # Average across time steps
                avg_attrs = np.mean([arr for arr in band_attrs[band_key]], axis=0)
                
                # Sum across features to get node importance
                node_importance = np.abs(avg_attrs).sum(axis=1)
                
                # Plot
                axes[band_idx].bar(range(len(node_importance)), node_importance)
                axes[band_idx].set_title(f'{band_name} - Node Importance')
                axes[band_idx].set_xlabel('Node Index')
                axes[band_idx].set_ylabel('Attribution Magnitude')
                
                if node_names:
                    axes[band_idx].set_xticks(range(len(node_names)))
                    axes[band_idx].set_xticklabels(node_names, rotation=45)
        
        plt.tight_layout()
        return fig
    
    def get_temporal_importance(self, attributions):
        """
        Analyze how importance changes over time
        
        Returns:
            Dictionary mapping band names to temporal importance arrays
        """
        band_attrs = attributions['band_attributions']
        temporal_importance = {}
        
        for band_key, attrs_list in band_attrs.items():
            # Sum absolute attributions across nodes and features for each timestep
            timestep_importance = [np.abs(arr).sum() for arr in attrs_list]
            temporal_importance[band_key] = timestep_importance
        
        return temporal_importance


# ============ USAGE EXAMPLE ============

def analyze_saved_models(model_path, test_subject, band_names=['Alpha', 'Beta']):
    """
    Analyze a saved model without retraining
    
    Args:
        model_path: Path to saved model weights (.pt file)
        test_subject: A single subject's data to explain
        band_names: Names of frequency bands
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load your saved model
    model = MultiBandAttentionFusion(
        num_bands=len(band_names),
        hidden_channels=64,
        num_classes=2,
        dropout_rate=0.5,
        num_nodes=19,  # Adjust based on your data
        in_channels=18  # Adjust based on your data
    ).to(device)
    
    # Load saved weights
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # Create explainer
    explainer = NetworkExplainer(model, device)
    
    # Fast explanation with Saliency (recommended for quick analysis)
    print("Computing saliency-based attributions...")
    saliency_attrs = explainer.explain_with_saliency(test_subject)
    
    # More accurate explanation with Integrated Gradients (slower but better)
    print("\nComputing integrated gradients attributions...")
    ig_attrs = explainer.explain_with_integrated_gradients(test_subject)
    
    # Visualize
    fig1 = explainer.visualize_attributions(saliency_attrs, band_names)
    plt.savefig('saliency_explanation.png')
    
    fig2 = explainer.visualize_attributions(ig_attrs, band_names)
    plt.savefig('integrated_gradients_explanation.png')
    
    # Temporal analysis
    temporal_imp = explainer.get_temporal_importance(ig_attrs)
    
    plt.figure(figsize=(10, 5))
    for band_name, importance in temporal_imp.items():
        plt.plot(importance, label=band_name, marker='o')
    plt.xlabel('Time Step')
    plt.ylabel('Importance')
    plt.title('Temporal Importance of Each Band')
    plt.legend()
    plt.savefig('temporal_importance.png')
    
    print("\nâœ… Analysis complete! Check saved images.")
    
    return saliency_attrs, ig_attrs


# Example: Load and analyze multiple runs
def batch_analyze_runs(num_runs=10, dataset=None, band_names=['Alpha', 'Beta']):
    """Analyze all saved model runs"""
    
    all_attributions = []
    
    for run in range(1, num_runs + 1):
        model_path = f"results/final_model_run{run}.pt"
        
        if not os.path.exists(model_path):
            print(f"Skipping run {run}: model not found")
            continue
        
        print(f"\n=== Analyzing Run {run} ===")
        
        # Use first test subject (you can modify this)
        test_subject = dataset[0]  # Replace with actual test data
        
        saliency_attrs, ig_attrs = analyze_saved_models(
            model_path, test_subject, band_names
        )
        
        all_attributions.append({
            'run': run,
            'saliency': saliency_attrs,
            'integrated_gradients': ig_attrs
        })
    
    return all_attributions


# ============ QUICK START ============
if __name__ == "__main__":
    # Assuming you have:
    # - Saved model at "results/final_model_run1.pt"
    # - Test dataset loaded
    
    band_names = ['alpha', 'beta']
    
    # Single model analysis
    test_subject = dataset[0]  # Replace with your test data
    saliency_attrs, ig_attrs = analyze_saved_models(
        "results/final_model_run1.pt",
        test_subject,
        band_names
    )
    
    # Or batch analyze all runs
    # all_results = batch_analyze_runs(num_runs=10, dataset=test_dataset, band_names=band_names)