In [4]:
import sys
from pathlib import Path

# Ensure project root is in path
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import warnings
warnings.filterwarnings('ignore')

import torch
import numpy as np
from torch.utils.data import DataLoader

import transformer.config as config
from preprocess import load_cached_data
from cnn.utils import calculate_metrics
from transformer.model import TransformerClassifier
from transformer.dataset import EmbeddingDataset
from transformer.utils import compute_embeddings, extract_features


In [None]:
def load_ensemble(device):
    ensemble_path = config.CACHE_DIR / "transformer_ensemble_info.pt"
    if not ensemble_path.exists():
        return None, None

    info = torch.load(ensemble_path, map_location=device, weights_only=False)
    top_indices = info["top_indices"]
    input_dim = info["input_dim"]
    num_extra_features = info.get("num_extra_features", 0)

    models = []
    for idx in top_indices:
        fold_path = config.CACHE_DIR / f"transformer_model_fold_{idx}.pt"
        if fold_path.exists():
            checkpoint = torch.load(fold_path, map_location=device, weights_only=False)
            model = TransformerClassifier(
                input_dim=input_dim,
                hidden_dim=256,
                dropout=config.DROPOUT,
                num_extra_features=num_extra_features,
            ).to(device)
            model.load_state_dict(checkpoint["model_state_dict"])
            model.eval()
            models.append(model)

    return models, info