In [1]:
from torch_geometric.explain import GNNExplainer
from captum.attr import IntegratedGradients, GradientShap
import torch
import numpy as np
import copy


In [4]:


def explain_model_predictions(model, dataset, device, num_subjects=None, band_names=None):
    model.eval()
    model.to(device)

    if band_names is None:
        band_names = [f"Band_{i}" for i in range(len(dataset[0]))]

    subject_explanations = []
    if num_subjects is None or num_subjects > len(dataset):
       num_subjects = len(dataset)


    # for subject_idx in range(min(num_subjects, len(dataset))):
    for subject_idx in range(len(dataset)):

        subject_data = dataset[subject_idx]
        subject_label = subject_data[0][0].y.item()

        print(f"Explaining subject {subject_idx + 1}/{num_subjects} (Label: {subject_label})")

        processed_subject = []
        for band_graphs in subject_data:
            band_graphs = [g.to(device) for g in band_graphs]
            processed_subject.append(band_graphs)

        batch = [tuple(processed_subject)]

        with torch.no_grad():
            prediction = model(batch)
            predicted_class = prediction.argmax(1).item()
            confidence = torch.softmax(prediction, dim=1).max().item()

        band_explanations = []

        for band_idx, band_graphs in enumerate(processed_subject):
            print(f"  Analyzing {band_names[band_idx]} band...")

            band_explanation = {
                'band_name': band_names[band_idx],
                'band_index': band_idx,
                'time_steps': [],
                'avg_node_importance': None,
                'avg_feature_importance': None,
                'temporal_importance': []
            }

            time_step_explanations = []
            all_node_importances = []
            all_feature_importances = []

            for time_idx, graph in enumerate(band_graphs):
                print(f"    Explaining time step {time_idx} ...")

                try:
                    graph_explanation = explain_single_graph(
                        model, graph, batch, band_idx, time_idx,
                        target_class=predicted_class, device=device
                    )

                    if graph_explanation is None:
                        print(f"    Skipping time step {time_idx} due to explanation failure.")
                        continue

                    time_step_explanations.append({
                        'time_step': time_idx,
                        'node_importance': graph_explanation['node_importance'],
                        'feature_importance': graph_explanation['feature_importance'],
                    })

                    all_node_importances.append(graph_explanation['node_importance'])
                    all_feature_importances.append(graph_explanation['feature_importance'])

                    print(f"    Done with time step {time_idx}")
                    torch.cuda.empty_cache()  # Free up GPU memory

                except Exception as e:
                    print(f"    Error explaining time step {time_idx}: {e}")
                    continue


            if time_step_explanations:
                band_explanation['time_steps'] = time_step_explanations
                band_explanation['avg_node_importance'] = np.mean(all_node_importances, axis=0)
                band_explanation['avg_feature_importance'] = np.mean(all_feature_importances, axis=0)
                band_explanation['temporal_importance'] = [
                    np.mean(node_imp) for node_imp in all_node_importances
                ]

            band_explanations.append(band_explanation)

        subject_explanations.append({
            'subject_index': subject_idx,
            'true_label': subject_label,
            'predicted_label': predicted_class,
            'confidence': confidence,
            'band_explanations': band_explanations
        })

    overall_band_importance = calculate_overall_band_importance(subject_explanations, band_names)

    return {
        'subject_explanations': subject_explanations,
        'overall_band_importance': overall_band_importance,
        'summary': generate_explanation_summary(subject_explanations, band_names)
    }

import copy
from captum.attr import IntegratedGradients, GradientShap

def explain_single_graph(model, graph, full_batch, band_idx, time_idx, target_class, device):
    graph = copy.deepcopy(graph).to(device)
    graph.x.requires_grad_(True)


    def model_wrapper(x):
        temp_subject = copy.deepcopy(full_batch[0])  # unwrap from outer list
        temp_subject[band_idx][time_idx].x = x
        return model([tuple(temp_subject)])  # rewrap for model


    try:
        ig = IntegratedGradients(model_wrapper)
        attributions = ig.attribute(
            graph.x,
            target=target_class,
            n_steps=20,
            internal_batch_size=1  # <- important
        )


        node_importance = attributions.abs().mean(dim=1).detach().cpu().numpy()
        feature_importance = attributions.abs().mean(dim=0).detach().cpu().numpy()
        

        return {
            'node_importance': node_importance,
            'feature_importance': feature_importance,
            
        }

    except Exception as e:
        print(f"[ERROR] IntegratedGradients failed: {e}")
        try:
            gs = GradientShap(model_wrapper)
            baseline = torch.zeros_like(graph.x).to(device)
            attributions = gs.attribute(graph.x, baselines=baseline, target=target_class, n_samples=10)

            node_importance = attributions.abs().mean(dim=1).detach().cpu().numpy()
            feature_importance = attributions.abs().mean(dim=0).detach().cpu().numpy()
            

            return {
                'node_importance': node_importance,
                'feature_importance': feature_importance,
               
            }

        except Exception as e2:
            print(f"[ERROR] GradientShap also failed: {e2}")


        # If both methods fail
    return None
    



def calculate_overall_band_importance(subject_explanations, band_names):
    band_importance = {band: [] for band in band_names}

    for subject in subject_explanations:
        for band_exp in subject['band_explanations']:
            if band_exp['avg_node_importance'] is not None:
                avg_importance = np.mean(band_exp['avg_node_importance'])
                band_importance[band_exp['band_name']].append(avg_importance)

    overall_importance = {}
    for band, importances in band_importance.items():
        if importances:
            overall_importance[band] = {
                'mean': np.mean(importances),
                'std': np.std(importances),
                'values': importances
            }
        else:
            overall_importance[band] = {
                'mean': 0.0,
                'std': 0.0,
                'values': []
            }

    return overall_importance

def generate_explanation_summary(subject_explanations, band_names):
    summary = {
        'total_subjects': len(subject_explanations),
        'band_summary': {},
        'prediction_accuracy': 0
    }

    correct_predictions = sum(1 for s in subject_explanations if s['true_label'] == s['predicted_label'])
    summary['prediction_accuracy'] = correct_predictions / len(subject_explanations)

    for band in band_names:
        node_importances = []
        feature_importances = []
        temporal_importances = []

        for subject in subject_explanations:
            for band_exp in subject['band_explanations']:
                if band_exp['band_name'] == band:
                    if band_exp['avg_node_importance'] is not None:
                        node_importances.extend(band_exp['avg_node_importance'])
                    if band_exp['avg_feature_importance'] is not None:
                        feature_importances.extend(band_exp['avg_feature_importance'])
                    if band_exp['temporal_importance']:
                        temporal_importances.extend(band_exp['temporal_importance'])

        summary['band_summary'][band] = {
            'avg_node_importance': np.mean(node_importances) if node_importances else 0,
            'avg_feature_importance': np.mean(feature_importances) if feature_importances else 0,
            'avg_temporal_importance': np.mean(temporal_importances) if temporal_importances else 0,
            'node_importance_std': np.std(node_importances) if node_importances else 0,
            'feature_importance_std': np.std(feature_importances) if feature_importances else 0,
            'temporal_importance_std': np.std(temporal_importances) if temporal_importances else 0
        }

    # === GLOBAL IMPORTANCE ACROSS ALL SUBJECTS & BANDS ===
    global_node_importances = []
    global_feature_importances = []

    for subject in subject_explanations:
        for band_exp in subject['band_explanations']:
            if band_exp['avg_node_importance'] is not None:
                global_node_importances.append(band_exp['avg_node_importance'])
            if band_exp['avg_feature_importance'] is not None:
                global_feature_importances.append(band_exp['avg_feature_importance'])

    summary['global_node_importance'] = {
        'mean': np.mean(global_node_importances, axis=0).tolist(),
        'std': np.std(global_node_importances, axis=0).tolist()
    }

    summary['global_feature_importance'] = {
        'mean': np.mean(global_feature_importances, axis=0).tolist(),
        'std': np.std(global_feature_importances, axis=0).tolist()
    }

    # NEW: Class-wise summaries
    class_summaries = {}
    for subject in subject_explanations:
        label = subject['true_label']
        if label not in class_summaries:
            class_summaries[label] = {'node_importances': [], 'feature_importances': []}
        for band_exp in subject['band_explanations']:
            if band_exp['avg_node_importance'] is not None:
                class_summaries[label]['node_importances'].append(band_exp['avg_node_importance'])
            if band_exp['avg_feature_importance'] is not None:
                class_summaries[label]['feature_importances'].append(band_exp['avg_feature_importance'])

    summary['classwise_importance'] = {}
    for label, values in class_summaries.items():
        summary['classwise_importance'][label] = {
            'avg_node_importance': np.mean(values['node_importances'], axis=0).tolist(),
            'avg_feature_importance': np.mean(values['feature_importances'], axis=0).tolist()
        }


    return summary

def print_explanation_results(explanations):
    print("\n" + "=" * 50)
    print("MODEL EXPLANATION RESULTS")
    print("=" * 50)

    summary = explanations['summary']
    print(f"Total subjects analyzed: {summary['total_subjects']}")
    print(f"Prediction accuracy: {summary['prediction_accuracy']:.4f}")

    print("\nBAND IMPORTANCE SUMMARY:")
    for band, stats in summary['band_summary'].items():
        print(f"\n{band} Band:")
        print(f"  Average node importance: {stats['avg_node_importance']:.4f} ± {stats['node_importance_std']:.4f}")
        print(f"  Average feature importance: {stats['avg_feature_importance']:.4f} ± {stats['feature_importance_std']:.4f}")
        print(f"  Average temporal importance: {stats['avg_temporal_importance']:.4f} ± {stats['temporal_importance_std']:.4f}")

    print("\nOVERALL BAND IMPORTANCE:")
    for band, stats in explanations['overall_band_importance'].items():
        print(f"{band}: {stats['mean']:.4f} ± {stats['std']:.4f}")

    print("\nGLOBAL NODE IMPORTANCE (across all bands & subjects):")
    global_node = summary['global_node_importance']
    top_nodes = np.argsort(global_node['mean'])[-5:][::-1]
    for i in top_nodes:
        print(f"  Node {i}: {global_node['mean'][i]:.4f} ± {global_node['std'][i]:.4f}")

    print("\nGLOBAL FEATURE IMPORTANCE (across all bands & subjects):")
    global_feat = summary['global_feature_importance']
    top_feats = np.argsort(global_feat['mean'])[-5:][::-1]
    for i in top_feats:
        print(f"  Feature {i}: {global_feat['mean'][i]:.4f} ± {global_feat['std'][i]:.4f}")


    print("\nCLASS-WISE NODE IMPORTANCE (per class label):")
    for label, stats in summary['classwise_importance'].items():
        print(f"\nClass {label}:")
        top_nodes = np.argsort(stats['avg_node_importance'])[-5:][::-1]
        for i in top_nodes:
            print(f"  Node {i}: {stats['avg_node_importance'][i]:.4f}")

    print("\nCLASS-WISE FEATURE IMPORTANCE (per class label):")
    for label, stats in summary['classwise_importance'].items():
        print(f"\nClass {label}:")
        top_feats = np.argsort(stats['avg_feature_importance'])[-5:][::-1]
        for i in top_feats:
            print(f"  Feature {i}: {stats['avg_feature_importance'][i]:.4f}")
        
