# Methun Research â€” CNN-Transformer IDS with XAI

This notebook runs the **full pipeline** on Google Colab:

| Step | What happens |
|------|-------------|
| **1. Setup** | Clone repo, install deps, verify GPU |
| **2. Data** | Upload / mount CICIDS2017 dataset |
| **3. CNN-Transformer Training** | Train the hybrid model |
| **4. Enhanced Transformer Training** | Train the grouped-token model |
| **5. Integrated Gradients** | Per-feature attribution (already runs during training) |
| **6. Grad-CAM** | CNN activation-map attribution (already runs during training) |
| **7. SHAP** | GradientExplainer feature importance + plots |
| **8. Test Evaluation** | Held-out test set metrics (20 % unseen data) |
| **9. Results** | Display all XAI outputs side by side |

> **Runtime:** Select **GPU â†’ T4** via *Runtime â†’ Change runtime type* before running.

---
## 1. Setup â€” Clone & Install

In [None]:
import sys, os

# Clone the repository
!git clone https://github.com/samaraweeramethun-eng/IDS_Interpretability.git
%cd IDS_Interpretability

# Install CUDA-enabled PyTorch (matches Colab T4 drivers)
!pip install -q torch --index-url https://download.pytorch.org/whl/cu121

# Install project dependencies
!pip install -q -r requirements.txt

# Explicit editable install (ensures methun_research is importable)
!pip install -q -e .

# Fallback: add src/ to sys.path in case editable install didn't register
SRC_DIR = os.path.join(os.getcwd(), 'src')
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

# Verify the package is importable
import methun_research
print(f"âœ“ methun_research imported from: {methun_research.__file__}")

# Verify GPU
import torch
print(f"PyTorch {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    free, total = torch.cuda.mem_get_info(0)
    print(f"VRAM: {total/1024**3:.1f} GB total, {free/1024**3:.1f} GB free")
else:
    print("WARNING: No GPU detected â€” training will be slow. Select GPU runtime.")

---
## 2. Data â€” Load CICIDS2017

Choose **one** of the options below:

- **Option A (recommended):** Mount Google Drive if you have the full dataset there  
- **Option B:** Upload a CSV directly from your local machine  
- **Option C:** Use the included 5 000-row sample for a quick smoke test

In [None]:
import os

# â”€â”€ Option A: Mount Google Drive â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# Uncomment the lines below if your dataset is on Google Drive.
# Update the path to match where you stored the CSV.

# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)
# DRIVE_CSV = '/content/drive/MyDrive/cicids2017/cicids2017.csv'
# !mkdir -p data/cicids2017
# !ln -sf "$DRIVE_CSV" data/cicids2017/cicids2017.csv

# â”€â”€ Option B: Upload from local machine â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# Uncomment the lines below to upload interactively.

# from google.colab import files
# uploaded = files.upload()  # select your cicids2017.csv
# !mkdir -p data/cicids2017
# !mv "{list(uploaded.keys())[0]}" data/cicids2017/cicids2017.csv

# â”€â”€ Option C: Use included sample (default) â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
DATA_PATH = 'data/cicids2017/cicids2017_sample.csv'

# Switch to full dataset if it exists
if os.path.exists('data/cicids2017/cicids2017.csv'):
    DATA_PATH = 'data/cicids2017/cicids2017.csv'

print(f'Using dataset: {DATA_PATH}')
print(f'Size: {os.path.getsize(DATA_PATH) / 1024**2:.1f} MB')

import pandas as pd
df_peek = pd.read_csv(DATA_PATH, nrows=5)
print(f'Columns: {len(df_peek.columns)}')
label_col = [c for c in df_peek.columns if 'label' in c.lower()][0]
counts = pd.read_csv(DATA_PATH, usecols=[label_col])[label_col].value_counts()
print(f'\nLabel distribution:\n{counts}')

---
## 3. Configure Training

Adjust hyperparameters below. The defaults are tuned for a **T4 GPU** with the full dataset.  
If you are using the 5 000-row sample, the `_sample` overrides will automatically apply.

In [None]:
import os, sys, warnings
os.environ['TORCHDYNAMO_DISABLE'] = '1'
warnings.filterwarnings('ignore')

# Ensure methun_research is importable (survives kernel restart)
SRC_DIR = os.path.join(os.getcwd(), 'src')
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

from methun_research.config import CNNTransformerConfig, EnhancedConfig

USE_SAMPLE = 'sample' in DATA_PATH

# â”€â”€ CNN-Transformer Config (optimised for T4 GPU) â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
cnn_cfg = CNNTransformerConfig(
    input_path=DATA_PATH,
    output_dir='artifacts',
    epochs=25 if not USE_SAMPLE else 5,
    batch_size=1024 if not USE_SAMPLE else 64,
    val_batch_size=2048 if not USE_SAMPLE else 128,
    lr=2e-3 if not USE_SAMPLE else 1.5e-3,
    num_workers=2,
    d_model=192 if not USE_SAMPLE else 64,
    conv_channels=96 if not USE_SAMPLE else 32,
    num_layers=3 if not USE_SAMPLE else 1,
    num_heads=8 if not USE_SAMPLE else 4,
    d_ff=768 if not USE_SAMPLE else 256,
    dropout=0.2,
    val_size=0.1,
    ig_steps=32 if not USE_SAMPLE else 8,
    ig_samples=512 if not USE_SAMPLE else 128,
    max_train_samples=500_000 if not USE_SAMPLE else 0,
)

# â”€â”€ Enhanced Transformer Config (optimised for T4 GPU) â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
enh_cfg = EnhancedConfig(
    input_path=DATA_PATH,
    output_dir='artifacts',
    epochs=35 if not USE_SAMPLE else 5,
    batch_size=1024 if not USE_SAMPLE else 64,
    val_batch_size=2048 if not USE_SAMPLE else 128,
    lr=2e-3,
    num_workers=2,
    d_model=160 if not USE_SAMPLE else 64,
    num_layers=4 if not USE_SAMPLE else 2,
    heads=10 if not USE_SAMPLE else 4,
    d_ff=640 if not USE_SAMPLE else 256,
    dropout=0.15,
    group_size=8,
    use_swa=True if not USE_SAMPLE else False,
    use_mixup=True if not USE_SAMPLE else False,
    max_train_samples=500_000 if not USE_SAMPLE else 0,
)

print(f'Mode: {"SAMPLE (small model)" if USE_SAMPLE else "FULL DATASET (production model)"}')
print(f'\nData split: {100*(1-cnn_cfg.val_size-cnn_cfg.test_size):.0f}% train / {100*cnn_cfg.val_size:.0f}% val / {100*cnn_cfg.test_size:.0f}% test')
print(f'\nCNN-Transformer: {cnn_cfg.epochs} epochs, d_model={cnn_cfg.d_model}, batch={cnn_cfg.batch_size}')
print(f'Enhanced:        {enh_cfg.epochs} epochs, d_model={enh_cfg.d_model}, batch={enh_cfg.batch_size}')
print(f'Max train samples: {cnn_cfg.max_train_samples or "unlimited"}')

---
## 4. Train CNN-Transformer

This trains the CNN + Transformer hybrid and automatically generates:
- **Integrated Gradients** feature attributions
- **Grad-CAM** activation maps over feature positions

In [None]:
import gc, time, torch
from methun_research.training.cnn_trainer import train_cnn_transformer

# Show RAM before training
import psutil
ram = psutil.virtual_memory()
print(f'System RAM: {ram.used/1024**3:.1f} / {ram.total/1024**3:.1f} GB ({ram.percent}% used)')
if torch.cuda.is_available():
    free, total = torch.cuda.mem_get_info(0)
    print(f'GPU VRAM:   {(total-free)/1024**3:.1f} / {total/1024**3:.1f} GB used')

gc.collect()
print('\nTraining CNN-Transformer...')
t0 = time.time()
cnn_path = train_cnn_transformer(cnn_cfg)

elapsed = (time.time() - t0) / 60
print(f'\nDone in {elapsed:.1f} min  ->  {cnn_path}')

# Show resource usage after training
ram = psutil.virtual_memory()
print(f'\nSystem RAM: {ram.used/1024**3:.1f} / {ram.total/1024**3:.1f} GB ({ram.percent}% used)')
if torch.cuda.is_available():
    free, total = torch.cuda.mem_get_info(0)
    print(f'GPU VRAM:   {(total-free)/1024**3:.1f} / {total/1024**3:.1f} GB used')
    print(f'GPU peak:   {torch.cuda.max_memory_allocated(0)/1024**3:.2f} GB')

---
## 5. Train Enhanced Transformer

In [None]:
import gc, torch
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

from methun_research.training.enhanced_trainer import train_enhanced

print('Training Enhanced Transformer...')
t0 = time.time()
enh_path = train_enhanced(enh_cfg)

elapsed = (time.time() - t0) / 60
print(f'\nDone in {elapsed:.1f} min  ->  {enh_path}')

if torch.cuda.is_available():
    free, total = torch.cuda.mem_get_info(0)
    print(f'GPU VRAM:   {(total-free)/1024**3:.1f} / {total/1024**3:.1f} GB used')
    print(f'GPU peak:   {torch.cuda.max_memory_allocated(0)/1024**3:.2f} GB')

---
## 6. SHAP Analysis â€” CNN-Transformer

Runs `GradientExplainer` on the CNN-Transformer checkpoint.  
Produces: global importance CSV, summary beeswarm plot, single-sample waterfall plot.

In [None]:
from methun_research.interpretability.shap_runner import run_shap

os.makedirs('artifacts/shap', exist_ok=True)

SHAP_BG   = 2000 if not USE_SAMPLE else 200
SHAP_EVAL = 2000 if not USE_SAMPLE else 200
SHAP_POOL = 150_000 if not USE_SAMPLE else 500

print('Running SHAP on CNN-Transformer...')
t0 = time.time()
shap_csv = run_shap(
    checkpoint_path=cnn_path,
    data_path=DATA_PATH,
    output_dir='artifacts/shap',
    background_size=SHAP_BG,
    eval_size=SHAP_EVAL,
    eval_pool=SHAP_POOL,
    chunk_size=256,
)
print(f'\nDone in {(time.time()-t0)/60:.1f} min  ->  {shap_csv}')

---
## 7. SHAP Analysis â€” Enhanced Transformer

In [None]:
if enh_path:
    print('Running SHAP on Enhanced Transformer...')
    t0 = time.time()
    shap_enh_csv = run_shap(
        checkpoint_path=enh_path,
        data_path=DATA_PATH,
        output_dir='artifacts/shap_enhanced',
        background_size=SHAP_BG,
        eval_size=SHAP_EVAL,
        eval_pool=SHAP_POOL,
        chunk_size=256,
    )
    print(f'\nDone in {(time.time()-t0)/60:.1f} min  ->  {shap_enh_csv}')
else:
    print('No Enhanced Transformer checkpoint â€” skipping.')

---
## 8. Test Set Evaluation Summary

Both trainers automatically evaluate on a **held-out 20% test set** that is never seen during training or validation.  
This cell loads the saved test metrics from each checkpoint and presents them side by side.

In [None]:
import torch, pandas as pd
from IPython.display import display, HTML

def load_test_metrics(ckpt_path, model_name):
    """Load test_metrics from a saved checkpoint."""
    if not os.path.exists(ckpt_path):
        print(f'  [{model_name}] Checkpoint not found: {ckpt_path}')
        return None
    ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
    if 'test_metrics' not in ckpt:
        print(f'  [{model_name}] No test_metrics in checkpoint (was test_size=0?)')
        return None
    m = ckpt['test_metrics']
    return {
        'Model': model_name,
        'ROC-AUC': f"{m['auc_roc']:.4f}",
        'PR-AUC': f"{m['auc_pr']:.4f}",
        'F1-Score': f"{m['f1_score']:.4f}",
        'Precision': f"{m['precision']:.4f}",
        'Recall': f"{m['recall']:.4f}",
        'Accuracy': f"{m['accuracy']:.4f}",
    }

rows = []
r = load_test_metrics(cnn_path, 'CNN-Transformer')
if r: rows.append(r)
r = load_test_metrics(enh_path, 'Enhanced Transformer')
if r: rows.append(r)

if rows:
    df_test = pd.DataFrame(rows).set_index('Model')
    display(HTML('<h3>ðŸ§ª Held-Out Test Set Results (20 % of data, never seen during training)</h3>'))
    display(df_test.style.set_properties(**{'text-align': 'center'}).format(precision=4))
else:
    print('No test metrics available. Ensure test_size > 0 in configs.')

---
## 9. Visualise XAI Results

Compare feature rankings across all four XAI methods.

In [None]:
import pandas as pd
from IPython.display import display, HTML

TOP_K = 15

def load_ranking(path, val_col, method_name):
    if not os.path.exists(path):
        print(f'  [{method_name}] Not found: {path}')
        return None
    df = pd.read_csv(path)
    df = df.sort_values(val_col, ascending=False).head(TOP_K).reset_index(drop=True)
    df.index += 1
    df.columns = ['Feature', method_name]
    return df

rankings = [
    load_ranking('artifacts/cnn_transformer_integrated_gradients.csv',
                 'avg_abs_integrated_grad', 'Integrated Gradients'),
    load_ranking('artifacts/cnn_transformer_grad_cam.csv',
                 'grad_cam_importance', 'Grad-CAM'),
    load_ranking('artifacts/shap/shap_global_importance_attack.csv',
                 'mean_abs_shap', 'SHAP (CNN-Trans)'),
]

# Enhanced SHAP if available
r = load_ranking('artifacts/shap_enhanced/shap_global_importance_attack.csv',
                 'mean_abs_shap', 'SHAP (Enhanced)')
if r is not None:
    rankings.append(r)

rankings = [r for r in rankings if r is not None]
if rankings:
    combined = pd.concat([r.set_index(r.index) for r in rankings], axis=1)
    display(HTML(f'<h3>Top {TOP_K} Features â€” All XAI Methods</h3>'))
    display(combined)
else:
    print('No ranking files found.')

In [None]:
# Display the generated plots
from IPython.display import Image, display

plots = [
    ('Grad-CAM Feature Importance', 'artifacts/grad_cam_importance.png'),
    ('SHAP Summary (CNN-Transformer)', 'artifacts/shap/shap_summary_attack.png'),
    ('SHAP Waterfall (CNN-Transformer)', 'artifacts/shap/shap_waterfall_attack.png'),
    ('SHAP Summary (Enhanced)', 'artifacts/shap_enhanced/shap_summary_attack.png'),
    ('SHAP Waterfall (Enhanced)', 'artifacts/shap_enhanced/shap_waterfall_attack.png'),
]

for title, path in plots:
    if os.path.exists(path):
        print(f'\n--- {title} ---')
        display(Image(filename=path, width=700))
    else:
        print(f'[skip] {title} â€” file not found')

---
## 10. Save Artifacts to Google Drive (Optional)

Copy all checkpoints and XAI outputs to your Drive for later use.

In [None]:
# Uncomment to save artifacts to Google Drive
# import shutil
# from google.colab import drive
# drive.mount('/content/drive', force_remount=False)

# DEST = '/content/drive/MyDrive/methun_research_artifacts'
# shutil.copytree('artifacts', DEST, dirs_exist_ok=True)
# print(f'Artifacts saved to {DEST}')

---
## 11. List All Generated Artifacts

In [None]:
print('All generated artifacts:')
for root, dirs, files in os.walk('artifacts'):
    for f in sorted(files):
        fp = os.path.join(root, f)
        sz = os.path.getsize(fp)
        if sz > 1024*1024:
            print(f'  {fp}  ({sz/1024**2:.1f} MB)')
        else:
            print(f'  {fp}  ({sz/1024:.1f} KB)')