In [None]:
dynamic_data_alpha_raw[1][1]

In [2]:
# Cell 1: Import libraries and modules
import torch
import import_ipynb
from data_preprocessing import process_band
from data_combiner import combine_band_data
from analysis_pipeline import run_all_analyses
from netwrks_2 import MultiBandAttentionFusion
from explain_utils import (
    explain_model_predictions,
    print_explanation_results,
    explain_single_graph
)

In [3]:




# Cell 2: Load your data
PATH_TO_SAVED_DYNAMIC_alpha = r"C:\Users\fathi\test\Granger_all_dynamic_data_overlap50.pt"
PATH_TO_SAVED_DYNAMIC_beta = r"C:\Users\fathi\test\Granger_all_dynamic_data_overlap50_beta.pt"

dynamic_data_alpha_raw = torch.load(PATH_TO_SAVED_DYNAMIC_alpha)
dynamic_data_beta_raw = torch.load(PATH_TO_SAVED_DYNAMIC_beta)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Cell 3: Process each band
train_alpha, test_alpha, dynamic_alpha = process_band(dynamic_data_alpha_raw)
train_beta, test_beta, dynamic_beta = process_band(dynamic_data_beta_raw)

print(f"Alpha band: {len(dynamic_alpha)} subjects")
print(f"Beta band: {len(dynamic_beta)} subjects")

# Cell 4: Combine bands for multi-band analysis
band_data_list = [dynamic_alpha, dynamic_beta]
combined_dataset = combine_band_data(band_data_list)
band_names = ['Alpha', 'Beta']

print(f"Combined dataset: {len(combined_dataset)} subjects")

# Cell 5: Run analysis
results = run_all_analyses(combined_dataset, model_class=MultiBandAttentionFusion, band_names=band_names)
# Cell 6: Access results
best_model = results['best_single_model']
ensemble_model = results['ensemble_model']
fold_metrics = results['fold_metrics']
cv_results = results['cv_results']

print("\nFinal Results:")
print(f"Best CV F1: {cv_results['avg_f1']:.4f}")
print(f"Best CV Accuracy: {cv_results['avg_accuracy']:.4f}")
if cv_results['avg_band_importance']:
    print("Band Importance:")
    for band, importance in cv_results['avg_band_importance'].items():
        print(f"  {band}: {importance:.4f}")

###

# === EXPLANATION PHASE ===
explanations = explain_model_predictions(best_model, combined_dataset, device, band_names=band_names)
print_explanation_results(explanations)




# Cell 7: Optional - Save models
torch.save(best_model.state_dict(), 'best_model.pt')
torch.save(ensemble_model.state_dict(), 'ensemble_model.pt')
# Optional: Print a single graph shape for sanity check
sample_band_idx = 0
sample_time_idx = 0
sample_subject = combined_dataset[0]
sample_graph = sample_subject[sample_band_idx][sample_time_idx]

print("Sample graph shape check:")
print("  x shape:", sample_graph.x.shape)
print("  edge_index shape:", sample_graph.edge_index.shape)





# Pick one subject and a specific band and time point
subject_idx = 0
band_idx = 0
time_idx = 0

# Select subject data
subject_data = combined_dataset[subject_idx]
graph = subject_data[band_idx][time_idx]
target_class = graph.y.item()

# Use the trained model (best_model)
model = best_model.to(device)
model.eval()  # VERY important for consistent inference

# Run explanation
explanation = explain_single_graph(
    model=model,
    graph=graph,
    full_batch=subject_data,
    band_idx=band_idx,
    time_idx=time_idx,
    target_class=target_class,
    device=device
)

# Save or print explanation
torch.save(explanations, "full_model_explanations.pt")

print("Explanation saved.")

  dynamic_data_alpha_raw = torch.load(PATH_TO_SAVED_DYNAMIC_alpha)
  dynamic_data_beta_raw = torch.load(PATH_TO_SAVED_DYNAMIC_beta)


Alpha band: 52 subjects
Beta band: 52 subjects
Combined dataset: 52 subjects
Using device: cpu

===== Fold 1/2 =====
Training set: Class 0: 15, Class 1: 11
Validation set: Class 0: 14, Class 1: 12




Using class weights: [0.8666667 1.7727273]


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Training Fold 1: 100%|██████████| 1/1 [00:34<00:00, 34.63s/it, best_epoch=1, best_f1=0.3769, conf=0.62, train_loss=2.0892, val_f1=0.3769, val_loss=0.6874]



Fold 1 - Best Epoch: 1
Best Validation F1: 0.3769
Accuracy: 0.5385 | Recall: 0.5385 | Precision: 0.2899
Average prediction confidence: 0.6188
Confusion Matrix:
[[14  0]
 [12  0]]

Band Importance Analysis (Fold 1):
Alpha: 0.0140 (1.4%)
Beta: 0.9860 (98.6%)

===== Fold 2/2 =====
Training set: Class 0: 14, Class 1: 12
Validation set: Class 0: 15, Class 1: 11




Using class weights: [0.9285714 1.625    ]


Training Fold 2: 100%|██████████| 1/1 [00:34<00:00, 34.30s/it, best_epoch=1, best_f1=0.5032, conf=0.55, train_loss=1.4345, val_f1=0.5032, val_loss=0.6719]



Fold 2 - Best Epoch: 1
Best Validation F1: 0.5032
Accuracy: 0.6154 | Recall: 0.6154 | Precision: 0.7692
Average prediction confidence: 0.5476
Confusion Matrix:
[[15  0]
 [10  1]]

Band Importance Analysis (Fold 2):
Alpha: 0.9433 (94.3%)
Beta: 0.0567 (5.7%)

===== Cross-Validation Results =====
Average F1: 0.4401
Average Accuracy: 0.5769
Average Recall: 0.5769
Average Precision: 0.5296
F1 Standard Deviation: 0.0631

Average Band Importance Across All Folds:
Alpha: 0.4787 (47.9%)
Beta: 0.5213 (52.1%)

===== Analysis Summary =====
Cross-Validation F1: 0.4401 ± 0.0631
Cross-Validation Accuracy: 0.5769 ± 0.0385

Final Results:
Best CV F1: 0.4401
Best CV Accuracy: 0.5769
Band Importance:
  Alpha: 0.4787
  Beta: 0.5213
Explaining subject 1/52 (Label: 1)
  Analyzing Alpha band...
    Explaining time step 0 ...




    Done with time step 0
    Explaining time step 1 ...
    Done with time step 1
    Explaining time step 2 ...
    Done with time step 2
    Explaining time step 3 ...
    Done with time step 3
    Explaining time step 4 ...
    Done with time step 4
    Explaining time step 5 ...
    Done with time step 5
    Explaining time step 6 ...
    Done with time step 6
    Explaining time step 7 ...
    Done with time step 7
    Explaining time step 8 ...
    Done with time step 8
    Explaining time step 9 ...
    Done with time step 9
    Explaining time step 10 ...
    Done with time step 10
    Explaining time step 11 ...
    Done with time step 11
    Explaining time step 12 ...
    Done with time step 12
    Explaining time step 13 ...


KeyboardInterrupt: 

In [None]:
'''MODEL EXPLANATION RESULTS WOULD BE SOMETHING LIKE THIS
==================================================
Total subjects analyzed: 52
Prediction accuracy: 0.5769

BAND IMPORTANCE SUMMARY:

Alpha Band:
  Average node importance: 0.0234 ± 0.0089
  Average feature importance: 0.0187 ± 0.0056
  Average temporal importance: 0.0198 ± 0.0067

Beta Band:
  Average node importance: 0.0456 ± 0.0123
  Average feature importance: 0.0398 ± 0.0098
  Average temporal importance: 0.0421 ± 0.0087

OVERALL BAND IMPORTANCE:
Alpha: 0.0234 ± 0.0089
Beta: 0.0456 ± 0.0123

GLOBAL NODE IMPORTANCE (across all bands & subjects):
  Node 15: 0.0523 ± 0.0145
  Node 8: 0.0498 ± 0.0134
  Node 12: 0.0487 ± 0.0128
  Node 3: 0.0465 ± 0.0119
  Node 7: 0.0443 ± 0.0112

GLOBAL FEATURE IMPORTANCE (across all bands & subjects):
  Feature 2: 0.0612 ± 0.0167
  Feature 0: 0.0578 ± 0.0145
  Feature 1: 0.0534 ± 0.0139
  Feature 4: 0.0498 ± 0.0123
  Feature 3: 0.0456 ± 0.0118

CLASS-WISE NODE IMPORTANCE (per class label):

Class 0:
  Node 15: 0.0534
  Node 8: 0.0512
  Node 12: 0.0489
  Node 3: 0.0467
  Node 7: 0.0445

Class 1:
  Node 12: 0.0598
  Node 15: 0.0567
  Node 8: 0.0534
  Node 3: 0.0512
  Node 7: 0.0489

CLASS-WISE FEATURE IMPORTANCE (per class label):

Class 0:
  Feature 2: 0.0623
  Feature 0: 0.0589
  Feature 1: 0.0545
  Feature 4: 0.0512
  Feature 3: 0.0467

Class 1:
  Feature 2: 0.0634
  Feature 0: 0.0598
  Feature 1: 0.0567
  Feature 4: 0.0534
  Feature 3: 0.0489'''