# Hybrid CNN–ViT Training Notebook (Updated)

In [None]:
%uv pip install optuna timm torchmetrics

[2mUsing Python 3.12.6 environment at: /usr/local[0m
[2mAudited [1m3 packages[0m [2min 17ms[0m[0m
Note: you may need to restart the kernel to use updated packages.


## Import Required Libraries

In [None]:
import os
import math
import json

from typing import List

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import optuna
from optuna.trial import Trial

import numpy as np

from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score

MODEL_DIR = os.environ.get("MODEL_DIR", "/mnt/vit-hybrid-optuna")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Hyrbid CNN-ViT Pipeline

In [None]:
# Paths & IO (Modal-friendly)
MODEL_DIR = os.environ.get("MODEL_DIR", "/mnt/vit-hybrid-optuna")  # On Modal, mount a Volume at /mnt/vit-hybrid-optuna
os.makedirs(MODEL_DIR, exist_ok=True)

BEST_STATE_PATH = os.path.join(MODEL_DIR, "best_hybrid_cnn_vit.pth")
FULL_CKPT_PATH = os.path.join(MODEL_DIR, "hybrid_cnn_vit_full.pth")
FINETUNED_PATH = os.path.join(MODEL_DIR, "finetuned_model.pth")
STUDY_JSON_PATH = os.path.join(MODEL_DIR, "optuna_best.json")

# Reproducibility & Device
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using Device: {device}")

# Metrics Setup
num_classes = 10
avg_type = 'macro'  # 'macro' fair to all classes, 'micro' = global accuracy

metric_collection = MetricCollection({
    'accuracy': MulticlassAccuracy(num_classes=num_classes, average=avg_type),
    'precision': MulticlassPrecision(num_classes=num_classes, average=avg_type),
    'recall': MulticlassRecall(num_classes=num_classes, average=avg_type),
    'f1_score': MulticlassF1Score(num_classes=num_classes, average=avg_type)
}).to(device)

# Data Loading
def get_dataloaders(batch_size=128):
    """Load CIFAR-10 with augmentation"""
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    return train_loader, test_loader

# CNN Feature Extractors
class EfficientNetFeatureExtractor(nn.Module):
    """EfficientNet-B0 multi-scale feature extractor with taps [2, 4, 6]."""
    def __init__(self, model_name='efficientnet_b0'):
        super().__init__()
        from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
        base_model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        self.features = base_model.features
        # Channels at indices [2, 4, 6] for EfficientNet-B0
        self.out_channels = [24, 80, 192]

    def forward(self, x):
        feats = []
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i in [2, 4, 6]:
                feats.append(x)
        return feats

class ResNetFeatureExtractor(nn.Module):
    """ResNet-18 multi-scale feature extractor returning [layer2, layer3, layer4]."""
    def __init__(self, depth='resnet18'):
        super().__init__()
        from torchvision.models import resnet18, ResNet18_Weights
        base_model = resnet18(weights=ResNet18_Weights.DEFAULT)
        self.conv1 = base_model.conv1
        self.bn1 = base_model.bn1
        self.relu = base_model.relu
        self.maxpool = base_model.maxpool
        self.layer1 = base_model.layer1
        self.layer2 = base_model.layer2
        self.layer3 = base_model.layer3
        self.layer4 = base_model.layer4
        # We return [f2, f3, f4] → channels [128, 256, 512]
        self.out_channels = [128, 256, 512]
        
    def forward(self, x):
        x = self.conv1(x); x = self.bn1(x); x = self.relu(x); x = self.maxpool(x)
        f1 = self.layer1(x)  # 64 ch
        f2 = self.layer2(f1) # 128 ch
        f3 = self.layer3(f2) # 256 ch
        f4 = self.layer4(f3) # 512 ch
        return [f2, f3, f4]

# Adaptive Patch Embedding
class AdaptivePatchEmbedding(nn.Module):
    """Adaptive patch embedding with learnable pos-enc and class token."""
    def __init__(self, in_channels, embed_dim, max_img_size=8):
        super().__init__()
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=1, stride=1)
        self.max_patches = max_img_size * max_img_size
        self.pos_embed = nn.Parameter(torch.zeros(1, self.max_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x)                  # (B, embed_dim, H, W)
        x = x.flatten(2).transpose(1, 2)  # (B, H*W, embed_dim)
        num_patches = x.shape[1]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, 1+H*W, embed_dim)
        pos_embed = self.pos_embed if (num_patches + 1 == self.pos_embed.shape[1]) \
            else self.interpolate_pos_encoding(H, W)
        return x + pos_embed
    
    def interpolate_pos_encoding(self, H, W):
        N = self.pos_embed.shape[1] - 1
        class_pos_embed = self.pos_embed[:, 0:1]
        patch_pos_embed = self.pos_embed[:, 1:]
        orig_size = int(math.sqrt(N))
        patch_pos_embed = patch_pos_embed.reshape(1, orig_size, orig_size, -1).permute(0, 3, 1, 2)
        patch_pos_embed = F.interpolate(patch_pos_embed, size=(H, W), mode='bilinear', align_corners=False)
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, self.embed_dim)
        return torch.cat([class_pos_embed, patch_pos_embed], dim=1)

# Multi-Scale Feature Fusion
class MultiScaleFeatureFusion(nn.Module):
    """Fuse features from multiple scales to a unified embed_dim."""
    def __init__(self, channels_list: List[int], out_channels: int):
        super().__init__()
        self.projections = nn.ModuleList([nn.Conv2d(ch, out_channels, 1) for ch in channels_list])
        self.fusion = nn.Sequential(
            nn.Conv2d(out_channels * len(channels_list), out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, features: List[torch.Tensor]):
        assert len(features) == len(self.projections), \
            f"Got {len(features)} features but {len(self.projections)} projections"
        target_size = features[-1].shape[2:]
        aligned = []
        for idx, (feat, proj) in enumerate(zip(features, self.projections)):
            assert feat.shape[1] == proj.in_channels, \
                f"Fusion input ch mismatch at idx {idx}: tensor has {feat.shape[1]}, proj expects {proj.in_channels}"
            feat = proj(feat)
            if feat.shape[2:] != target_size:
                feat = F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
            aligned.append(feat)
        fused = torch.cat(aligned, dim=1)
        return self.fusion(fused)

# Vision Transformer Encoder
class MultiHeadSelfAttention(nn.Module):
    """Multi-head self-attention mechanism"""
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        if embed_dim % num_heads != 0:
            raise ValueError(f"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})")
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(dropout)
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class TransformerBlock(nn.Module):
    """Transformer encoder block"""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# Hybrid CNN-ViT Model
class HybridCNNViT(nn.Module):
    """Hybrid CNN-ViT with adaptive patch embedding and multi-scale fusion"""
    def __init__(self, 
                 backbone='efficientnet',
                 embed_dim=256,
                 num_heads=8,
                 num_layers=4,
                 mlp_ratio=4.0,
                 dropout=0.1,
                 num_classes=10):
        super().__init__()
        if backbone == 'efficientnet':
            self.backbone = EfficientNetFeatureExtractor()
        else:
            self.backbone = ResNetFeatureExtractor()
        channels_list = self.backbone.out_channels  # single source of truth
        
        self.fusion = MultiScaleFeatureFusion(channels_list, embed_dim)
        self.patch_embed = AdaptivePatchEmbedding(embed_dim, embed_dim, max_img_size=8)
        self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
                                     for _ in range(num_layers)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim // 2, num_classes)
        )
        
    def forward(self, x):
        features = self.backbone(x)
        fused = self.fusion(features)
        x = self.patch_embed(fused)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        cls_token = x[:, 0]
        logits = self.head(cls_token)
        return logits

# Training & Evaluation
def train_epoch(model, loader, optimizer, criterion, device):
    """train_epoch function"""
    model.train()
    total_loss = 0.0; correct = 0; total = 0
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return total_loss / len(loader), 100. * correct / total

# Modified evaluate function to use torchmetrics
def evaluate(model, loader, criterion, device):
    """
    Evaluate function modified to use torchmetrics.
    Now returns a dictionary containing all metrics.
    """
    model.eval()
    total_loss = 0.0
    metric_collection.reset() # Reset metrics at each evaluation

    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            total_loss += loss.item()

            metric_collection.update(outputs, targets)

    final_metrics = metric_collection.compute()
    final_loss = total_loss / len(loader)

    metrics_dict = {k: v.item() for k, v in final_metrics.items()}
    metrics_dict['loss'] = final_loss

    metrics_dict['accuracy'] = metrics_dict['accuracy'] * 100 
    
    return metrics_dict

CHOICES_HEADS = [4, 6, 8, 12]

def _effective_num_heads(embed_dim: int, trial_choice: int) -> int:
    """Map a trial choice to a valid head count given embed_dim (avoid dynamic value space)."""
    valid = [h for h in CHOICES_HEADS if embed_dim % h == 0]
    if not valid:
        # Fallback: ensure at least 1 head (should not happen with provided embed_dim choices)
        return 1
    if trial_choice in valid:
        return trial_choice
    # Deterministic projection: pick the largest valid head <= trial_choice, else pick max(valid)
    smaller_or_equal = [h for h in valid if h <= trial_choice]
    return max(smaller_or_equal) if smaller_or_equal else max(valid)

def objective(trial: Trial):
    """Optuna objective function"""
    backbone = trial.suggest_categorical('backbone', ['efficientnet', 'resnet'])
    embed_dim = trial.suggest_categorical('embed_dim', [128, 192, 256, 384])
    
    # Fixed categorical (no dynamic space). We project to a valid value after sampling.
    num_heads_raw = trial.suggest_categorical('num_heads', CHOICES_HEADS)
    num_heads = _effective_num_heads(embed_dim, num_heads_raw)
    trial.set_user_attr('num_heads_effective', num_heads)

    num_layers = trial.suggest_int('num_layers', 3, 6)
    mlp_ratio = trial.suggest_float('mlp_ratio', 2.0, 4.0)
    dropout = trial.suggest_float('dropout', 0.1, 0.3)
    lr = trial.suggest_float('lr', 1e-4, 1e-3, log=True)
    batch_size = trial.suggest_categorical('batch_size', [64, 128, 256])
    
    train_loader, test_loader = get_dataloaders(batch_size)
    
    model = HybridCNNViT(
        backbone=backbone,
        embed_dim=embed_dim,
        num_heads=num_heads,
        num_layers=num_layers,
        mlp_ratio=mlp_ratio,
        dropout=dropout,
        num_classes=10
    ).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    best_acc = 0.0
    patience = 0
    max_patience = 3
    
    for epoch in range(15): # Reduced for faster tuning
        train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)

        test_metrics = evaluate(model, test_loader, criterion, device)
        test_acc = test_metrics["accuracy"]
        test_f1 = test_metrics["f1_score"]
        
        scheduler.step()
        trial.report(test_acc, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()
        if test_acc > best_acc:
            best_acc = test_acc; patience = 0
        else:
            patience += 1
            if patience >= max_patience:
                break
        
        print(f"Epoch {epoch+1}: Train Acc={train_acc:.2f}%, "
              f"Test Acc={test_acc:.2f}%, Test F1={test_f1:.4f}, Test Loss={test_metrics['loss']:.4f}")
              
    return best_acc

if __name__ == "__main__":
    print("Hybrid CNN-ViT Architecture - Optuna Tuning")

    print("\n[1/2] Running Optuna Hyperparameter Optimization...\n")
    study = optuna.create_study(
        direction='maximize',
        pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=5)
    )
    study.optimize(objective, n_trials = 3, timeout=3600)  # 10 trials or 1 hour

    print("Optimization Results:")
    print(f"Best Accuracy: {study.best_value:.2f}%")
    print(f"Best Hyperparameters: {study.best_params}")

    # Persist Optuna best to JSON (optional, helpful on Modal)
    try:
        import json
        with open(STUDY_JSON_PATH, "w") as f:
            json.dump({"best_value": study.best_value, "best_params": study.best_params}, f, indent=2)
        print(f"✓ Saved Optuna Best Summary to: {STUDY_JSON_PATH}")
    except Exception as e:
        print(f"Could not Write Study Summary: {e}")

    # Train final model with best hyperparameters
    print("\n[2/2] Training final model with best hyperparameters...")
    best_params = study.best_params
    
    # Derive effective heads again (consistent with objective)
    num_heads_eff = _effective_num_heads(best_params['embed_dim'], best_params.get('num_heads', 8))

    train_loader, test_loader = get_dataloaders(best_params['batch_size'])

    final_model = HybridCNNViT(
        backbone=best_params['backbone'],
        embed_dim=best_params['embed_dim'],
        num_heads=num_heads_eff,
        num_layers=best_params['num_layers'],
        mlp_ratio=best_params['mlp_ratio'],
        dropout=best_params['dropout'],
        num_classes=10
    ).to(device)

    optimizer = torch.optim.AdamW(final_model.parameters(), 
                                  lr=best_params['lr'], 
                                  weight_decay=0.05)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    best_acc = 0.0
    print("\nTraining Final Model for 30 Epochs")

    for epoch in range(30):
        train_loss, train_acc = train_epoch(final_model, train_loader, optimizer, criterion, device)

        test_metrics = evaluate(final_model, test_loader, criterion, device)
        test_acc = test_metrics["accuracy"]
        test_f1 = test_metrics["f1_score"]

        scheduler.step()
        if test_acc > best_acc:
            best_acc = test_acc
            try:
                torch.save(final_model.state_dict(), BEST_STATE_PATH)
                print(f"✓ Saved Best state_dict to: {BEST_STATE_PATH} (acc={best_acc:.2f}%)")
            except Exception as e:
                print(f"Save Failed: {e}")
        
        if (epoch + 1) % 5 == 0:

            print(f"Epoch {epoch+1}/30: Train Acc={train_acc:.2f}%, "
                  f"Test Acc={test_acc:.2f}%, Test F1={test_f1:.4f}, Test Loss={test_metrics['loss']:.4f}")

    print("Final Results:")
    print(f"Best Test Accuracy: {best_acc:.2f}%")
    print(f"Target Achieved: {'✓ YES' if best_acc >= 90 else '✗ NO'}")

    # Model summary
    total_params = sum(p.numel() for p in final_model.parameters())
    trainable_params = sum(p.numel() for p in final_model.parameters() if p.requires_grad)
    print(f"\nModel Parameters: {total_params:,} (Trainable: {trainable_params:,})")
    print(f"\nModel State Saved at: {BEST_STATE_PATH}")

    print("Transfer Learning Utilities")

    def save_full_model(model, hyperparams, path=FULL_CKPT_PATH):
        """Save complete model with hyperparameters for transfer learning"""
        try:
            torch.save({
                'model_state_dict': model.state_dict(),
                'hyperparameters': hyperparams,
                'model_architecture': 'HybridCNNViT',
                'training_info': {
                    'dataset': 'CIFAR-10',
                    'num_classes': 10,
                    'best_accuracy': best_acc
                }
            }, path)
            print(f"✓ Full Model Saved to: {path}")
        except Exception as e:
            print(f"Failed to Save Full Model: {e}")

    def load_pretrained_model(path=FULL_CKPT_PATH, num_classes_new=10, freeze_backbone=False):
        """
        Load pretrained model and adapt for new dataset
        """
        checkpoint = torch.load(path, map_location=device)
        hyperparams = checkpoint['hyperparameters']
        heads_eff = _effective_num_heads(hyperparams['embed_dim'], hyperparams.get('num_heads', 8))
        model = HybridCNNViT(
            backbone=hyperparams['backbone'],
            embed_dim=hyperparams['embed_dim'],
            num_heads=heads_eff,
            num_layers=hyperparams['num_layers'],
            mlp_ratio=hyperparams['mlp_ratio'],
            dropout=hyperparams['dropout'],
            num_classes=checkpoint['training_info']['num_classes']
        ).to(device)
        model.load_state_dict(checkpoint['model_state_dict'])
        if num_classes_new != checkpoint['training_info']['num_classes']:
            print(f"Adapting Model: {checkpoint['training_info']['num_classes']} → {num_classes_new} classes")
            model.head = nn.Sequential(
                nn.Linear(hyperparams['embed_dim'], hyperparams['embed_dim'] // 2),
                nn.ReLU(),
                nn.Dropout(hyperparams['dropout']),
                nn.Linear(hyperparams['embed_dim'] // 2, num_classes_new)
            ).to(device)
        if freeze_backbone:
            print("Freezing CNN Backbone Layers...")
            for p in model.backbone.parameters(): p.requires_grad = False
            for p in model.fusion.parameters(): p.requires_grad = False
        return model, hyperparams

    def fine_tune_on_new_dataset(model, train_loader, test_loader, 
                                 num_epochs = 5, lr=1e-4):
        """
        Fine-tune pretrained model on new dataset
        """
        print(f"\nFine-tuning for {num_epochs} Epochs with lr={lr}...")
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr, 
            weight_decay=0.05
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        
        best_ft_acc = 0.0
        for epoch in range(num_epochs):
            train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
            
            test_metrics = evaluate(model, test_loader, criterion, device)
            test_acc = test_metrics["accuracy"]
            
            scheduler.step()
            if test_acc > best_ft_acc:
                best_ft_acc = test_acc
                try:
                    torch.save(model.state_dict(), FINETUNED_PATH)
                    print(f"✓ Saved Finetuned state_dict to: {FINETUNED_PATH} (acc={best_ft_acc:.2f}%)")
                except Exception as e:
                    print(f"Save Failed: {e}")
            if (epoch + 1) % 5 == 0:
                print(f"Epoch {epoch+1}/{num_epochs}: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")
        print(f"\n✓ Fine-tuning Complete! Best Accuracy: {best_ft_acc:.2f}%")
        return best_ft_acc

    # Save full model with hyperparameters
    save_full_model(final_model, best_params, FULL_CKPT_PATH)

    print("Transfer Learning Guide")

    print("""
        To use this model on other datasets:
        
        1. Similiar Datasets (e.g., CIFAR-100, STL-10, Tiny ImageNet):
           - Load pretrained model with freeze_backbone=False
           - Fine-tune entire model with lower learning rate
        
        2. Differents Domains (e.g., Medical images, Satellite imagery):
           - Load pretrained model with freeze_backbone=True
           - Train transformer layers and head first
        
        3. Different Image Sizes:
           - Resize images to 32x32 OR
           - Modify patch_embed layer for new resolution and let pos-enc interpolate
    """)

    print("Files Saved:")
    print(f"Best state dict       : {BEST_STATE_PATH}")
    print(f"Full checkpoint (+hps) : {FULL_CKPT_PATH}")
    print(f"Finetuned state dict   : {FINETUNED_PATH}")
    print(f"Optuna best summary    : {STUDY_JSON_PATH}")

[I 2025-11-03 03:57:46,228] A new study created in memory with name: no-name-58c5b714-23f9-444d-be19-350200d6a5c7


[INFO] Using Device: cuda
[INFO] Hybrid CNN-ViT Architecture - Optuna Tuning
[INFO] 
[1/2] Running Optuna Hyperparameter Optimization...



  0%|                                                                               | 0.00/170M [00:00<?, ?B/s]  0%|▎                                                                      | 688k/170M [00:00<00:24, 6.85MB/s]  5%|███▏                                                                  | 7.70M/170M [00:00<00:03, 43.8MB/s] 10%|███████                                                               | 17.1M/170M [00:00<00:02, 66.7MB/s] 15%|██████████▊                                                           | 26.3M/170M [00:00<00:01, 76.6MB/s] 21%|██████████████▍                                                       | 35.2M/170M [00:00<00:01, 80.7MB/s] 26%|██████████████████▏                                                   | 44.4M/170M [00:00<00:01, 84.8MB/s] 31%|█████████████████████▉                                                | 53.5M/170M [00:00<00:01, 86.6MB/s] 37%|█████████████████████████▉                                            | 63.1M/170M [00:00<00:01, 8

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


  0%|                                                                              | 0.00/20.5M [00:00<?, ?B/s]  4%|██▌                                                                   | 768k/20.5M [00:00<00:02, 7.43MB/s] 35%|████████████████████████▍                                            | 7.25M/20.5M [00:00<00:00, 42.1MB/s] 74%|███████████████████████████████████████████████████                  | 15.1M/20.5M [00:00<00:00, 60.1MB/s]100%|█████████████████████████████████████████████████████████████████████| 20.5M/20.5M [00:00<00:00, 57.4MB/s]


[INFO] Epoch 1: Train Acc=51.63%, Test Acc=72.38%, Test F1=0.7197, Test Loss=1.1272
[INFO] Epoch 2: Train Acc=66.82%, Test Acc=78.45%, Test F1=0.7821, Test Loss=0.9955
[INFO] Epoch 3: Train Acc=71.05%, Test Acc=79.98%, Test F1=0.7969, Test Loss=0.9528
[INFO] Epoch 4: Train Acc=73.86%, Test Acc=81.92%, Test F1=0.8170, Test Loss=0.9184
[INFO] Epoch 5: Train Acc=75.50%, Test Acc=82.61%, Test F1=0.8241, Test Loss=0.9016
[INFO] Epoch 6: Train Acc=77.33%, Test Acc=83.34%, Test F1=0.8327, Test Loss=0.8783
[INFO] Epoch 7: Train Acc=78.06%, Test Acc=84.48%, Test F1=0.8440, Test Loss=0.8535
[INFO] Epoch 8: Train Acc=79.05%, Test Acc=84.79%, Test F1=0.8464, Test Loss=0.8483
[INFO] Epoch 9: Train Acc=79.80%, Test Acc=85.28%, Test F1=0.8521, Test Loss=0.8421
[INFO] Epoch 10: Train Acc=80.06%, Test Acc=85.30%, Test F1=0.8518, Test Loss=0.8422
[INFO] Epoch 11: Train Acc=79.74%, Test Acc=85.14%, Test F1=0.8504, Test Loss=0.8416
[INFO] Epoch 12: Train Acc=80.27%, Test Acc=85.13%, Test F1=0.8503, Test L

[I 2025-11-03 04:01:31,468] Trial 0 finished with value: 85.30000448226929 and parameters: {'backbone': 'efficientnet', 'embed_dim': 128, 'num_heads': 4, 'num_layers': 4, 'mlp_ratio': 2.3883077120553216, 'dropout': 0.18282095801199652, 'lr': 0.0006486992336048424, 'batch_size': 256}. Best is trial 0 with value: 85.30000448226929.


[INFO] Epoch 1: Train Acc=51.58%, Test Acc=71.02%, Test F1=0.7054, Test Loss=1.1569
[INFO] Epoch 2: Train Acc=65.72%, Test Acc=76.58%, Test F1=0.7634, Test Loss=1.0304
[INFO] Epoch 3: Train Acc=70.13%, Test Acc=79.40%, Test F1=0.7926, Test Loss=0.9714
[INFO] Epoch 4: Train Acc=72.75%, Test Acc=80.83%, Test F1=0.8062, Test Loss=0.9407
[INFO] Epoch 5: Train Acc=74.66%, Test Acc=81.81%, Test F1=0.8170, Test Loss=0.9149
[INFO] Epoch 6: Train Acc=75.83%, Test Acc=82.79%, Test F1=0.8272, Test Loss=0.8873
[INFO] Epoch 7: Train Acc=77.38%, Test Acc=83.00%, Test F1=0.8283, Test Loss=0.8818
[INFO] Epoch 8: Train Acc=78.00%, Test Acc=83.45%, Test F1=0.8329, Test Loss=0.8684
[INFO] Epoch 9: Train Acc=78.65%, Test Acc=83.94%, Test F1=0.8380, Test Loss=0.8632
[INFO] Epoch 10: Train Acc=78.84%, Test Acc=83.91%, Test F1=0.8377, Test Loss=0.8608
[INFO] Epoch 11: Train Acc=78.93%, Test Acc=83.96%, Test F1=0.8383, Test Loss=0.8604
[INFO] Epoch 12: Train Acc=79.09%, Test Acc=83.88%, Test F1=0.8372, Test L

[I 2025-11-03 04:06:09,902] Trial 1 finished with value: 84.19000506401062 and parameters: {'backbone': 'efficientnet', 'embed_dim': 256, 'num_heads': 12, 'num_layers': 5, 'mlp_ratio': 3.746607213994121, 'dropout': 0.16402878887680009, 'lr': 0.0004365165085324784, 'batch_size': 256}. Best is trial 0 with value: 85.30000448226929.


[INFO] Epoch 15: Train Acc=78.97%, Test Acc=83.93%, Test F1=0.8381, Test Loss=0.8647
[INFO] Epoch 1: Train Acc=49.70%, Test Acc=70.15%, Test F1=0.6954, Test Loss=1.1815
[INFO] Epoch 2: Train Acc=63.38%, Test Acc=75.18%, Test F1=0.7489, Test Loss=1.0640
[INFO] Epoch 3: Train Acc=67.76%, Test Acc=77.50%, Test F1=0.7731, Test Loss=1.0149
[INFO] Epoch 4: Train Acc=70.26%, Test Acc=79.15%, Test F1=0.7896, Test Loss=0.9729
[INFO] Epoch 5: Train Acc=71.83%, Test Acc=80.50%, Test F1=0.8030, Test Loss=0.9455
[INFO] Epoch 6: Train Acc=73.66%, Test Acc=81.67%, Test F1=0.8142, Test Loss=0.9169
[INFO] Epoch 7: Train Acc=74.82%, Test Acc=82.02%, Test F1=0.8193, Test Loss=0.9069
[INFO] Epoch 8: Train Acc=75.34%, Test Acc=82.57%, Test F1=0.8243, Test Loss=0.8963
[INFO] Epoch 9: Train Acc=76.11%, Test Acc=82.85%, Test F1=0.8272, Test Loss=0.8923
[INFO] Epoch 10: Train Acc=76.22%, Test Acc=83.05%, Test F1=0.8291, Test Loss=0.8903
[INFO] Epoch 11: Train Acc=76.49%, Test Acc=82.92%, Test F1=0.8278, Test L

[I 2025-11-03 04:11:26,072] Trial 2 finished with value: 83.17000269889832 and parameters: {'backbone': 'efficientnet', 'embed_dim': 256, 'num_heads': 6, 'num_layers': 5, 'mlp_ratio': 3.8810242636868932, 'dropout': 0.21411406441653213, 'lr': 0.0002110510081711246, 'batch_size': 128}. Best is trial 0 with value: 85.30000448226929.


[INFO] Epoch 15: Train Acc=76.41%, Test Acc=83.08%, Test F1=0.8292, Test Loss=0.8822
[INFO] Optimization Results:
[INFO] Best Accuracy: 85.30%
[INFO] Best Hyperparameters: {'backbone': 'efficientnet', 'embed_dim': 128, 'num_heads': 4, 'num_layers': 4, 'mlp_ratio': 2.3883077120553216, 'dropout': 0.18282095801199652, 'lr': 0.0006486992336048424, 'batch_size': 256}
[INFO] ✓ Saved Optuna Best Summary to: /mnt/vit-hybrid-optuna/optuna_best.json
[INFO] 
[2/2] Training final model with best hyperparameters...
[INFO] 
Training Final Model for 30 Epochs
[INFO] ✓ Saved Best state_dict to: /mnt/vit-hybrid-optuna/best_hybrid_cnn_vit.pth (acc=72.69%)
[INFO] ✓ Saved Best state_dict to: /mnt/vit-hybrid-optuna/best_hybrid_cnn_vit.pth (acc=78.15%)
[INFO] ✓ Saved Best state_dict to: /mnt/vit-hybrid-optuna/best_hybrid_cnn_vit.pth (acc=80.03%)
[INFO] ✓ Saved Best state_dict to: /mnt/vit-hybrid-optuna/best_hybrid_cnn_vit.pth (acc=81.62%)
[INFO] ✓ Saved Best state_dict to: /mnt/vit-hybrid-optuna/best_hybrid

## Ensemble Learning: Weighted Averaging (Logits) - Post-Training Experiments (Optional)

In [None]:
"""
Ensemble Learning: Weighted Averaging (Logits)

Purpose
-------
Build an ensemble classifier by combining logits from multiple models using
weighted averaging. This implementation supports a HybridCNNViT member
(if available) and optional ResNet-18 / ViT-Tiny members. It evaluates
members and the ensemble on CIFAR-10 with torchmetrics and performs a small
grid search to find the best weights by accuracy.

Key behaviors
-------------
- Uses CIFAR-10 test set with standard normalization.
- Tries to include HybridCNNViT from:
  1) a class already defined in the current runtime, or
  2) a module named in the TRAIN_MODULE environment variable (e.g., "train_hybrid").
- Optionally includes ResNet-18 and ViT-Tiny if their checkpoints are present.
- Evaluates metrics: accuracy, precision, recall, F1-macro.
- Saves an ensemble_summary.json artifact with weights, metrics, and members.

Assumptions
-----------
- Torch, Torchvision, Torchmetrics, and timm are installed.
- CIFAR-10 will be downloaded if not found.

Usage
-----
- Ensure HybridCNNViT is defined earlier in the notebook OR set:
      os.environ["TRAIN_MODULE"] = "train_hybrid"
  and make sure train_hybrid.py defines `HybridCNNViT` and is on PYTHONPATH.
- Optionally place checkpoints in MODEL_DIR (default: /mnt/vit-hybrid-optuna):
    - hybrid_cnn_vit_full.pth OR both best_hybrid_cnn_vit.pth + optuna_best.json
    - best_resnet18.pth (optional)
    - best_vit_tiny.pth (optional)
- Run this cell. It evaluates members, searches simple weights, prints results,
  and writes ensemble_summary.json.
"""

import os
import json
from typing import List, Optional
import importlib

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from torchmetrics import MetricCollection
from torchmetrics.classification import (
    MulticlassAccuracy, MulticlassPrecision, MulticlassRecall, MulticlassF1Score
)
import timm

# Minimal logging helpers
def _emit(tag: str, *objects, sep=" ", end="\n", file=None, flush=False):
    msg = sep.join(str(o) for o in objects)
    # avoid double-tagging if caller already passed a tag
    if msg.startswith(("[INFO]", "[ERROR]", "[RESULT]", "[WARN]")):
        out = msg
    else:
        out = f"{tag} {msg}"
    print(out, sep="", end=end, file=file, flush=flush)

def info(*a, **k):   _emit("[INFO]",   *a, **k)
def warn(*a, **k):   _emit("[WARN]",   *a, **k)
def error(*a, **k):  _emit("[ERROR]",  *a, **k)
def result(*a, **k): _emit("[RESULT]", *a, **k)

# Configuration and device
MODEL_DIR = os.environ.get("MODEL_DIR", "/mnt/vit-hybrid-optuna")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
info("Using device:", device)
info("Artifacts directory:", MODEL_DIR)

# TorchMetrics setup
try:
    metric_collection  # noqa: F821
    _metrics_ready = True
except NameError:
    _metrics_ready = False

if not _metrics_ready:
    NUM_CLASSES = 10  # CIFAR-10 by default; adjust for other datasets
    AVG = "macro"
    metric_collection = MetricCollection({
        "accuracy":  MulticlassAccuracy(num_classes=NUM_CLASSES, average=AVG),
        "precision": MulticlassPrecision(num_classes=NUM_CLASSES, average=AVG),
        "recall":    MulticlassRecall(num_classes=NUM_CLASSES, average=AVG),
        "f1_score":  MulticlassF1Score(num_classes=NUM_CLASSES, average=AVG),
    }).to(device)
    info("Initialized torchmetrics collection.")

# Utility: valid num_heads
_CHOICES_HEADS = [4, 6, 8, 12]
def _effective_num_heads_local(embed_dim: int, trial_choice: Optional[int]) -> int:
    valid = [h for h in _CHOICES_HEADS if embed_dim % h == 0]
    if not valid:
        return 1
    if trial_choice in valid:
        return trial_choice  # type: ignore[arg-type]
    return max(valid)

# CIFAR-10 test loader
def _get_testloader(batch_size: int = 128):
    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    data_root = os.path.join(MODEL_DIR, "data")
    try:
        ds = datasets.CIFAR10(root=data_root, train=False, download=False, transform=tf)
    except RuntimeError:
        warn("CIFAR-10 test set not found in artifacts volume. Attempting download.")
        ds = datasets.CIFAR10(root=data_root, train=False, download=True, transform=tf)

    dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return dl, ds.classes

# Try to locate HybridCNNViT
HybridCNNViT = None  # will be assigned if found

# If already defined earlier, keep it
try:
    _ = HybridCNNViT  # type: ignore # noqa: F821
    info("HybridCNNViT is already defined in this runtime.")
except NameError:
    # Try import from TRAIN_MODULE if provided
    train_module_name = os.environ.get("TRAIN_MODULE", None)
    if train_module_name:
        try:
            train_module = importlib.import_module(train_module_name)
            HybridCNNViT = getattr(train_module, "HybridCNNViT", None)
            if HybridCNNViT is None:
                raise AttributeError("HybridCNNViT not found in the provided training module.")
            info(f"Imported HybridCNNViT from module '{train_module_name}'.")
        except Exception as e:
            error("Failed to import HybridCNNViT from TRAIN_MODULE. Details:", e)
            HybridCNNViT = None
    else:
        warn("HybridCNNViT is not defined and TRAIN_MODULE is not set. Hybrid member will be skipped.")
        HybridCNNViT = None

# Model builders/loaders
FULL_CKPT_PATH = os.path.join(MODEL_DIR, "hybrid_cnn_vit_full.pth")
BEST_STATE_PATH = os.path.join(MODEL_DIR, "best_hybrid_cnn_vit.pth")
STUDY_JSON_PATH = os.path.join(MODEL_DIR, "optuna_best.json")

def load_hybrid_from_full(full_ckpt_path: str) -> nn.Module:
    if HybridCNNViT is None:
        raise RuntimeError("HybridCNNViT class is unavailable; cannot load 'full' checkpoint.")
    ckpt = torch.load(full_ckpt_path, map_location=device)
    hp = ckpt["hyperparameters"]
    heads_eff = _effective_num_heads_local(hp["embed_dim"], hp.get("num_heads"))
    m = HybridCNNViT(
        backbone=hp["backbone"],
        embed_dim=hp["embed_dim"],
        num_heads=heads_eff,
        num_layers=hp["num_layers"],
        mlp_ratio=hp["mlp_ratio"],
        dropout=hp["dropout"],
        num_classes=ckpt["training_info"]["num_classes"],
    ).to(device).eval()
    m.load_state_dict(ckpt["model_state_dict"])
    return m

def load_hybrid_from_best(best_state_path: str, hp_json: str) -> nn.Module:
    if HybridCNNViT is None:
        raise RuntimeError("HybridCNNViT class is unavailable; cannot load 'best' checkpoint.")
    with open(hp_json, "r") as f:
        hp = json.load(f)["best_params"]
    heads_eff = _effective_num_heads_local(hp["embed_dim"], hp.get("num_heads"))
    m = HybridCNNViT(
        backbone=hp["backbone"],
        embed_dim=hp["embed_dim"],
        num_heads=heads_eff,
        num_layers=hp["num_layers"],
        mlp_ratio=hp["mlp_ratio"],
        dropout=hp["dropout"],
        num_classes=10,  # adjust if your dataset differs
    ).to(device).eval()
    m.load_state_dict(torch.load(best_state_path, map_location=device))
    return m

def build_resnet18(num_classes: int = 10) -> nn.Module:
    from torchvision.models import resnet18, ResNet18_Weights
    m = resnet18(weights=ResNet18_Weights.DEFAULT)
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m.to(device).eval()

def build_vit_tiny(num_classes: int = 10) -> nn.Module:
    m = timm.create_model("vit_tiny_patch16_224", pretrained=True, num_classes=num_classes)
    # Optional adaptation for small images if that matches your training pipeline
    m.patch_embed.patch_size = (4, 4)
    warn("Using ViT-Tiny with patch_size=4 to adapt to small images.")
    return m.to(device).eval()

# Collect ensemble members
models: List[nn.Module] = []

# HybridCNNViT (if available and checkpoints exist)
if HybridCNNViT is not None and os.path.exists(FULL_CKPT_PATH):
    info("Loading HybridCNNViT from FULL_CKPT_PATH:", FULL_CKPT_PATH)
    try:
        models.append(load_hybrid_from_full(FULL_CKPT_PATH))
    except Exception as e:
        error("Failed to load HybridCNNViT from FULL_CKPT_PATH. Details:", e)

elif HybridCNNViT is not None and os.path.exists(BEST_STATE_PATH) and os.path.exists(STUDY_JSON_PATH):
    info("Loading HybridCNNViT from BEST_STATE_PATH:", BEST_STATE_PATH)
    try:
        models.append(load_hybrid_from_best(BEST_STATE_PATH, STUDY_JSON_PATH))
    except Exception as e:
        error("Failed to load HybridCNNViT from BEST_STATE_PATH. Details:", e)
else:
    if HybridCNNViT is None:
        warn("Skipping HybridCNNViT because the class is unavailable.")
    else:
        warn("HybridCNNViT checkpoints not found. Skipping Hybrid member.")

# Optional members (if checkpoints exist)
RESNET_PTH = os.path.join(MODEL_DIR, "best_resnet18.pth")
VIT_PTH    = os.path.join(MODEL_DIR, "best_vit_tiny.pth")

if os.path.exists(RESNET_PTH):
    info("Loading ResNet-18 member from:", RESNET_PTH)
    try:
        m_res = build_resnet18(num_classes=10)
        m_res.load_state_dict(torch.load(RESNET_PTH, map_location=device))
        models.append(m_res.eval())
    except Exception as e:
        error("Failed to load ResNet-18. Details:", e)

if os.path.exists(VIT_PTH):
    info("Loading ViT-Tiny member from:", VIT_PTH)
    try:
        m_vit = build_vit_tiny(num_classes=10)
        m_vit.load_state_dict(torch.load(VIT_PTH, map_location=device))
        models.append(m_vit.eval())
    except Exception as e:
        error("Failed to load ViT-Tiny. Details:", e)

info("Total ensemble members:", len(models))
for i, m in enumerate(models, start=1):
    info(f"  Model {i}: {type(m).__name__}")

if len(models) == 0:
    error("No models are available for the ensemble.")
    error("To include HybridCNNViT, either define the class earlier in the notebook, or set:")
    error("os.environ['TRAIN_MODULE'] = 'train_hybrid'  # ensure train_hybrid.py defines HybridCNNViT and is on PYTHONPATH")
    raise RuntimeError("No ensemble members found. Provide at least one valid checkpoint or define/import HybridCNNViT.")

# Ensemble wrapper (logits)
class LogitsEnsemble(nn.Module):
    def __init__(self, models: List[nn.Module], weights: Optional[List[float]] = None):
        super().__init__()
        self.models = nn.ModuleList(models)
        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        s = sum(weights)
        self.register_buffer("weights", torch.tensor([w / s for w in weights], dtype=torch.float32))

    @torch.no_grad()
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        logits_sum = None
        for w, m in zip(self.weights, self.models):
            out = m(x)
            logits_sum = w * out if logits_sum is None else logits_sum + w * out
        return logits_sum

# Evaluation with metrics
@torch.no_grad()
def eval_with_metrics(model: nn.Module, loader: DataLoader) -> dict:
    model.eval()
    metric_collection.reset()
    total_loss, n = 0.0, 0
    criterion = nn.CrossEntropyLoss()
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        total_loss += loss.item()
        metric_collection.update(logits, yb)
        n += 1
    m = metric_collection.compute()
    out = {k: v.item() for k, v in m.items()}
    out["loss"] = total_loss / max(1, n)
    out["accuracy"] *= 100.0  # percentage
    return out

# Grid search over weights
def simple_weight_search(models: List[nn.Module], step: float = 0.1):
    info("Loading test data for weight search.")
    testloader, class_names = _get_testloader()
    n_models = len(models)

    info("Evaluating individual models.")
    indiv_stats = []
    for i, m in enumerate(models, start=1):
        stats = eval_with_metrics(m, testloader)
        info(f"  Model {i} ({type(m).__name__}) accuracy: {stats['accuracy']:.2f}%")
        indiv_stats.append(stats)

    if n_models == 1:
        warn("Only one model available. Returning weight [1.0].")
        return [1.0], indiv_stats[0], class_names

    best_w, best_stats, best_acc = None, None, -1.0
    grid = [round(i * step, 2) for i in range(int(1 / step) + 1)]
    info(f"Starting weight search (n_models={n_models}, step={step}).")

    if n_models == 2:
        for w0 in grid:
            weights = [w0, round(1.0 - w0, 2)]
            ens = LogitsEnsemble(models, weights).to(device)
            stats = eval_with_metrics(ens, testloader)
            if stats["accuracy"] > best_acc:
                best_acc, best_w, best_stats = stats["accuracy"], weights, stats
    elif n_models == 3:
        for w0 in grid:
            for w1 in [x for x in grid if x <= round(1.0 - w0, 2)]:
                w2 = round(1.0 - w0 - w1, 2)
                weights = [w0, w1, w2]
                ens = LogitsEnsemble(models, weights).to(device)
                stats = eval_with_metrics(ens, testloader)
                if stats["accuracy"] > best_acc:
                    best_acc, best_w, best_stats = stats["accuracy"], weights, stats
    else:
        warn("Weight search for n > 3 is not implemented. Using simple average.")
        weights = [1.0 / n_models] * n_models
        ens = LogitsEnsemble(models, weights).to(device)
        stats = eval_with_metrics(ens, testloader)
        best_w, best_stats = weights, stats

    return best_w, best_stats, class_names

# Run ensemble search & report
weights, stats, class_names = simple_weight_search(models, step=0.1)

result("Ensemble weights:", weights)
result(f"Accuracy: {stats['accuracy']:.2f}%")
result(f"F1-Macro: {stats['f1_score']:.4f}")
result(f"Precision: {stats['precision']:.4f}")
result(f"Recall: {stats['recall']:.4f}")
result(f"Loss: {stats['loss']:.4f}")

# Save artifacts
ENSEMBLE_JSON = os.path.join(MODEL_DIR, "ensemble_summary.json")
try:
    with open(ENSEMBLE_JSON, "w") as f:
        json.dump({
            "weights": weights,
            "metrics": stats,
            "members": [type(m).__name__ for m in models]
        }, f, indent=2)
    result("Ensemble summary saved to:", ENSEMBLE_JSON)
except Exception as e:
    error("Failed to save ensemble summary. Details:", e)

# Optional: single-image API
@torch.no_grad()
def predict_image_ensemble(pil_image, ensemble_weights=None, class_names=class_names):
    """
    Predict a single PIL image with the current ensemble.

    Parameters
    ----------
    pil_image : PIL.Image.Image
        The input image.
    ensemble_weights : list[float] or None
        Optional custom weights to override the best weights found.
    class_names : list[str]
        Class labels corresponding to the model outputs.

    Returns
    -------
    dict
        {"label": str, "index": int, "confidence": float}
    """
    global models, weights
    tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])
    x = tf(pil_image).unsqueeze(0).to(device)
    active_weights = ensemble_weights or weights
    ens = LogitsEnsemble(models, active_weights).to(device)
    logits = ens(x)
    prob = logits.softmax(1)[0]
    idx = int(prob.argmax().item())
    conf = float(prob.max().item())
    label = class_names[idx] if class_names else str(idx)
    return {"label": label, "index": idx, "confidence": conf}

info("Helper 'predict_image_ensemble' is ready.")

[INFO] Using device: cuda
[INFO] Artifacts directory: /mnt/vit-hybrid-optuna
[INFO] HybridCNNViT is already defined in this runtime.
[WARN] Skipping HybridCNNViT because the class is unavailable.
[INFO] Total ensemble members: 0
[ERROR] No models are available for the ensemble.
[ERROR] To include HybridCNNViT, either define the class earlier in the notebook, or set:
[ERROR] os.environ['TRAIN_MODULE'] = 'train_hybrid'  # ensure train_hybrid.py defines HybridCNNViT and is on PYTHONPATH


RuntimeError: No ensemble members found. Provide at least one valid checkpoint or define/import HybridCNNViT.