# CNN + Capsules + Learnable Edge-Gated CapsGNN for GI Image Classification

This notebook extends the previous pipeline with the novelty:
- **Learnable Edge-Gated Capsule Graph (k-NN + Edge Gate MLP)**

Models compared in ablation:
1. CNN only
2. CNN + Capsules
3. CNN + CapsGNN (fixed cosine-kNN graph)
4. CNN + Caps + EdgeGatedGNN (**novelty**)

In [None]:
!pip install -q torch torchvision torchaudio
!pip install -q numpy tqdm scikit-learn matplotlib

In [None]:
import os
import glob
import random
import json
import zipfile
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms, models
from torchvision.models import ResNet18_Weights
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

import matplotlib.pyplot as plt

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 15
LR = 1e-4
WEIGHT_DECAY = 1e-5
K_NEIGHBORS = 8
MAX_CAPS_NODES = 128
EXPECTED_NUM_CLASSES = 8
EARLY_STOPPING_PATIENCE = 7
MIN_DELTA = 1e-4
NUM_WORKERS = min(4, os.cpu_count() or 0)

kaggle_matches = glob.glob('/kaggle/input/**/kvasir-dataset-v2', recursive=True)
kaggle_auto = kaggle_matches[0] if len(kaggle_matches) else None

DATA_CANDIDATES = [
    os.path.join(os.getcwd(), 'kvasir-dataset-v2'),
    r'c:/Users/sayye/OneDrive/Desktop/Minor_Project/kvasir-dataset-v2',
    r'/kaggle/input/datasets/sameer512100/kvasir-minor-project/kvasir-dataset-v2',
    r'/kaggle/input/kvasir-dataset-v2',
    kaggle_auto,
    '/kaggle/working/kvasir-dataset-v2'
 ]
DATA_DIR = next((p for p in DATA_CANDIDATES if p and os.path.isdir(p)), None)
if DATA_DIR is None:
    raise FileNotFoundError('kvasir-dataset-v2 not found. Update DATA_CANDIDATES in this cell.')

CHECKPOINT_DIR = os.path.join(os.getcwd(), 'checkpoints_edge_gated')
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
print('DATA_DIR:', DATA_DIR)
print('Checkpoint dir:', CHECKPOINT_DIR)

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.10, hue=0.02),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

base_dataset = datasets.ImageFolder(DATA_DIR)
class_names = base_dataset.classes
num_classes = len(class_names)
if num_classes != EXPECTED_NUM_CLASSES:
    raise ValueError(f'Expected {EXPECTED_NUM_CLASSES} classes, found {num_classes}: {class_names}')

n_samples = len(base_dataset)
train_size = int(0.70 * n_samples)
val_size = int(0.15 * n_samples)
test_size = n_samples - train_size - val_size

generator = torch.Generator().manual_seed(SEED)
indices = torch.randperm(n_samples, generator=generator).tolist()
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size + val_size]
test_indices = indices[train_size + val_size:]

train_base = datasets.ImageFolder(DATA_DIR, transform=train_transform)
val_base = datasets.ImageFolder(DATA_DIR, transform=eval_transform)
test_base = datasets.ImageFolder(DATA_DIR, transform=eval_transform)

train_dataset = Subset(train_base, train_indices)
val_dataset = Subset(val_base, val_indices)
test_dataset = Subset(test_base, test_indices)

pin_mem = torch.cuda.is_available()
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=pin_mem)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pin_mem)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=pin_mem)

print('Classes:', class_names)
print(f'Train/Val/Test: {len(train_dataset)}/{len(val_dataset)}/{len(test_dataset)}')

In [None]:
def squash(x, dim=-1, eps=1e-8):
    norm_sq = (x ** 2).sum(dim=dim, keepdim=True)
    scale = norm_sq / (1.0 + norm_sq)
    norm = torch.sqrt(norm_sq + eps)
    return scale * (x / norm)

class PrimaryCapsules(nn.Module):
    def __init__(self, in_channels, num_capsules=8, caps_dim=16, kernel_size=1, stride=1):
        super().__init__()
        self.num_capsules = num_capsules
        self.caps_dim = caps_dim
        self.conv = nn.Conv2d(in_channels, num_capsules * caps_dim, kernel_size=kernel_size, stride=stride)

    def forward(self, x):
        out = self.conv(x)
        b, _, h, w = out.shape
        out = out.view(b, self.num_capsules, self.caps_dim, h, w)
        out = out.permute(0, 3, 4, 1, 2).contiguous()
        out = out.view(b, h * w * self.num_capsules, self.caps_dim)
        return squash(out)

def reduce_capsule_nodes(capsules, max_nodes=128):
    b, n, d = capsules.shape
    if n <= max_nodes:
        return capsules
    return F.adaptive_avg_pool1d(capsules.transpose(1, 2), max_nodes).transpose(1, 2)

def build_knn_graph_fixed(capsules, k=8):
    b, n, _ = capsules.shape
    x = F.normalize(capsules, p=2, dim=-1)
    sim = torch.bmm(x, x.transpose(1, 2))

    k_eff = min(k, n)
    topk_vals, topk_idx = torch.topk(sim, k=k_eff, dim=-1)

    adj = torch.zeros_like(sim)
    adj.scatter_(dim=-1, index=topk_idx, src=F.relu(topk_vals))

    eye = torch.eye(n, device=capsules.device).unsqueeze(0).expand(b, -1, -1)
    adj = adj + eye
    row_sum = adj.sum(dim=-1, keepdim=True).clamp_min(1e-8)
    return adj / row_sum

class LearnableEdgeGate(nn.Module):
    """Builds gated k-NN adjacency using Edge Gate MLP."""
    def __init__(self, node_dim, hidden_dim=64):
        super().__init__()
        in_dim = 4 * node_dim + 1  # xi, xj, |xi-xj|, xi*xj, cosine(xi,xj)
        self.edge_mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, capsules, k=8):
        b, n, d = capsules.shape
        x_norm = F.normalize(capsules, p=2, dim=-1)
        sim = torch.bmm(x_norm, x_norm.transpose(1, 2))

        k_eff = min(k, n)
        topk_vals, topk_idx = torch.topk(sim, k=k_eff, dim=-1)  # [B, N, K]

        xi = capsules.unsqueeze(2).expand(-1, -1, k_eff, -1)  # [B, N, K, D]
        idx_exp = topk_idx.unsqueeze(-1).expand(-1, -1, -1, d)

        all_nodes = capsules.unsqueeze(1).expand(-1, n, -1, -1)  # [B, N, N, D]
        xj = torch.gather(all_nodes, dim=2, index=idx_exp)  # [B, N, K, D]

        edge_feat = torch.cat([
            xi,
            xj,
            torch.abs(xi - xj),
            xi * xj,
            topk_vals.unsqueeze(-1)
        ], dim=-1)

        gate = torch.sigmoid(self.edge_mlp(edge_feat)).squeeze(-1)  # [B, N, K]
        weights = gate * F.relu(topk_vals)

        adj = torch.zeros_like(sim)
        adj.scatter_(dim=-1, index=topk_idx, src=weights)

        eye = torch.eye(n, device=capsules.device).unsqueeze(0).expand(b, -1, -1)
        adj = adj + eye
        row_sum = adj.sum(dim=-1, keepdim=True).clamp_min(1e-8)
        return adj / row_sum

class GraphConv(nn.Module):
    def __init__(self, in_dim, out_dim, dropout=0.1):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.dropout = nn.Dropout(dropout)
        self.bn = nn.BatchNorm1d(out_dim)

    def forward(self, x, adj):
        h = self.linear(x)
        h = torch.bmm(adj, h)
        b, n, d = h.shape
        h = self.bn(h.reshape(b * n, d)).reshape(b, n, d)
        h = F.relu(h)
        return self.dropout(h)

class ResNet18Backbone(nn.Module):
    def __init__(self, pretrained=True, freeze_early_layers=True):
        super().__init__()
        weights = ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
        base = models.resnet18(weights=weights)
        self.features = nn.Sequential(*list(base.children())[:-2])  # [B, 512, 7, 7]

        if freeze_early_layers:
            for idx, module in enumerate(self.features):
                if idx < 6:
                    for p in module.parameters():
                        p.requires_grad = False

    def forward(self, x):
        return self.features(x)

class CNNBaseline(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = ResNet18Backbone(pretrained=True, freeze_early_layers=True)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(512, num_classes)

    def forward(self, x):
        feat = self.backbone(x)
        emb = self.pool(feat).flatten(1)
        return self.classifier(emb)

class CNNCapsClassifier(nn.Module):
    def __init__(self, num_classes, max_caps_nodes=128):
        super().__init__()
        self.max_caps_nodes = max_caps_nodes
        self.backbone = ResNet18Backbone(pretrained=True, freeze_early_layers=True)
        self.primary_caps = PrimaryCapsules(in_channels=512, num_capsules=8, caps_dim=16)
        self.classifier = nn.Sequential(
            nn.Linear(16, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        feat = self.backbone(x)
        caps = self.primary_caps(feat)
        caps = reduce_capsule_nodes(caps, self.max_caps_nodes)
        emb = caps.mean(dim=1)
        return self.classifier(emb)

class CNNCapsGNN(nn.Module):
    def __init__(self, num_classes, k_neighbors=8, max_caps_nodes=128):
        super().__init__()
        self.k_neighbors = k_neighbors
        self.max_caps_nodes = max_caps_nodes
        self.backbone = ResNet18Backbone(pretrained=True, freeze_early_layers=True)
        self.primary_caps = PrimaryCapsules(in_channels=512, num_capsules=8, caps_dim=16)
        self.gnn1 = GraphConv(16, 32, dropout=0.15)
        self.gnn2 = GraphConv(32, 32, dropout=0.15)
        self.classifier = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        feat = self.backbone(x)
        caps = self.primary_caps(feat)
        caps = reduce_capsule_nodes(caps, self.max_caps_nodes)
        adj = build_knn_graph_fixed(caps, k=self.k_neighbors)
        h = self.gnn1(caps, adj)
        h = self.gnn2(h, adj)
        return self.classifier(h.mean(dim=1))

class CNNCapsEdgeGNN(nn.Module):
    def __init__(self, num_classes, k_neighbors=8, max_caps_nodes=128):
        super().__init__()
        self.k_neighbors = k_neighbors
        self.max_caps_nodes = max_caps_nodes

        self.backbone = ResNet18Backbone(pretrained=True, freeze_early_layers=True)
        self.primary_caps = PrimaryCapsules(in_channels=512, num_capsules=8, caps_dim=16)

        self.edge_gate_builder = LearnableEdgeGate(node_dim=16, hidden_dim=64)
        self.gnn1 = GraphConv(16, 32, dropout=0.15)
        self.gnn2 = GraphConv(32, 32, dropout=0.15)

        self.classifier = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        feat = self.backbone(x)
        caps = self.primary_caps(feat)
        caps = reduce_capsule_nodes(caps, self.max_caps_nodes)

        adj = self.edge_gate_builder(caps, k=self.k_neighbors)
        h = self.gnn1(caps, adj)
        h = self.gnn2(h, adj)
        return self.classifier(h.mean(dim=1))

model_factories = {
    'CNN only': lambda: CNNBaseline(num_classes=num_classes),
    'CNN + Capsules': lambda: CNNCapsClassifier(num_classes=num_classes, max_caps_nodes=MAX_CAPS_NODES),
    'CNN + CapsGNN': lambda: CNNCapsGNN(num_classes=num_classes, k_neighbors=K_NEIGHBORS, max_caps_nodes=MAX_CAPS_NODES),
    'CNN + Caps + EdgeGatedGNN': lambda: CNNCapsEdgeGNN(num_classes=num_classes, k_neighbors=K_NEIGHBORS, max_caps_nodes=MAX_CAPS_NODES),
}
print('Ablation models:', list(model_factories.keys()))

In [None]:
def backup_checkpoints(checkpoint_dir, backup_zip_path):
    if not os.path.isdir(checkpoint_dir):
        return
    with zipfile.ZipFile(backup_zip_path, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
        for root, _, files in os.walk(checkpoint_dir):
            for fname in files:
                if fname.endswith('.pth') or fname.endswith('.json'):
                    fpath = os.path.join(root, fname)
                    arcname = os.path.relpath(fpath, checkpoint_dir)
                    zf.write(fpath, arcname=arcname)

def run_epoch(model, loader, criterion, optimizer=None):
    train_mode = optimizer is not None
    model.train() if train_mode else model.eval()

    all_preds, all_targets = [], []
    running_loss = 0.0

    for images, labels in tqdm(loader, leave=False):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        with torch.set_grad_enabled(train_mode):
            logits = model(images)
            loss = criterion(logits, labels)

            if train_mode:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        running_loss += loss.item()
        preds = logits.argmax(dim=1)
        all_preds.append(preds.detach().cpu().numpy())
        all_targets.append(labels.detach().cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    acc = accuracy_score(all_targets, all_preds)
    avg_loss = running_loss / max(len(loader), 1)
    return avg_loss, acc

def plot_training_history(model_name, history):
    epochs = history['epochs']

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], marker='o', label='Train Loss')
    plt.plot(epochs, history['val_loss'], marker='o', label='Val Loss')
    plt.title(f'Loss Curves - {model_name}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(alpha=0.25)
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, np.array(history['train_acc']) * 100, marker='o', label='Train Acc')
    plt.plot(epochs, np.array(history['val_acc']) * 100, marker='o', label='Val Acc')
    plt.title(f'Accuracy Curves - {model_name}')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.grid(alpha=0.25)
    plt.legend()

    plt.tight_layout()
    plt.show()

def train_single_model(model_name, model_factory):
    model = model_factory().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

    best_val_acc = 0.0
    epochs_no_improve = 0
    best_epoch = 0
    save_file = f"best_{model_name.lower().replace(' ', '_').replace('+', 'plus')}.pth"
    save_path = os.path.join(CHECKPOINT_DIR, save_file)

    history = {
        'epochs': [],
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }

    for epoch in range(1, EPOCHS + 1):
        train_loss, train_acc = run_epoch(model, train_loader, criterion, optimizer=optimizer)
        val_loss, val_acc = run_epoch(model, val_loader, criterion, optimizer=None)

        scheduler.step(val_acc)

        history['epochs'].append(epoch)
        history['train_loss'].append(float(train_loss))
        history['train_acc'].append(float(train_acc))
        history['val_loss'].append(float(val_loss))
        history['val_acc'].append(float(val_acc))

        print(f"[{model_name}] Epoch {epoch:02d}/{EPOCHS} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc * 100:.2f}% | Val Loss: {val_loss:.4f} | Val Acc: {val_acc * 100:.2f}%")

        if val_acc > best_val_acc + MIN_DELTA:
            best_val_acc = val_acc
            best_epoch = epoch
            epochs_no_improve = 0
            torch.save(model.state_dict(), save_path)
            print(f"[{model_name}] Best model saved -> {save_path}")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= EARLY_STOPPING_PATIENCE:
            print(f"[{model_name}] Early stopping at epoch {epoch}")
            break

    plot_training_history(model_name, history)

    return {
        'model_name': model_name,
        'save_path': save_path,
        'best_val_acc': float(best_val_acc),
        'best_epoch': int(best_epoch),
        'history': history
    }

ablation_results = []
for name, factory in model_factories.items():
    print(f"\n{'='*24} Training: {name} {'='*24}")
    result = train_single_model(name, factory)
    ablation_results.append(result)

print('\nAblation Summary (Validation):')
for res in ablation_results:
    print(f"{res['model_name']:<30} | Best Val Acc: {res['best_val_acc'] * 100:.2f}% | Best Epoch: {res['best_epoch']}")

plt.figure(figsize=(9, 4))
names = [r['model_name'] for r in ablation_results]
val_accs = [r['best_val_acc'] * 100 for r in ablation_results]
bars = plt.bar(names, val_accs)
plt.ylabel('Best Validation Accuracy (%)')
plt.title('Validation Accuracy Comparison Across Models')
plt.xticks(rotation=20, ha='right')
plt.grid(axis='y', alpha=0.25)
for b, v in zip(bars, val_accs):
    plt.text(b.get_x() + b.get_width()/2, b.get_height() + 0.2, f'{v:.2f}', ha='center', va='bottom', fontsize=9)
plt.tight_layout()
plt.show()

best_run = max(ablation_results, key=lambda x: x['best_val_acc'])
print(f"\nSelected best model for test/eval: {best_run['model_name']}")

run_manifest_path = os.path.join(CHECKPOINT_DIR, 'ablation_results.json')
with open(run_manifest_path, 'w', encoding='utf-8') as f:
    json.dump(ablation_results, f, indent=2)

backup_zip_path = os.path.join(os.getcwd(), 'checkpoints_edge_gated_backup.zip')
backup_checkpoints(CHECKPOINT_DIR, backup_zip_path)
print(f"Checkpoint backup zip created: {backup_zip_path}")

In [None]:
def evaluate_model(model, loader, criterion):
    model.eval()
    preds_all, targets_all = [], []
    running_loss = 0.0

    with torch.no_grad():
        for images, labels in tqdm(loader, leave=False):
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            logits = model(images)
            loss = criterion(logits, labels)

            running_loss += loss.item()
            preds_all.append(logits.argmax(dim=1).cpu().numpy())
            targets_all.append(labels.cpu().numpy())

    preds_all = np.concatenate(preds_all)
    targets_all = np.concatenate(targets_all)
    avg_loss = running_loss / max(len(loader), 1)
    acc = accuracy_score(targets_all, preds_all)
    return avg_loss, acc, targets_all, preds_all

criterion = nn.CrossEntropyLoss()
test_results = []

for res in ablation_results:
    name = res['model_name']
    ckpt = res['save_path']
    if not os.path.exists(ckpt):
        print(f'Skipping {name}: checkpoint not found at {ckpt}')
        continue

    model = model_factories[name]().to(device)
    model.load_state_dict(torch.load(ckpt, map_location=device))

    test_loss, test_acc, y_true, y_pred = evaluate_model(model, test_loader, criterion)
    report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True, zero_division=0)
    cm = confusion_matrix(y_true, y_pred)

    test_results.append({
        'model_name': name,
        'test_loss': float(test_loss),
        'test_acc': float(test_acc),
        'y_true': y_true,
        'y_pred': y_pred,
        'report': report,
        'confusion_matrix': cm.tolist()
    })

if len(test_results) == 0:
    raise RuntimeError('No test results found. Run the training cell first.')

print('\nAblation Summary (Test):')
for tr in test_results:
    macro_f1 = tr['report']['macro avg']['f1-score'] * 100
    print(f"{tr['model_name']:<30} | Test Loss: {tr['test_loss']:.4f} | Test Acc: {tr['test_acc'] * 100:.2f}% | Macro-F1: {macro_f1:.2f}%")

for tr in test_results:
    model_name = tr['model_name']
    y_true = tr['y_true']
    y_pred = tr['y_pred']
    report = tr['report']
    cm = np.array(tr['confusion_matrix'])

    print(f"\n{'='*12} Detailed Report: {model_name} {'='*12}")
    print(classification_report(y_true, y_pred, target_names=class_names, zero_division=0))

    per_class_f1 = [report[c]['f1-score'] * 100 for c in class_names]

    plt.figure(figsize=(16, 4.5))

    plt.subplot(1, 3, 1)
    im = plt.imshow(cm, cmap='Blues')
    plt.title(f'Confusion Matrix\n{model_name}')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.xticks(np.arange(len(class_names)), class_names, rotation=45, ha='right')
    plt.yticks(np.arange(len(class_names)), class_names)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, cm[i, j], ha='center', va='center', color='black', fontsize=7)
    plt.colorbar(im, fraction=0.046, pad=0.04)

    plt.subplot(1, 3, 2)
    bars = plt.bar(class_names, per_class_f1)
    plt.title(f'Per-Class F1 (%)\n{model_name}')
    plt.ylabel('F1-score (%)')
    plt.ylim(0, 100)
    plt.xticks(rotation=45, ha='right')
    plt.grid(axis='y', alpha=0.25)
    for b, v in zip(bars, per_class_f1):
        plt.text(b.get_x() + b.get_width()/2, v + 0.5, f'{v:.1f}', ha='center', va='bottom', fontsize=7)

    plt.subplot(1, 3, 3)
    metrics = ['Accuracy', 'Macro Precision', 'Macro Recall', 'Macro F1']
    values = [
        tr['test_acc'] * 100,
        report['macro avg']['precision'] * 100,
        report['macro avg']['recall'] * 100,
        report['macro avg']['f1-score'] * 100
    ]
    bars = plt.bar(metrics, values)
    plt.title(f'Overall Metrics (%)\n{model_name}')
    plt.ylim(0, 100)
    plt.xticks(rotation=20, ha='right')
    plt.grid(axis='y', alpha=0.25)
    for b, v in zip(bars, values):
        plt.text(b.get_x() + b.get_width()/2, v + 0.5, f'{v:.1f}', ha='center', va='bottom', fontsize=8)

    plt.tight_layout()
    plt.show()

test_summary_path = os.path.join(CHECKPOINT_DIR, 'test_results_detailed.json')
with open(test_summary_path, 'w', encoding='utf-8') as f:
    json.dump(test_results, f, indent=2)
print(f'Detailed test results saved at: {test_summary_path}')

plt.figure(figsize=(10, 4))
names = [r['model_name'] for r in test_results]
accs = [r['test_acc'] * 100 for r in test_results]
macro_f1s = [r['report']['macro avg']['f1-score'] * 100 for r in test_results]

x = np.arange(len(names))
w = 0.36
plt.bar(x - w/2, accs, width=w, label='Test Accuracy (%)')
plt.bar(x + w/2, macro_f1s, width=w, label='Macro-F1 (%)')
plt.xticks(x, names, rotation=20, ha='right')
plt.ylabel('Score (%)')
plt.title('Final Model Comparison on Test Set')
plt.grid(axis='y', alpha=0.25)
plt.legend()
plt.tight_layout()
plt.show()

best_test = max(test_results, key=lambda x: x['test_acc'])
print(f"\nBest test model: {best_test['model_name']} ({best_test['test_acc'] * 100:.2f}%)")

In [None]:
# Export trained artifacts for download/reuse (Kaggle-friendly)
import shutil
from IPython.display import FileLink, display

EXPORT_DIR = os.path.join(os.getcwd(), 'model_export_edge_gated')
os.makedirs(EXPORT_DIR, exist_ok=True)

# Copy checkpoint files + metadata
copied = []
for item in ablation_results:
    src = item.get('save_path', '')
    if src and os.path.exists(src):
        dst = os.path.join(EXPORT_DIR, os.path.basename(src))
        shutil.copy2(src, dst)
        copied.append(dst)

for meta_name in ['ablation_results.json', 'test_results_detailed.json']:
    meta_src = os.path.join(CHECKPOINT_DIR, meta_name)
    if os.path.exists(meta_src):
        shutil.copy2(meta_src, os.path.join(EXPORT_DIR, meta_name))

export_zip = os.path.join(os.getcwd(), 'edge_gated_model_export.zip')
with zipfile.ZipFile(export_zip, mode='w', compression=zipfile.ZIP_DEFLATED) as zf:
    for root, _, files in os.walk(EXPORT_DIR):
        for fname in files:
            fpath = os.path.join(root, fname)
            arcname = os.path.relpath(fpath, EXPORT_DIR)
            zf.write(fpath, arcname=arcname)

print('Exported checkpoint files:', len(copied))
for p in copied:
    print('-', p)

print('\nDownload this zip from Kaggle Files panel:')
print(export_zip)
display(FileLink(export_zip))

# Optional helper snippet for future sessions
print('\nFuture load example:')
print("model = CNNCapsEdgeGNN(num_classes=8, k_neighbors=8, max_caps_nodes=128).to(device)")
print("model.load_state_dict(torch.load('/kaggle/input/YOUR_DATASET/best_cnn_plus_caps_plus_edgegatedgnn.pth', map_location=device))")
print("model.eval()")

## Notes for Paper Write-up

- Novelty module: **Learnable Edge-Gated Capsule Graph (k-NN + Edge Gate MLP)**
- Compare against fixed-edge CapsGNN to isolate novelty gain.
- Report Accuracy + Macro-F1 and add qualitative edge-gate interpretation plots in future iteration.