# Methun Research â€” CNN-Transformer IDS with XAI (Kaggle)

This notebook runs the **full pipeline** on Kaggle with a free T4 GPU.

| Step | What happens |
|------|-------------|
| **1. Setup** | Clone repo, install deps, verify GPU |
| **2. Data** | Link Kaggle dataset or upload CSV |
| **3. Configure** | Set hyperparameters (auto-detects sample vs full) |
| **4. CNN-Transformer** | Train hybrid model + IG + Grad-CAM |
| **5. Enhanced Transformer** | Train grouped-token model |
| **6. SHAP** | GradientExplainer feature importance |
| **7. Test Evaluation** | Held-out test set metrics (20% unseen data) |
| **8. Results** | Display all XAI outputs side by side |

> **Before running:**
> 1. Enable **GPU T4 x2** or **GPU P100** via *Settings â†’ Accelerator* (right panel)
> 2. Enable **Internet** via *Settings â†’ Internet â†’ On* (needed for git clone + pip)
> 3. Add CICIDS2017 as a Kaggle dataset (see Step 2)

---
## 1. Setup â€” Clone & Install

In [None]:
import sys, os, subprocess

# â”€â”€ Clone repository â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
REPO_DIR = '/kaggle/working/IDS_Interpretability'
if not os.path.exists(REPO_DIR):
    !git clone https://github.com/samaraweeramethun-eng/IDS_Interpretability.git {REPO_DIR}
else:
    print(f'Repo already exists at {REPO_DIR} â€” pulling latest...')
    !cd {REPO_DIR} && git pull origin main

os.chdir(REPO_DIR)
print(f'Working directory: {os.getcwd()}')

# â”€â”€ Install dependencies â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# Kaggle already has PyTorch with CUDA, so we only install missing packages
!pip install -q shap joblib psutil
!pip install -q -e .

# â”€â”€ Fallback: add src/ to sys.path â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
SRC_DIR = os.path.join(os.getcwd(), 'src')
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

# â”€â”€ Verify â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
import methun_research
print(f'\nâœ“ methun_research imported from: {methun_research.__file__}')

import torch
print(f'âœ“ PyTorch {torch.__version__}')
print(f'âœ“ CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f'âœ“ GPU {i}: {torch.cuda.get_device_name(i)}')
    free, total = torch.cuda.mem_get_info(0)
    print(f'âœ“ VRAM: {total/1024**3:.1f} GB total, {free/1024**3:.1f} GB free')
else:
    print('âš  WARNING: No GPU detected â€” enable GPU in Settings â†’ Accelerator')

---
## 2. Data â€” Load CICIDS2017

### How to add the dataset on Kaggle:

**Option A â€” Use an existing Kaggle dataset:**
1. Click **+ Add Data** (right panel) â†’ search `cicids2017`
2. Add it â€” the CSV will be at `/kaggle/input/<dataset-name>/cicids2017.csv`
3. Update `KAGGLE_DATASET_PATH` below to match

**Option B â€” Upload your own CSV as a dataset:**
1. Go to [kaggle.com/datasets](https://www.kaggle.com/datasets) â†’ **New Dataset**
2. Upload `cicids2017.csv` â†’ create the dataset
3. Come back to this notebook â†’ **+ Add Data** â†’ search your dataset name
4. Update `KAGGLE_DATASET_PATH` below

**Option C â€” Use the included 5,000-row sample (no upload needed):**
The sample is already in the repo â€” leave the defaults below.

In [None]:
import os, glob

# â”€â”€ Update this path to match YOUR Kaggle dataset â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# Example: '/kaggle/input/cicids2017-dataset/cicids2017.csv'
KAGGLE_DATASET_PATH = ''  # <-- put your path here, or leave empty to auto-detect

# â”€â”€ Auto-detection logic â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
DATA_PATH = None

# 1. User-specified path
if KAGGLE_DATASET_PATH and os.path.exists(KAGGLE_DATASET_PATH):
    DATA_PATH = KAGGLE_DATASET_PATH

# 2. Search /kaggle/input/ for any cicids CSV
if DATA_PATH is None:
    candidates = glob.glob('/kaggle/input/**/cicids*.csv', recursive=True)
    # Prefer the largest file (likely the full dataset)
    if candidates:
        candidates.sort(key=os.path.getsize, reverse=True)
        DATA_PATH = candidates[0]
        print(f'Auto-detected Kaggle dataset: {DATA_PATH}')

# 3. Check the repo's data directory
if DATA_PATH is None and os.path.exists('data/cicids2017/cicids2017.csv'):
    DATA_PATH = 'data/cicids2017/cicids2017.csv'

# 4. Fall back to included sample
if DATA_PATH is None:
    DATA_PATH = 'data/cicids2017/cicids2017_sample.csv'

# â”€â”€ Symlink into expected location (so configs work unchanged) â”€â”€â”€â”€â”€â”€â”€â”€
repo_data_dir = 'data/cicids2017'
os.makedirs(repo_data_dir, exist_ok=True)
repo_csv = os.path.join(repo_data_dir, 'cicids2017.csv')
if DATA_PATH.startswith('/kaggle/input') and not os.path.exists(repo_csv):
    os.symlink(DATA_PATH, repo_csv)
    DATA_PATH = repo_csv
    print(f'Symlinked Kaggle dataset â†’ {repo_csv}')

# â”€â”€ Verify â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
print(f'\nUsing dataset: {DATA_PATH}')
print(f'Size: {os.path.getsize(DATA_PATH) / 1024**2:.1f} MB')

import pandas as pd
df_peek = pd.read_csv(DATA_PATH, nrows=5)
print(f'Columns: {len(df_peek.columns)}')
label_col = [c for c in df_peek.columns if 'label' in c.lower()][0]
counts = pd.read_csv(DATA_PATH, usecols=[label_col])[label_col].value_counts()
print(f'\nLabel distribution:\n{counts}')

---
## 3. Configure Training

Hyperparameters are tuned for **Kaggle T4 GPU** (16 GB VRAM, ~13 GB RAM).
If you're using the 5,000-row sample, smaller configs are applied automatically.

In [None]:
import os, sys, warnings
os.environ['TORCHDYNAMO_DISABLE'] = '1'
warnings.filterwarnings('ignore')

# Ensure methun_research is importable (survives kernel restart)
SRC_DIR = os.path.join(os.getcwd(), 'src')
if SRC_DIR not in sys.path:
    sys.path.insert(0, SRC_DIR)

from methun_research.config import CNNTransformerConfig, EnhancedConfig

USE_SAMPLE = 'sample' in DATA_PATH

# â”€â”€ CNN-Transformer Config â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
cnn_cfg = CNNTransformerConfig(
    input_path=DATA_PATH,
    output_dir='/kaggle/working/artifacts',
    epochs=25 if not USE_SAMPLE else 5,
    batch_size=1024 if not USE_SAMPLE else 64,
    val_batch_size=2048 if not USE_SAMPLE else 128,
    lr=2e-3 if not USE_SAMPLE else 1.5e-3,
    num_workers=2,
    d_model=192 if not USE_SAMPLE else 64,
    conv_channels=96 if not USE_SAMPLE else 32,
    num_layers=3 if not USE_SAMPLE else 1,
    num_heads=8 if not USE_SAMPLE else 4,
    d_ff=768 if not USE_SAMPLE else 256,
    dropout=0.2,
    val_size=0.1,
    test_size=0.2,
    ig_steps=32 if not USE_SAMPLE else 8,
    ig_samples=512 if not USE_SAMPLE else 128,
    max_train_samples=500_000 if not USE_SAMPLE else 0,
)

# â”€â”€ Enhanced Transformer Config â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
enh_cfg = EnhancedConfig(
    input_path=DATA_PATH,
    output_dir='/kaggle/working/artifacts',
    epochs=35 if not USE_SAMPLE else 5,
    batch_size=1024 if not USE_SAMPLE else 64,
    val_batch_size=2048 if not USE_SAMPLE else 128,
    lr=2e-3,
    num_workers=2,
    d_model=160 if not USE_SAMPLE else 64,
    num_layers=4 if not USE_SAMPLE else 2,
    heads=10 if not USE_SAMPLE else 4,
    d_ff=640 if not USE_SAMPLE else 256,
    dropout=0.15,
    group_size=8,
    use_swa=True if not USE_SAMPLE else False,
    use_mixup=True if not USE_SAMPLE else False,
    max_train_samples=500_000 if not USE_SAMPLE else 0,
)

print(f'Mode: {"SAMPLE (small model)" if USE_SAMPLE else "FULL DATASET (production model)"}')
print(f'Data split: {100*(1-cnn_cfg.val_size-cnn_cfg.test_size):.0f}% train / {100*cnn_cfg.val_size:.0f}% val / {100*cnn_cfg.test_size:.0f}% test')
print(f'\nCNN-Transformer: {cnn_cfg.epochs} epochs, d_model={cnn_cfg.d_model}, batch={cnn_cfg.batch_size}')
print(f'Enhanced:        {enh_cfg.epochs} epochs, d_model={enh_cfg.d_model}, batch={enh_cfg.batch_size}')
print(f'Max train samples: {cnn_cfg.max_train_samples or "unlimited"}')
print(f'Output dir: {cnn_cfg.output_dir}')

---
## 4. Train CNN-Transformer

Trains the CNN + Transformer hybrid and automatically generates:
- **Integrated Gradients** feature attributions
- **Grad-CAM** activation maps over feature positions

In [None]:
import gc, time, torch
from methun_research.training.cnn_trainer import train_cnn_transformer

# Show resources before training
import psutil
ram = psutil.virtual_memory()
print(f'System RAM: {ram.used/1024**3:.1f} / {ram.total/1024**3:.1f} GB ({ram.percent}% used)')
if torch.cuda.is_available():
    free, total = torch.cuda.mem_get_info(0)
    print(f'GPU VRAM:   {(total-free)/1024**3:.1f} / {total/1024**3:.1f} GB used')

os.makedirs(cnn_cfg.output_dir, exist_ok=True)
gc.collect()
print('\nTraining CNN-Transformer...')
t0 = time.time()
cnn_path = train_cnn_transformer(cnn_cfg)

elapsed = (time.time() - t0) / 60
print(f'\nâœ“ Done in {elapsed:.1f} min â†’ {cnn_path}')

# Show resources after training
ram = psutil.virtual_memory()
print(f'\nSystem RAM: {ram.used/1024**3:.1f} / {ram.total/1024**3:.1f} GB ({ram.percent}% used)')
if torch.cuda.is_available():
    free, total = torch.cuda.mem_get_info(0)
    print(f'GPU VRAM:   {(total-free)/1024**3:.1f} / {total/1024**3:.1f} GB used')
    print(f'GPU peak:   {torch.cuda.max_memory_allocated(0)/1024**3:.2f} GB')

---
## 5. Train Enhanced Transformer

In [None]:
import gc, torch
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

from methun_research.training.enhanced_trainer import train_enhanced

print('Training Enhanced Transformer...')
t0 = time.time()
enh_path = train_enhanced(enh_cfg)

elapsed = (time.time() - t0) / 60
print(f'\nâœ“ Done in {elapsed:.1f} min â†’ {enh_path}')

if torch.cuda.is_available():
    free, total = torch.cuda.mem_get_info(0)
    print(f'GPU VRAM:   {(total-free)/1024**3:.1f} / {total/1024**3:.1f} GB used')
    print(f'GPU peak:   {torch.cuda.max_memory_allocated(0)/1024**3:.2f} GB')

---
## 6. SHAP Analysis â€” CNN-Transformer

Runs `GradientExplainer` on the CNN-Transformer checkpoint.
Produces: global importance CSV, summary beeswarm plot, single-sample waterfall plot.

In [None]:
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

from methun_research.interpretability.shap_runner import run_shap

SHAP_BG   = 2000 if not USE_SAMPLE else 200
SHAP_EVAL = 2000 if not USE_SAMPLE else 200
SHAP_POOL = 150_000 if not USE_SAMPLE else 500

shap_dir = os.path.join(cnn_cfg.output_dir, 'shap')
os.makedirs(shap_dir, exist_ok=True)

print('Running SHAP on CNN-Transformer...')
t0 = time.time()
shap_csv = run_shap(
    checkpoint_path=cnn_path,
    data_path=DATA_PATH,
    output_dir=shap_dir,
    background_size=SHAP_BG,
    eval_size=SHAP_EVAL,
    eval_pool=SHAP_POOL,
    chunk_size=256,
)
print(f'\nâœ“ Done in {(time.time()-t0)/60:.1f} min â†’ {shap_csv}')

---
## 6b. SHAP Analysis â€” Enhanced Transformer

In [None]:
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

if enh_path:
    shap_enh_dir = os.path.join(enh_cfg.output_dir, 'shap_enhanced')
    os.makedirs(shap_enh_dir, exist_ok=True)
    print('Running SHAP on Enhanced Transformer...')
    t0 = time.time()
    shap_enh_csv = run_shap(
        checkpoint_path=enh_path,
        data_path=DATA_PATH,
        output_dir=shap_enh_dir,
        background_size=SHAP_BG,
        eval_size=SHAP_EVAL,
        eval_pool=SHAP_POOL,
        chunk_size=256,
    )
    print(f'\nâœ“ Done in {(time.time()-t0)/60:.1f} min â†’ {shap_enh_csv}')
else:
    print('No Enhanced Transformer checkpoint â€” skipping.')

---
## 7. Test Set Evaluation Summary

Both trainers automatically evaluate on a **held-out 20% test set** that is never seen during training or validation.
This cell loads the saved test metrics from each checkpoint.

In [None]:
import torch, pandas as pd
from IPython.display import display, HTML

OUT = cnn_cfg.output_dir

def load_test_metrics(ckpt_path, model_name):
    if not ckpt_path or not os.path.exists(ckpt_path):
        print(f'  [{model_name}] Checkpoint not found')
        return None
    ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=False)
    m = ckpt.get('test_metrics')
    if not m:
        print(f'  [{model_name}] No test_metrics in checkpoint')
        return None
    return {
        'Model': model_name,
        'ROC-AUC': f"{m['auc_roc']:.4f}",
        'PR-AUC': f"{m['auc_pr']:.4f}",
        'F1-Score': f"{m['f1_score']:.4f}",
        'Precision': f"{m['precision']:.4f}",
        'Recall': f"{m['recall']:.4f}",
        'Accuracy': f"{m['accuracy']:.4f}",
    }

rows = []
r = load_test_metrics(cnn_path, 'CNN-Transformer')
if r: rows.append(r)
r = load_test_metrics(enh_path, 'Enhanced Transformer')
if r: rows.append(r)

if rows:
    df_test = pd.DataFrame(rows).set_index('Model')
    display(HTML('<h3>ðŸ§ª Held-Out Test Set Results (20% of data, never seen during training)</h3>'))
    display(df_test)
else:
    print('No test metrics available.')

---
## 8. Visualise XAI Results

Compare feature rankings across all four XAI methods.

In [None]:
import pandas as pd
from IPython.display import display, HTML

OUT = cnn_cfg.output_dir
TOP_K = 15

def load_ranking(path, val_col, method_name):
    if not os.path.exists(path):
        print(f'  [{method_name}] Not found: {path}')
        return None
    df = pd.read_csv(path)
    df = df.sort_values(val_col, ascending=False).head(TOP_K).reset_index(drop=True)
    df.index += 1
    df.columns = ['Feature', method_name]
    return df

rankings = [
    load_ranking(f'{OUT}/cnn_transformer_integrated_gradients.csv',
                 'avg_abs_integrated_grad', 'Integrated Gradients'),
    load_ranking(f'{OUT}/cnn_transformer_grad_cam.csv',
                 'grad_cam_importance', 'Grad-CAM'),
    load_ranking(f'{OUT}/shap/shap_global_importance_attack.csv',
                 'mean_abs_shap', 'SHAP (CNN-Trans)'),
]

r = load_ranking(f'{OUT}/shap_enhanced/shap_global_importance_attack.csv',
                 'mean_abs_shap', 'SHAP (Enhanced)')
if r is not None:
    rankings.append(r)

rankings = [r for r in rankings if r is not None]
if rankings:
    combined = pd.concat([r.set_index(r.index) for r in rankings], axis=1)
    display(HTML(f'<h3>Top {TOP_K} Features â€” All XAI Methods</h3>'))
    display(combined)
else:
    print('No ranking files found.')

In [None]:
# Display the generated plots
from IPython.display import Image, display

OUT = cnn_cfg.output_dir
plots = [
    ('Grad-CAM Feature Importance', f'{OUT}/grad_cam_importance.png'),
    ('SHAP Summary (CNN-Transformer)', f'{OUT}/shap/shap_summary_attack.png'),
    ('SHAP Waterfall (CNN-Transformer)', f'{OUT}/shap/shap_waterfall_attack.png'),
    ('SHAP Summary (Enhanced)', f'{OUT}/shap_enhanced/shap_summary_attack.png'),
    ('SHAP Waterfall (Enhanced)', f'{OUT}/shap_enhanced/shap_waterfall_attack.png'),
]

for title, path in plots:
    if os.path.exists(path):
        print(f'\n--- {title} ---')
        display(Image(filename=path, width=700))
    else:
        print(f'[skip] {title} â€” file not found')

---
## 9. List All Generated Artifacts

Everything in `/kaggle/working/artifacts/` is automatically saved as notebook output.
Click **Output** tab (right panel) â†’ **Download All** to get the files.

In [None]:
OUT = cnn_cfg.output_dir
print(f'All generated artifacts in {OUT}:')
total_size = 0
for root, dirs, files in os.walk(OUT):
    for f in sorted(files):
        fp = os.path.join(root, f)
        sz = os.path.getsize(fp)
        total_size += sz
        if sz > 1024*1024:
            print(f'  {os.path.relpath(fp, OUT)}  ({sz/1024**2:.1f} MB)')
        else:
            print(f'  {os.path.relpath(fp, OUT)}  ({sz/1024:.1f} KB)')
print(f'\nTotal: {total_size/1024**2:.1f} MB')