# ðŸ§¬ StrandWeaver XGBoost Retraining on Colab (GPU)

Retrain all **7 XGBoost models** with **hybrid resampling**, **focal loss**,
and **GPU-accelerated hyperparameter tuning** using XGBoost 2.x
(`device='cuda'`, `tree_method='hist'`).

**v2.1 Improvements (S1-S4):**
- **S1**: SV-dense training data (10-50Ã— higher SV density for balanced classes)
- **S2**: Two-stage SV detection â€” binary detector (SV vs none) + 4-class subtype
- **S3**: Focal loss objective for SV models (down-weights easy examples)
- **S4**: 8 new SV-specific features (19 â†’ 27 total)

**What this does:**
1. Mounts Google Drive and extracts pre-generated graph CSVs (200 genomes, 1.4 GB)
2. Applies hybrid resampling for class-imbalanced tasks (undersample majority
   to 100k + oversample minorities to median)
3. Runs **Optuna Bayesian hyperparameter sweep** across all 7 models on GPU
4. Retrains with sweep-winning configs + 5-fold cross-validation
5. Saves model weights + metrics back to Google Drive for download

**Hybrid strategy** was benchmarked against 6 alternatives on 1.2M edges from
200 synthetic genomes: **+33% F1-macro** over the previous class-weighting baseline.

**Runtime:** Set to **GPU** via `Runtime â†’ Change runtime type â†’ T4 GPU`

**Prep (local, one-time):**
```bash
cd strandweaver      # dev branch
./scripts/package_training_data.sh   # â†’ graph_csvs.tar.gz (~400 MB)
# Upload graph_csvs.tar.gz to Google Drive: My Drive/Colab Notebooks/
```

## 1. Setup

In [None]:
# â”€â”€ Verify GPU â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
import subprocess
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'],
                        capture_output=True, text=True)
if result.returncode == 0:
    print(f"âœ“ GPU detected: {result.stdout.strip()}")
    GPU_AVAILABLE = True
else:
    print("âš  No GPU detected â€” will use CPU (slower)")
    GPU_AVAILABLE = False

In [None]:
# â”€â”€ Install dependencies â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
!pip install -q xgboost scikit-learn numpy pandas optuna
print("\nâœ“ Dependencies installed (including Optuna for Bayesian HP search)")

## 2. Load Training Data from Google Drive

In [None]:
# â”€â”€ Mount Google Drive & extract training data â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
from google.colab import drive
import os, tarfile, glob, shutil

drive.mount('/content/drive')

# â”€â”€ Paths â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
GDRIVE_DIR = '/content/drive/MyDrive/Colab Notebooks'
TARBALL = os.path.join(GDRIVE_DIR, 'graph_csvs.tar.gz')
OUTPUT_DIR = '/content/trained_models_v2'
GDRIVE_OUTPUT = os.path.join(GDRIVE_DIR, 'trained_models_v2.tar.gz')

# Model save subdirectories
SAVE_MAP = {
    'edge_ai':       'edgewarden',
    'path_gnn':      'pathgnn',
    'diploid_ai':    'diploid',
    'ul_routing':    'ul_routing',
    'sv_ai':         'sv_detector',
    'sv_ai_binary':  'sv_detector',
    'sv_ai_subtype': 'sv_detector',
}
TECH_LIST = ['hifi', 'ont_r9', 'ont_r10', 'illumina', 'adna']

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Extract tarball to local SSD (much faster I/O than Drive)
assert os.path.exists(TARBALL), f"Tarball not found at {TARBALL}"
print(f"Extracting {TARBALL} ...")
with tarfile.open(TARBALL, 'r:gz') as tar:
    tar.extractall('/content/')

DATA_DIR = '/content/training_data_10x' if os.path.isdir('/content/training_data_10x') else '/content/training_data'

all_csvs = glob.glob(f'{DATA_DIR}/**/*.csv', recursive=True)
print(f"âœ“ Extracted {len(all_csvs)} CSVs to {DATA_DIR}")
print(f"âœ“ Output dir: {OUTPUT_DIR}")
print(f"âœ“ Results will be saved back to: {GDRIVE_OUTPUT}")

## 3. Model Definitions & Training Infrastructure

In [None]:
import numpy as np
import pandas as pd
import xgboost as xgb
import json
import pickle
import time
from pathlib import Path
from collections import Counter
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, r2_score

# â”€â”€ Schema v2.0 â€” Feature definitions (must match graph_training_data.py) â”€â”€

# Metadata columns prepended to every CSV row (skipped during training)
METADATA_COLUMNS = [
    'genome_id', 'genome_size', 'chromosome_id', 'read_technology',
    'coverage_depth', 'error_rate', 'ploidy', 'gc_content_global',
    'repeat_density_global', 'heterozygosity_rate', 'random_seed',
    'generator_version', 'schema_version',
]

EDGE_AI_FEATURES = [
    'overlap_length', 'overlap_identity', 'read1_length', 'read2_length',
    'coverage_r1', 'coverage_r2', 'gc_content_r1', 'gc_content_r2',
    'repeat_fraction_r1', 'repeat_fraction_r2',
    'kmer_diversity_r1', 'kmer_diversity_r2',
    'branching_factor_r1', 'branching_factor_r2',
    'hic_support', 'mapping_quality_r1', 'mapping_quality_r2',
    # v2.0: graph topology
    'clustering_coeff_r1', 'clustering_coeff_r2', 'component_size',
    # v2.0: sequence complexity
    'entropy_r1', 'entropy_r2', 'homopolymer_max_r1', 'homopolymer_max_r2',
]

EDGE_AI_PROVENANCE = [
    'node_id_r1', 'node_id_r2',
    'read1_haplotype', 'read2_haplotype',
    'genomic_distance', 'is_repeat_region',
]

PATH_GNN_FEATURES = [
    'overlap_length', 'overlap_identity', 'coverage_consistency',
    'gc_similarity', 'repeat_match', 'branching_score',
    'path_support', 'hic_contact', 'mapping_quality',
    'kmer_match', 'sequence_complexity', 'orientation_score',
    'distance_score', 'topology_score', 'ul_support', 'sv_evidence',
]

PATH_GNN_PROVENANCE = [
    'node_id_r1', 'node_id_r2',
    'read1_haplotype', 'read2_haplotype',
    'genomic_distance', 'is_repeat_region',
]

NODE_SIGNAL_FEATURES = [
    'coverage', 'gc_content', 'repeat_fraction', 'kmer_diversity',
    'branching_factor', 'hic_contact_density', 'allele_frequency',
    'heterozygosity', 'phase_consistency', 'mappability',
    'hic_intra_contacts', 'hic_inter_contacts',
    'hic_contact_ratio', 'hic_phase_signal',
    # v2.0: graph topology
    'clustering_coeff', 'component_size',
    # v2.0: sequence complexity
    'shannon_entropy', 'dinucleotide_bias',
    'homopolymer_max_run', 'homopolymer_density', 'low_complexity_fraction',
    # v2.0: coverage distribution
    'coverage_skewness', 'coverage_kurtosis', 'coverage_cv',
    'coverage_p10', 'coverage_p90',
]

NODE_PROVENANCE = [
    'node_id', 'read_haplotype', 'read_start_pos', 'read_end_pos',
    'read_length', 'is_in_repeat', 'read_technology',
]

UL_ROUTE_FEATURES = [
    'path_length', 'num_branches', 'coverage_mean', 'coverage_std',
    'sequence_identity', 'mapping_quality', 'num_gaps', 'gap_size_mean',
    'kmer_consistency', 'orientation_consistency', 'ul_span', 'route_complexity',
]

SV_DETECT_FEATURES = [
    'coverage_mean', 'coverage_std', 'coverage_median',
    'gc_content', 'repeat_fraction', 'kmer_diversity',
    'branching_complexity', 'hic_disruption_score',
    'ul_support', 'mapping_quality',
    'region_length', 'breakpoint_precision',
    'allele_balance', 'phase_switch_rate',
    # v2.0: coverage distribution
    'coverage_cv', 'coverage_skewness', 'coverage_kurtosis',
    'coverage_p10', 'coverage_p90',
    # v2.1: SV-specific features (S4 improvement)
    'depth_ratio_flank', 'split_read_count', 'clip_fraction',
    'bubble_size', 'path_divergence', 'ul_spanning',
    'coverage_drop_magnitude', 'orientation_switch_rate',
]

# Columns to skip when loading CSVs (metadata + provenance)
_NON_FEATURE_COLUMNS = set(METADATA_COLUMNS) | set(EDGE_AI_PROVENANCE) | set(
    PATH_GNN_PROVENANCE) | set(NODE_PROVENANCE)

# â”€â”€ Model specifications â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# GPU-tuned XGBoost parameters (XGBoost 2.x API):
#   tree_method='hist'  â€” unified histogram method (gpu_hist is deprecated)
#   device='cuda'       â€” run on GPU (falls back to 'cpu' if no GPU)
#   max_bin=1024        â€” finer splits, uses more GPU memory (default 256)
DEVICE = 'cuda' if GPU_AVAILABLE else 'cpu'
GPU_PARAMS = {
    'tree_method': 'hist',
    'device': DEVICE,
    'max_bin': 1024,
}
print(f"Device: {DEVICE}, max_bin: {GPU_PARAMS['max_bin']}")

MODEL_SPECS = {
    'edge_ai': {
        'csv_glob': '**/edge_ai_training_g*.csv',
        'features': EDGE_AI_FEATURES,
        'label_col': 'label',
        'task': 'multiclass',
        'xgb_params': {
            'max_depth': 6, 'learning_rate': 0.1, 'n_estimators': 100,
            **GPU_PARAMS,
        },
        'desc': 'Edge scoring (TRUE/ALLELIC/CHIMERIC/SV_BREAK/REPEAT)',
    },
    'path_gnn': {
        'csv_glob': '**/path_gnn_training_g*.csv',
        'features': PATH_GNN_FEATURES,
        'label_col': 'in_correct_path',
        'task': 'binary',
        'xgb_params': {
            'max_depth': 6, 'learning_rate': 0.1, 'n_estimators': 100,
            **GPU_PARAMS,
        },
        'desc': 'Path edge scoring (binary)',
    },
    'diploid_ai': {
        'csv_glob': '**/diploid_ai_training_g*.csv',
        'features': NODE_SIGNAL_FEATURES,
        'label_col': 'haplotype_label',
        'task': 'multiclass',
        'label_transform': lambda lbl: lbl.replace('HAP_', ''),
        'xgb_params': {
            'max_depth': 10, 'learning_rate': 0.03, 'n_estimators': 500,
            'subsample': 0.8, 'colsample_bytree': 0.8,
            'min_child_weight': 5, 'gamma': 0.1,
            'reg_alpha': 0.1, 'reg_lambda': 1.0,
            **GPU_PARAMS,
        },
        'desc': 'Haplotype phasing (A/B) â€” v2.0 topology+complexity features',
    },
    'ul_routing': {
        'csv_glob': '**/ul_route_training_g*.csv',
        'features': UL_ROUTE_FEATURES,
        'label_col': 'route_score',
        'task': 'regression',
        'xgb_params': {
            'max_depth': 6, 'learning_rate': 0.1, 'n_estimators': 100,
            **GPU_PARAMS,
        },
        'desc': 'Ultra-long routing score (regression, continuous 0-1)',
    },
    'sv_ai': {
        'csv_glob': '**/sv_detect_training_g*.csv',
        'features': SV_DETECT_FEATURES,
        'label_col': 'sv_type',
        'task': 'multiclass',
        'focal_loss': True,
        'focal_alpha': 0.25,
        'focal_gamma': 2.0,
        'xgb_params': {
            'max_depth': 8, 'learning_rate': 0.05, 'n_estimators': 300,
            'subsample': 0.8, 'colsample_bytree': 0.8,
            'min_child_weight': 3,
            **GPU_PARAMS,
        },
        'desc': 'SV detection (5-class, focal loss) â€” v2.1 27 features',
    },
    'sv_ai_binary': {
        'csv_glob': '**/sv_detect_training_g*.csv',
        'features': SV_DETECT_FEATURES,
        'label_col': 'sv_type',
        'task': 'binary',
        'label_transform': lambda lbl: 0 if lbl == 'none' else 1,
        'xgb_params': {
            'max_depth': 6, 'learning_rate': 0.1, 'n_estimators': 200,
            'subsample': 0.8, 'colsample_bytree': 0.8,
            **GPU_PARAMS,
        },
        'desc': 'SV binary detector (SV vs no-SV, stage 1 of two-stage)',
    },
    'sv_ai_subtype': {
        'csv_glob': '**/sv_detect_training_g*.csv',
        'features': SV_DETECT_FEATURES,
        'label_col': 'sv_type',
        'task': 'multiclass',
        'label_filter': lambda lbl: lbl != 'none',
        'focal_loss': True,
        'focal_alpha': 0.25,
        'focal_gamma': 2.0,
        'xgb_params': {
            'max_depth': 8, 'learning_rate': 0.05, 'n_estimators': 300,
            'subsample': 0.8, 'colsample_bytree': 0.8,
            'min_child_weight': 3,
            **GPU_PARAMS,
        },
        'desc': 'SV subtype classifier (4-class, focal loss, stage 2 of two-stage)',
    },
}

print(f"\nSchema v2.1 â€” Defined {len(MODEL_SPECS)} models")
for name, spec in MODEL_SPECS.items():
    fl = ' [focal]' if spec.get('focal_loss') else ''
    lf = ' [filtered]' if spec.get('label_filter') else ''
    print(f"  {name}: {spec['desc']} ({len(spec['features'])} features{fl}{lf})")

In [None]:
# â”€â”€ Resampling & Training utilities â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€

def undersample(X, y, max_per_class=100_000, rng=None):
    """Downsample classes with more than max_per_class samples."""
    if rng is None:
        rng = np.random.default_rng(42)
    classes, counts = np.unique(y, return_counts=True)
    keep = []
    for cls, cnt in zip(classes, counts):
        idx = np.where(y == cls)[0]
        if cnt > max_per_class:
            idx = rng.choice(idx, max_per_class, replace=False)
        keep.append(idx)
    keep = np.concatenate(keep)
    rng.shuffle(keep)
    return X[keep], y[keep]


def oversample(X, y, target_count=None, rng=None):
    """Random-oversample minority classes up to target_count (default: median class size)."""
    if rng is None:
        rng = np.random.default_rng(42)
    classes, counts = np.unique(y, return_counts=True)
    if target_count is None:
        target_count = int(np.median(counts))
    parts_X, parts_y = [X], [y]
    for cls, cnt in zip(classes, counts):
        if cnt < target_count:
            idx = np.where(y == cls)[0]
            extra = target_count - cnt
            sampled = rng.choice(idx, extra, replace=True)
            parts_X.append(X[sampled])
            parts_y.append(y[sampled])
    return np.concatenate(parts_X), np.concatenate(parts_y)


def hybrid_resample(X, y, max_majority=100_000, rng=None):
    """Hybrid: undersample majority to max_majority, then oversample minorities to new median.

    Benchmarked against 6 alternatives on 1.2 M edges from 200 synthetic
    genomes: +33% F1-macro over class-weighting baseline.
    """
    X_u, y_u = undersample(X, y, max_per_class=max_majority, rng=rng)
    X_h, y_h = oversample(X_u, y_u, rng=rng)
    return X_h, y_h


# â”€â”€ Focal Loss for class-imbalanced classification â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
def _focal_loss_objective(alpha=0.25, gamma=2.0, n_classes=5):
    """Return (objective_fn, eval_metric_fn) for XGBoost custom focal loss.

    Focal loss down-weights easy/well-classified examples so the model
    focuses on hard minority cases like inversions and duplications.

    References:
        Lin et al., "Focal Loss for Dense Object Detection", ICCV 2017
    """
    def focal_obj(predt: np.ndarray, dtrain: xgb.DMatrix):
        labels = dtrain.get_label().astype(int)
        n = len(labels)
        # predt shape: (n * n_classes,) â€” reshape to (n, n_classes)
        preds = predt.reshape(n, n_classes)
        # Softmax
        preds = preds - preds.max(axis=1, keepdims=True)
        exp_p = np.exp(preds)
        softmax = exp_p / exp_p.sum(axis=1, keepdims=True)
        softmax = np.clip(softmax, 1e-7, 1.0 - 1e-7)

        # One-hot encode labels
        one_hot = np.zeros_like(softmax)
        one_hot[np.arange(n), labels] = 1.0

        # p_t = probability of true class
        p_t = (softmax * one_hot).sum(axis=1, keepdims=True)

        # Focal weight: alpha * (1 - p_t)^gamma
        focal_weight = alpha * (1.0 - p_t) ** gamma

        # Gradient and hessian of focal cross-entropy
        grad = focal_weight * (softmax - one_hot)
        hess = focal_weight * softmax * (1.0 - softmax)
        # Ensure positive hessian
        hess = np.maximum(hess, 1e-7)

        return grad.reshape(-1), hess.reshape(-1)

    def focal_eval(predt: np.ndarray, dtrain: xgb.DMatrix):
        labels = dtrain.get_label().astype(int)
        n = len(labels)
        preds = predt.reshape(n, n_classes)
        preds = preds - preds.max(axis=1, keepdims=True)
        exp_p = np.exp(preds)
        softmax = exp_p / exp_p.sum(axis=1, keepdims=True)
        softmax = np.clip(softmax, 1e-7, 1.0 - 1e-7)

        p_t = softmax[np.arange(n), labels]
        loss = -alpha * ((1.0 - p_t) ** gamma) * np.log(p_t)
        return 'focal_loss', float(np.mean(loss))

    return focal_obj, focal_eval


def load_csvs(data_dir, csv_glob, features, label_col,
              label_transform=None, label_filter=None):
    """Load and concatenate all matching CSVs into X, y arrays.

    v2.1 schema: Supports label_filter (callable) to exclude rows
    (e.g., remove 'none' class for subtype-only training).
    """
    csv_files = sorted(glob.glob(f'{data_dir}/{csv_glob}', recursive=True))
    if not csv_files:
        return None, None, None

    dfs = []
    for f in csv_files:
        try:
            df = pd.read_csv(f)
            if all(c in df.columns for c in features + [label_col]):
                # Apply label_filter before appending
                if label_filter is not None:
                    mask = df[label_col].apply(lambda v: label_filter(str(v)))
                    df = df[mask]
                    if len(df) == 0:
                        continue
                dfs.append(df)
            else:
                missing = [c for c in features + [label_col] if c not in df.columns]
                print(f"  \u26a0 Skipping {Path(f).name}: missing columns {missing[:5]}")
        except Exception as e:
            print(f"  \u26a0 Error reading {Path(f).name}: {e}")
            continue

    if not dfs:
        return None, None, None

    combined = pd.concat(dfs, ignore_index=True)
    if 'schema_version' in combined.columns:
        versions = combined['schema_version'].unique()
        print(f"  Schema version(s): {list(versions)}")

    X = combined[features].values.astype(np.float32)
    labels = combined[label_col].values
    if label_transform:
        labels = np.array([label_transform(str(l)) for l in labels])
    label_dist = Counter(labels)
    return X, labels, label_dist


def train_model(X, y, xgb_params, task, val_split=0.15, seed=42,
                focal_loss_config=None):
    """Train XGBoost with val split, return (model, metrics) with per-class F1.

    focal_loss_config: dict with keys 'alpha', 'gamma', 'n_classes' to enable
    focal loss objective instead of default multi:softprob.
    """
    is_binary = task == 'binary' or len(set(y.tolist())) == 2

    if task == 'regression':
        X_tr, X_va, y_tr, y_va = train_test_split(
            X, y, test_size=val_split, random_state=seed)
        params = dict(xgb_params, random_state=seed, verbosity=0,
                      early_stopping_rounds=10, objective='reg:squarederror',
                      eval_metric='rmse')
        model = xgb.XGBRegressor(**params)
        model.fit(X_tr, y_tr, eval_set=[(X_va, y_va)], verbose=False)
        y_pred = model.predict(X_va)
        rmse = float(np.sqrt(mean_squared_error(y_va, y_pred)))
        r2 = float(r2_score(y_va, y_pred))
        return model, {'val_rmse': round(rmse, 4), 'val_r2': round(r2, 4),
                       'train_size': len(X_tr), 'val_size': len(X_va)}

    # Classification (no sample_weight â€” we use resampling instead)
    stratify = y if len(set(y.tolist())) > 1 else None
    X_tr, X_va, y_tr, y_va = train_test_split(
        X, y, test_size=val_split, random_state=seed, stratify=stratify)

    params = dict(xgb_params, random_state=seed, use_label_encoder=False,
                  verbosity=0, early_stopping_rounds=10)

    # â”€â”€ Focal loss path â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    if focal_loss_config and not is_binary:
        n_classes = focal_loss_config.get('n_classes', int(len(set(y.tolist()))))
        alpha = focal_loss_config.get('alpha', 0.25)
        gamma = focal_loss_config.get('gamma', 2.0)
        focal_obj, focal_eval = _focal_loss_objective(alpha, gamma, n_classes)

        params['disable_default_eval_metric'] = True
        params['num_class'] = n_classes
        # Remove conflicting keys
        params.pop('objective', None)
        params.pop('eval_metric', None)

        dtrain = xgb.DMatrix(X_tr, label=y_tr)
        dval = xgb.DMatrix(X_va, label=y_va)

        bst = xgb.train(
            params, dtrain,
            num_boost_round=params.pop('n_estimators', 300),
            obj=focal_obj,
            custom_metric=focal_eval,
            evals=[(dval, 'val')],
            early_stopping_rounds=params.pop('early_stopping_rounds', 10),
            verbose_eval=False,
        )
        # Predict
        raw_preds = bst.predict(dval)
        y_pred = np.argmax(raw_preds.reshape(-1, n_classes), axis=1)
        model = bst  # Return Booster (not XGBClassifier) for focal path

    # â”€â”€ Standard classification path â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
    else:
        if is_binary:
            params.update({'objective': 'binary:logistic', 'eval_metric': 'logloss'})
        else:
            params.update({'objective': 'multi:softprob', 'eval_metric': 'mlogloss',
                           'num_class': int(len(set(y.tolist())))})

        model = xgb.XGBClassifier(**params)
        model.fit(X_tr, y_tr, eval_set=[(X_va, y_va)], verbose=False)
        y_pred = model.predict(X_va)

    acc = float(accuracy_score(y_va, y_pred))
    f1_w = float(f1_score(y_va, y_pred, average='weighted', zero_division=0))
    f1_m = float(f1_score(y_va, y_pred, average='macro', zero_division=0))

    # Per-class F1
    unique_classes = sorted(set(y_va.tolist()))
    f1_per = f1_score(y_va, y_pred, labels=unique_classes, average=None, zero_division=0)
    per_class_f1 = {str(c): round(float(f), 4) for c, f in zip(unique_classes, f1_per)}

    metrics = {
        'val_accuracy': round(acc, 4),
        'val_f1_weighted': round(f1_w, 4),
        'val_f1_macro': round(f1_m, 4),
        'per_class_f1': per_class_f1,
        'train_size': len(X_tr), 'val_size': len(X_va),
    }
    if focal_loss_config:
        metrics['focal_loss'] = True
        metrics['focal_alpha'] = focal_loss_config.get('alpha', 0.25)
        metrics['focal_gamma'] = focal_loss_config.get('gamma', 2.0)

    return model, metrics


def cv_model(X, y, xgb_params, task, n_folds=5, seed=42):
    """Manual k-fold CV (no sample_weight â€” resampling already applied)."""
    is_binary = task == 'binary' or len(set(y.tolist())) == 2

    if task == 'regression':
        from sklearn.model_selection import KFold
        kf = KFold(n_splits=n_folds, shuffle=True, random_state=seed)
        r2_scores = []
        for train_idx, val_idx in kf.split(X):
            params = dict(xgb_params, random_state=seed, verbosity=0,
                          early_stopping_rounds=10, objective='reg:squarederror',
                          eval_metric='rmse')
            m = xgb.XGBRegressor(**params)
            m.fit(X[train_idx], y[train_idx],
                  eval_set=[(X[val_idx], y[val_idx])], verbose=False)
            r2_scores.append(r2_score(y[val_idx], m.predict(X[val_idx])))
        return {'cv_r2_mean': round(np.mean(r2_scores), 4),
                'cv_r2_std': round(np.std(r2_scores), 4),
                'cv_fold_scores': [round(s, 4) for s in r2_scores]}

    # Classification CV
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
    fold_accs = []
    for train_idx, val_idx in skf.split(X, y):
        params = dict(xgb_params, random_state=seed, use_label_encoder=False,
                      verbosity=0, early_stopping_rounds=10)
        if is_binary:
            params.update({'objective': 'binary:logistic', 'eval_metric': 'logloss'})
        else:
            params.update({'objective': 'multi:softprob', 'eval_metric': 'mlogloss',
                           'num_class': int(len(set(y.tolist())))})

        m = xgb.XGBClassifier(**params)
        m.fit(X[train_idx], y[train_idx],
              eval_set=[(X[val_idx], y[val_idx])], verbose=False)
        fold_accs.append(accuracy_score(y[val_idx], m.predict(X[val_idx])))

    return {'cv_accuracy_mean': round(np.mean(fold_accs), 4),
            'cv_accuracy_std': round(np.std(fold_accs), 4),
            'cv_fold_scores': [round(s, 4) for s in fold_accs]}


print("Training utilities defined \u2713 (hybrid resampling + focal loss + per-class F1)")

## 7. Hyperparameter Sweep (Optuna Bayesian Optimization, GPU-Accelerated)

Uses **Optuna TPE** (Tree-structured Parzen Estimator) instead of brute-force
grid search. This explores the hyperparameter space ~10Ã— more efficiently by
learning which regions are promising and pruning bad trials early.

| Model | Trials | Focal Loss | Notes |
|-------|--------|------------|-------|
| edge_ai | 80 | â€” | 5-class edge scoring |
| path_gnn | 60 | â€” | Binary path edge scoring |
| diploid_ai | 150 | â€” | Haplotype phasing (biggest search) |
| ul_routing | 60 | â€” | Regression (routing score) |
| sv_ai | 100 | âœ“ | 5-class SV detection, 27 features |
| sv_ai_binary | 80 | â€” | Binary SV detector (stage 1) |
| sv_ai_subtype | 100 | âœ“ | 4-class SV subtype (stage 2) |

**GPU utilization improvements**:
- `device='cuda'` + `tree_method='hist'` â€” XGBoost 2.x unified GPU API
- `max_bin=1024` (4Ã— default) â€” builds finer histograms in GPU memory
- Optuna early pruning kills bad trials after 20 trees instead of waiting for 500

**Resume support**: Optuna studies are backed by a SQLite DB at
`sweep_optuna.db` in the output dir. Re-running this cell automatically
resumes all studies from where they left off â€” no manual checkpointing needed.

In [None]:
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner
from functools import partial

optuna.logging.set_verbosity(optuna.logging.WARNING)

# â”€â”€ Optuna search spaces per model â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# Each function defines the hyperparameter ranges for Optuna to explore.
# Ranges are wider than the old grid â€” Optuna's TPE sampler focuses on
# the promising regions automatically.

def edge_ai_objective(trial, X, y, base_params, task, focal_loss_config=None):
    hp = {
        'max_depth': trial.suggest_int('max_depth', 4, 12),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
        'n_estimators': trial.suggest_int('n_estimators', 100, 800, step=50),
        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 10),
        'gamma': trial.suggest_float('gamma', 0.0, 0.5),
        'reg_alpha': trial.suggest_float('reg_alpha', 1e-3, 1.0, log=True),
        'reg_lambda': trial.suggest_float('reg_lambda', 0.1, 10.0, log=True),
    }
    params = dict(base_params, **hp)
    _, metrics = train_model(X, y, params, task, focal_loss_config=focal_loss_config)
    return metrics.get('val_f1_macro', 0)


def path_gnn_objective(trial, X, y, base_params, task, focal_loss_config=None):
    hp = {
        'max_depth': trial.suggest_int('max_depth', 3, 10),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
        'n_estimators': trial.suggest_int('n_estimators', 50, 500, step=50),
        'subsample': trial.suggest_float('subsample', 0.7, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.7, 1.0),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 7),
    }
    params = dict(base_params, **hp)
    _, metrics = train_model(X, y, params, task, focal_loss_config=focal_loss_config)
    return metrics.get('val_f1_macro', 0)


def diploid_ai_objective(trial, X, y, base_params, task, focal_loss_config=None):
    hp = {
        'max_depth': trial.suggest_int('max_depth', 4, 14),
        'learning_rate': trial.suggest_float('learning_rate', 0.005, 0.2, log=True),
        'n_estimators': trial.suggest_int('n_estimators', 200, 1200, step=50),
        'subsample': trial.suggest_float('subsample', 0.5, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 15),
        'gamma': trial.suggest_float('gamma', 0.0, 1.0),
        'reg_alpha': trial.suggest_float('reg_alpha', 1e-3, 5.0, log=True),
        'reg_lambda': trial.suggest_float('reg_lambda', 0.1, 10.0, log=True),
    }
    params = dict(base_params, **hp)
    _, metrics = train_model(X, y, params, task, focal_loss_config=focal_loss_config)
    return metrics.get('val_f1_macro', 0)


def ul_routing_objective(trial, X, y, base_params, task, focal_loss_config=None):
    hp = {
        'max_depth': trial.suggest_int('max_depth', 3, 10),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
        'n_estimators': trial.suggest_int('n_estimators', 50, 500, step=50),
        'subsample': trial.suggest_float('subsample', 0.7, 1.0),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 7),
    }
    params = dict(base_params, **hp)
    _, metrics = train_model(X, y, params, task, focal_loss_config=focal_loss_config)
    return metrics.get('val_r2', 0)


def sv_ai_objective(trial, X, y, base_params, task, focal_loss_config=None):
    hp = {
        'max_depth': trial.suggest_int('max_depth', 4, 14),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.2, log=True),
        'n_estimators': trial.suggest_int('n_estimators', 100, 800, step=50),
        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 10),
        'gamma': trial.suggest_float('gamma', 0.0, 0.5),
        'reg_alpha': trial.suggest_float('reg_alpha', 1e-3, 1.0, log=True),
        'reg_lambda': trial.suggest_float('reg_lambda', 0.1, 10.0, log=True),
    }
    params = dict(base_params, **hp)
    _, metrics = train_model(X, y, params, task, focal_loss_config=focal_loss_config)
    return metrics.get('val_f1_macro', 0)


def sv_ai_binary_objective(trial, X, y, base_params, task, focal_loss_config=None):
    hp = {
        'max_depth': trial.suggest_int('max_depth', 3, 10),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
        'n_estimators': trial.suggest_int('n_estimators', 50, 500, step=50),
        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 7),
        'scale_pos_weight': trial.suggest_float('scale_pos_weight', 1.0, 20.0),
    }
    params = dict(base_params, **hp)
    _, metrics = train_model(X, y, params, task, focal_loss_config=focal_loss_config)
    return metrics.get('val_f1_macro', 0)


def sv_ai_subtype_objective(trial, X, y, base_params, task, focal_loss_config=None):
    hp = {
        'max_depth': trial.suggest_int('max_depth', 4, 14),
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.2, log=True),
        'n_estimators': trial.suggest_int('n_estimators', 100, 800, step=50),
        'subsample': trial.suggest_float('subsample', 0.6, 1.0),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.6, 1.0),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 10),
        'gamma': trial.suggest_float('gamma', 0.0, 0.5),
        'reg_alpha': trial.suggest_float('reg_alpha', 1e-3, 1.0, log=True),
        'reg_lambda': trial.suggest_float('reg_lambda', 0.1, 10.0, log=True),
    }
    params = dict(base_params, **hp)
    _, metrics = train_model(X, y, params, task, focal_loss_config=focal_loss_config)
    return metrics.get('val_f1_macro', 0)


OPTUNA_CONFIG = {
    'edge_ai':       {'objective_fn': edge_ai_objective,       'n_trials': 80},
    'path_gnn':      {'objective_fn': path_gnn_objective,      'n_trials': 60},
    'diploid_ai':    {'objective_fn': diploid_ai_objective,     'n_trials': 150},
    'ul_routing':    {'objective_fn': ul_routing_objective,     'n_trials': 60},
    'sv_ai':         {'objective_fn': sv_ai_objective,          'n_trials': 100},
    'sv_ai_binary':  {'objective_fn': sv_ai_binary_objective,   'n_trials': 80},
    'sv_ai_subtype': {'objective_fn': sv_ai_subtype_objective,  'n_trials': 100},
}

# â”€â”€ Optuna storage (SQLite) for automatic resume â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
DB_PATH = os.path.join(OUTPUT_DIR, 'sweep_optuna.db')
storage = f'sqlite:///{DB_PATH}'

sweep_results = {}
t_sweep_start = time.time()

for model_name, spec in MODEL_SPECS.items():
    cfg = OPTUNA_CONFIG.get(model_name)
    if not cfg:
        continue

    print(f"\n{'='*70}")
    print(f"  Optuna sweep: {model_name} ({cfg['n_trials']} trials)")
    print(f"{'='*70}")

    # Load data (with label_filter support for sv_ai_subtype)
    X, labels, label_dist = load_csvs(
        DATA_DIR, spec['csv_glob'], spec['features'],
        spec['label_col'], spec.get('label_transform'),
        label_filter=spec.get('label_filter'))

    if X is None:
        print(f"  No data, skipping")
        continue

    if spec['task'] == 'regression':
        y = labels.astype(np.float32)
    elif spec['task'] == 'binary':
        y = np.array([int(v) for v in labels], dtype=np.int32)
    else:
        le = LabelEncoder()
        y = le.fit_transform(labels)

    # Hybrid resampling if imbalanced
    resampled = False
    if spec['task'] in ('multiclass', 'binary'):
        counts = list(label_dist.values())
        imbalance = max(counts) / max(min(counts), 1)
        if imbalance > 5.0:
            rng = np.random.default_rng(42)
            X, y = hybrid_resample(X, y, rng=rng)
            resampled = True
            print(f"  Hybrid-resampled -> {len(y):,} samples")

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # Build focal loss config if spec has it
    focal_loss_cfg = None
    if spec.get('focal_loss') and spec['task'] == 'multiclass':
        focal_loss_cfg = {
            'alpha': spec.get('focal_alpha', 0.25),
            'gamma': spec.get('focal_gamma', 2.0),
            'n_classes': int(len(set(y.tolist()))),
        }
        print(f"  Focal loss enabled (alpha={focal_loss_cfg['alpha']}, gamma={focal_loss_cfg['gamma']})")

    # Create or load existing Optuna study (resume-safe)
    direction = 'maximize'
    study = optuna.create_study(
        study_name=model_name,
        storage=storage,
        load_if_exists=True,
        direction=direction,
        sampler=TPESampler(seed=42),
        pruner=MedianPruner(n_startup_trials=10, n_warmup_steps=20),
    )

    completed = len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE])
    remaining = max(0, cfg['n_trials'] - completed)

    if remaining == 0:
        print(f"  âœ“ Already done ({completed} trials, best={study.best_value:.4f})")
    else:
        if completed > 0:
            print(f"  Resuming: {completed} done, {remaining} remaining (best so far: {study.best_value:.4f})")
        # Bind data into objective via functools.partial (avoids late-binding)
        objective = partial(cfg['objective_fn'], X=X_scaled, y=y,
                            base_params=spec['xgb_params'], task=spec['task'],
                            focal_loss_config=focal_loss_cfg)
        study.optimize(objective, n_trials=remaining, show_progress_bar=True)

    # Extract best results
    best_trial = study.best_trial
    best_params = best_trial.params
    best_score = best_trial.value

    # Re-run best params to get full metrics (including per-class F1)
    final_params = dict(spec['xgb_params'], **best_params)
    _, best_metrics = train_model(X_scaled, y, final_params, spec['task'],
                                   focal_loss_config=focal_loss_cfg)

    sweep_results[model_name] = {
        'status': 'complete',
        'best_params': best_params,
        'best_score': round(best_score, 4),
        'best_metrics': best_metrics,
        'n_trials': len(study.trials),
        'resampled': resampled,
        'focal_loss': bool(focal_loss_cfg),
    }

    print(f"\n  âœ“ Best: score={best_score:.4f}  ({len(study.trials)} trials)")
    print(f"    params={best_params}")
    if 'per_class_f1' in best_metrics:
        for cls_name, f1_val in best_metrics['per_class_f1'].items():
            print(f"        class {cls_name:8s}  F1={f1_val:.4f}")

t_sweep_total = time.time() - t_sweep_start
print(f"\n{'='*70}")
print(f"  Sweep complete! ({len(sweep_results)} models, {t_sweep_total:.0f}s total)")
print(f"  Optuna DB: {DB_PATH}")
print(f"{'='*70}")

### Sweep Summary


In [None]:
# â”€â”€ Sweep summary table â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
# If kernel restarted, reload from Optuna DB
if not sweep_results:
    _db = os.path.join(OUTPUT_DIR, 'sweep_optuna.db')
    if os.path.exists(_db):
        _storage = f'sqlite:///{_db}'
        sweep_results = {}
        for name in OPTUNA_CONFIG:
            try:
                study = optuna.load_study(study_name=name, storage=_storage)
                best = study.best_trial
                final_params = dict(MODEL_SPECS[name]['xgb_params'], **best.params)
                sweep_results[name] = {
                    'status': 'complete',
                    'best_params': best.params,
                    'best_score': round(best.value, 4),
                    'n_trials': len(study.trials),
                }
                print(f"  Loaded {name}: {len(study.trials)} trials")
            except Exception:
                pass
        print()
    else:
        print("No sweep results found â€” run the sweep cell first.\n")

done = [m for m, sr in sweep_results.items() if sr.get('status') == 'complete']
pending = [m for m in OPTUNA_CONFIG if m not in done]

print(f"{'Model':<15} {'Best Score':<14} {'Trials':<10} {'Best Params'}")
print('-' * 90)
for name in OPTUNA_CONFIG:
    sr = sweep_results.get(name)
    if sr and sr.get('status') == 'complete':
        print(f"{name:<15} {sr['best_score']:<14.4f} {sr.get('n_trials','?'):<10} {sr['best_params']}")

if pending:
    print(f"\nâš  {len(pending)} model(s) not yet swept: {pending}")
    print("  Re-run the sweep cell to continue.")
else:
    print(f"\nâœ“ All {len(done)} models swept successfully!")

# Save finalized sweep results as JSON too
with open(os.path.join(OUTPUT_DIR, 'sweep_results.json'), 'w') as f:
    json.dump(sweep_results, f, indent=2, default=str)
print(f"\nSweep results saved -> {OUTPUT_DIR}/sweep_results.json")

## 8. Retrain with Best Hyperparameters

Retrain all models using the sweep-winning hyperparameters, with full
5-fold CV, and overwrite the previous models.


In [None]:
# â”€â”€ Retrain all models with sweep-winning hyperparameters â”€â”€â”€â”€â”€â”€â”€â”€â”€
# If kernel restarted, reload sweep results from Optuna DB or JSON
if not sweep_results:
    _sr_path = os.path.join(OUTPUT_DIR, 'sweep_results.json')
    _db = os.path.join(OUTPUT_DIR, 'sweep_optuna.db')
    if os.path.exists(_sr_path):
        with open(_sr_path) as f:
            sweep_results = json.load(f)
        print(f"Loaded sweep results from {_sr_path}")
    elif os.path.exists(_db):
        _storage = f'sqlite:///{_db}'
        sweep_results = {}
        for name in OPTUNA_CONFIG:
            try:
                study = optuna.load_study(study_name=name, storage=_storage)
                sweep_results[name] = {
                    'status': 'complete',
                    'best_params': study.best_trial.params,
                    'best_score': round(study.best_value, 4),
                }
            except Exception:
                pass
        print(f"Loaded {len(sweep_results)} model(s) from Optuna DB")

print("Retraining with sweep-optimized hyperparameters...\n")

final_models = {}
final_results = {}
t_start = time.time()

for model_name, spec in MODEL_SPECS.items():
    print(f"\n{'='*70}")
    print(f"  Retrain: {model_name}")
    print(f"{'='*70}")

    if model_name in sweep_results:
        best_hp = sweep_results[model_name]['best_params']
        xgb_params = dict(spec['xgb_params'], **best_hp)
        print(f"  Using sweep-optimized: {best_hp}")
    else:
        xgb_params = spec['xgb_params']
        print(f"  Using default params")

    X, labels, label_dist = load_csvs(
        DATA_DIR, spec['csv_glob'], spec['features'],
        spec['label_col'], spec.get('label_transform'),
        label_filter=spec.get('label_filter'))

    if X is None:
        print(f"  No data, skipping")
        continue

    if spec['task'] == 'regression':
        y = labels.astype(np.float32)
    elif spec['task'] == 'binary':
        y = np.array([int(v) for v in labels], dtype=np.int32)
    else:
        le = LabelEncoder()
        y = le.fit_transform(labels)

    rebalance = 'none'
    if spec['task'] in ('multiclass', 'binary'):
        counts = list(label_dist.values())
        imbalance = max(counts) / max(min(counts), 1)
        if imbalance > 5.0:
            rng = np.random.default_rng(42)
            X, y = hybrid_resample(X, y, rng=rng)
            rebalance = 'hybrid'
            print(f"  Hybrid-resampled -> {len(y):,} samples")

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    # Build focal loss config if spec has it
    focal_loss_cfg = None
    if spec.get('focal_loss') and spec['task'] == 'multiclass':
        focal_loss_cfg = {
            'alpha': spec.get('focal_alpha', 0.25),
            'gamma': spec.get('focal_gamma', 2.0),
            'n_classes': int(len(set(y.tolist()))),
        }
        print(f"  Focal loss enabled (alpha={focal_loss_cfg['alpha']}, gamma={focal_loss_cfg['gamma']})")

    model, metrics = train_model(X_scaled, y, xgb_params, spec['task'],
                                  focal_loss_config=focal_loss_cfg)
    metrics['rebalance_strategy'] = rebalance
    if model_name in sweep_results:
        metrics['sweep_best_params'] = sweep_results[model_name]['best_params']

    if spec['task'] == 'regression':
        print(f"  Val: RMSE={metrics['val_rmse']:.4f}  R2={metrics['val_r2']:.4f}")
    else:
        print(f"  Val: acc={metrics['val_accuracy']:.4f}  F1w={metrics['val_f1_weighted']:.4f}  F1m={metrics['val_f1_macro']:.4f}")
        if 'per_class_f1' in metrics:
            for cls, f1v in metrics['per_class_f1'].items():
                print(f"       class {cls:12s}  F1={f1v:.4f}")

    cv_metrics = cv_model(X_scaled, y, xgb_params, spec['task'])
    metrics.update(cv_metrics)
    metrics['label_distribution'] = {str(k): int(v) for k, v in label_dist.items()}
    metrics['num_samples'] = len(y)
    metrics['num_features'] = X.shape[1]

    if 'cv_accuracy_mean' in cv_metrics:
        print(f"  CV:  acc={cv_metrics['cv_accuracy_mean']:.4f} +/- {cv_metrics['cv_accuracy_std']:.4f}")
    else:
        print(f"  CV:  R2={cv_metrics['cv_r2_mean']:.4f} +/- {cv_metrics['cv_r2_std']:.4f}")

    final_models[model_name] = (model, scaler, metrics)
    final_results[model_name] = metrics

# â”€â”€ Save models â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
MODEL_FILENAMES = {
    'path_gnn':      'pathgnn_scorer.pkl',
    'diploid_ai':    'diploid_model.pkl',
    'ul_routing':    'ul_routing_model.pkl',
    'sv_ai':         'sv_detector_model.pkl',
    'sv_ai_binary':  'sv_binary_model.pkl',
    'sv_ai_subtype': 'sv_subtype_model.pkl',
}

for model_name, (model, scaler, metrics) in final_models.items():
    subdir = os.path.join(OUTPUT_DIR, SAVE_MAP[model_name])
    os.makedirs(subdir, exist_ok=True)

    if model_name == 'edge_ai':
        for tech in TECH_LIST:
            with open(os.path.join(subdir, f'edgewarden_{tech}.pkl'), 'wb') as f:
                pickle.dump(model, f)
            with open(os.path.join(subdir, f'scaler_{tech}.pkl'), 'wb') as f:
                pickle.dump(scaler, f)
    else:
        model_filename = MODEL_FILENAMES[model_name]
        with open(os.path.join(subdir, model_filename), 'wb') as f:
            pickle.dump(model, f)

    with open(os.path.join(subdir, f'training_metadata_{model_name}.json'), 'w') as f:
        json.dump(metrics, f, indent=2, default=str)

    print(f"  Done: {model_name} -> {subdir}/{MODEL_FILENAMES.get(model_name, 'edgewarden_*.pkl')}")

# â”€â”€ Final report â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
t_total = time.time() - t_start
report = {
    'training_method': 'hybrid_resampling_gpu_sweep_optimized_v2.1',
    'device': DEVICE,
    'resampling_strategy': 'undersample_100k + oversample_to_median',
    'improvements': [
        'S1: SV-dense training data (10-50x baseline density)',
        'S2: Two-stage SV detection (binary + subtype)',
        'S3: Focal loss for SV models',
        'S4: 8 new SV-specific features (27 total)',
    ],
    'total_time_seconds': round(t_total, 1),
    'models': final_results,
}
with open(os.path.join(OUTPUT_DIR, 'training_report.json'), 'w') as f:
    json.dump(report, f, indent=2, default=str)

print(f"\nFinal report -> {OUTPUT_DIR}/training_report.json")
print(f"  Total retrain time: {t_total:.1f}s")

## 9. Save Models to Google Drive

Saves the retrained models back to Google Drive for persistence.
Download from Drive to your local repo:
```bash
# From Google Drive, download trained_models_v2.tar.gz, then:
tar xzf trained_models_v2.tar.gz
cp -r trained_models_v2/* trained_models_10x/
```

In [None]:
# â”€â”€ Save retrained models to Google Drive â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
import shutil

# Create tarball on local SSD first (fast)
local_tar = '/content/trained_models_v2.tar.gz'
!cd /content && tar czf trained_models_v2.tar.gz trained_models_v2/

# Copy to Google Drive
shutil.copy2(local_tar, GDRIVE_OUTPUT)

# Verify
drive_size_mb = os.path.getsize(GDRIVE_OUTPUT) / (1024 * 1024)
print(f"âœ“ Saved to Google Drive: {GDRIVE_OUTPUT} ({drive_size_mb:.1f} MB)")
print(f"  Contains {len(final_models)} models with sweep-optimized hyperparameters")
print(f"  Training report: trained_models_v2/training_report.json")
print(f"  Sweep results:   trained_models_v2/sweep_results.json")
print(f"\n  Download from Drive â†’ My Drive/Colab Notebooks/trained_models_v2.tar.gz")