# Hyperparameter Tuning (lightweight, staged random search)

This notebook runs a cheap hyperparameter search for the Transformer MIL model used in `transformer_2.ipynb`.

Goals:
- Run quick trials (3-5 epochs) across a randomized search space to find promising configurations
- Re-evaluate the top-k candidates for a medium number of epochs (10-15)
- Save best hyperparameters to `notebooks/tuning_results/best_config.json` and detailed CSV of all trials

Notes:
- This notebook duplicates only the pieces needed to run quick experiments: data loading, dataset construction, model, loss and training loop.
- Designed to be cheap: stage-1 uses short runs to eliminate bad configs.

In [None]:
# 1. Imports
import json
import gzip
import random
import math
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from sklearn.metrics import roc_auc_score, average_precision_score
from tqdm import tqdm

# reproducibility
random.seed(42)
np.random.seed(42)

# Paths
ROOT = Path('..')
DATA_FILE = ROOT / 'data' / 'dataset0.json.gz'
LABELS_FILE = ROOT / 'data' / 'data.info.labelled'
OUT_DIR = Path('notebooks') / 'tuning_results'
OUT_DIR.mkdir(parents=True, exist_ok=True)
print('Output ->', OUT_DIR)

In [None]:
# 2. Data loader (lightweight, matches transformer_2.ipynb format)

def load_data(data_file, labels_file):
    rows = []
    with gzip.open(data_file, 'rt', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            for transcript_id, positions in data.items():
                for transcript_position, sequences in positions.items():
                    for sequence, feature_list in sequences.items():
                        for features in feature_list:
                            rows.append({
                                'transcript_id': transcript_id,
                                'transcript_position': int(transcript_position),
                                'dwell_-1': features[0],
                                'std_-1': features[1],
                                'mean_-1': features[2],
                                'dwell_0': features[3],
                                'std_0': features[4],
                                'mean_0': features[5],
                                'dwell_+1': features[6],
                                'std_+1': features[7],
                                'mean_+1': features[8],
                            })
    df = pd.DataFrame(rows)
    labels = pd.read_csv(labels_file)
    return df, labels

# Try loading a small subset check (dataset may be large; we don't print everything)
print('Loading data... (this may take a moment)')
df, labels = load_data(DATA_FILE, LABELS_FILE)
print('reads:', len(df), 'labels:', len(labels))

In [None]:
# 3. Create bags and dataset (matching transformer_2's format)

BASE2IDX = {'A':0, 'C':1, 'G':2, 'T':3, 'U':3}
PAD_IDX = 4


def seq_to_idx7(s: str):
    s = str(s).upper().replace('U', 'T')
    if len(s) != 7:
        # pad or truncate defensively
        s = (s + 'AAAAAAA')[:7]
    return np.array([BASE2IDX.get(ch, 0) for ch in s], dtype=np.int64)

num_cols = [
    'dwell_-1','std_-1','mean_-1','dwell_0','std_0','mean_0',
    'dwell_+1','std_+1','mean_+1'
]

site_key = ['transcript_id', 'transcript_position']


def create_bags(df, seq_col='sequence', label_col='label', min_reads=5, max_reads=50):
    rows = df.merge(labels, on=['transcript_id','transcript_position'], how='left')
    grouped = rows.groupby(site_key)
    bags = []
    for site, group in grouped:
        if len(group) < min_reads:
            continue
        feats = group[num_cols].to_numpy(dtype=np.float32)
        seqs = group[seq_col].astype(str).tolist() if seq_col in group.columns else ['A'*7]*len(group)
        seq_idx = np.vstack([seq_to_idx7(s) for s in seqs])
        label = int(group[label_col].iloc[0]) if label_col in group.columns else 0
        if len(feats) > max_reads:
            idx = np.random.choice(len(feats), max_reads, replace=False)
            feats = feats[idx]
            seq_idx = seq_idx[idx]
        bags.append({'features': feats, 'seq_idx': seq_idx, 'n_reads': len(feats), 'label': label, 'gene_id': group.get('gene_id', pd.Series()).iloc[0] if 'gene_id' in group else None})
    print(f'Created {len(bags)} bags (min_reads={min_reads})')
    return bags

# Create base bags with conservative parameters to support tuning
bags = create_bags(df, min_reads=5, max_reads=50)

In [None]:
# Dataset used for training/validation
class RNA_MIL_Dataset(Dataset):
    def __init__(self, bags, bag_size=20, is_train=True, pad_idx=PAD_IDX, oversample_positive=False, oversample_factor=1):
        self.bags = bags
        self.bag_size = bag_size
        self.is_train = is_train
        if is_train and oversample_positive and oversample_factor>1:
            pos = [b for b in bags if b['label']==1]
            neg = [b for b in bags if b['label']==0]
            self.bags = neg + pos * oversample_factor
        self.proc = []
        self.labels = []
        for bag in self.bags:
            num = bag['features']
            seq = bag['seq_idx']
            n = bag['n_reads']
            if n == 0:
                continue
            if n < bag_size:
                pad_num = np.zeros((bag_size - n, num.shape[1]), dtype=np.float32)
                pad_seq = np.full((bag_size - n, seq.shape[1]), pad_idx, dtype=np.int64)
                num_fixed = np.vstack([num, pad_num])
                seq_fixed = np.vstack([seq, pad_seq])
                mask = np.zeros(bag_size, dtype=np.float32)
                mask[:n] = 1.0
            else:
                if self.is_train and n > bag_size:
                    idx = np.random.choice(n, bag_size, replace=False)
                else:
                    idx = np.arange(bag_size)
                num_fixed = num[idx].astype(np.float32)
                seq_fixed = seq[idx].astype(np.int64)
                mask = np.ones(bag_size, dtype=np.float32)
            self.proc.append({'num': num_fixed, 'seq': seq_fixed, 'mask': mask})
            self.labels.append(float(bag['label']))
    def __len__(self):
        return len(self.proc)
    def __getitem__(self, idx):
        b = self.proc[idx]
        x_num = torch.from_numpy(b['num'])
        x_seq = torch.from_numpy(b['seq'])
        mask = torch.from_numpy(b['mask'])
        y = torch.tensor(self.labels[idx], dtype=torch.float32)
        return x_num, x_seq, mask, y

print('Dataset class ready')

In [None]:
# 4. Model (copied/adapted from transformer_2.ipynb)

class SeqEmbCNN(nn.Module):
    def __init__(self, vocab=5, d_emb=8, kernel_sizes=(2,3,4,5), n_filters=32, d_out=64):
        super().__init__()
        self.emb = nn.Embedding(vocab, d_emb, padding_idx=PAD_IDX)
        self.convs = nn.ModuleList([nn.Conv1d(d_emb, n_filters, ks, padding=0) for ks in kernel_sizes])
        self.proj = nn.Linear(n_filters * len(kernel_sizes), d_out)
        self.norm = nn.LayerNorm(d_out)
    def forward(self, x_idx):
        X = self.emb(x_idx).transpose(1,2)
        feats = []
        for conv in self.convs:
            h = F.gelu(conv(X))
            h = F.max_pool1d(h, h.shape[-1]).squeeze(-1)
            feats.append(h)
        z = torch.cat(feats, dim=1)
        z = self.proj(z)
        return self.norm(F.gelu(z))

class MultiHeadAttentionPooling(nn.Module):
    def __init__(self, d_model, n_heads=4, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.attention_heads = nn.ModuleList([
            nn.Sequential(nn.Linear(d_model, d_model), nn.Tanh(), nn.Dropout(dropout), nn.Linear(d_model,1))
            for _ in range(n_heads)
        ])
        self.fusion = nn.Sequential(nn.Linear(d_model * n_heads, d_model), nn.LayerNorm(d_model), nn.GELU())
    def forward(self, h, mask):
        pooled = []
        all_weights = []
        for attn in self.attention_heads:
            scores = attn(h).squeeze(-1)
            scores = scores.masked_fill(mask == 0, float('-inf'))
            weights = F.softmax(scores, dim=1).unsqueeze(-1)
            pooled.append((h * weights).sum(dim=1))
            all_weights.append(weights.squeeze(-1))
        bag_repr = self.fusion(torch.cat(pooled, dim=-1))
        avg_weights = torch.stack(all_weights, dim=0).mean(dim=0)
        return bag_repr, avg_weights

class TransformerMIL(nn.Module):
    def __init__(self, num_features=9, d_model=256, n_heads=8, n_layers=6, d_ff=1024, dropout=0.2, attn_pool_heads=4):
        super().__init__()
        self.seq_encoder = SeqEmbCNN(vocab=5, d_emb=8, kernel_sizes=(2,3,4,5), n_filters=32, d_out=64)
        self.num_proj = nn.Sequential(nn.Linear(num_features, 64), nn.LayerNorm(64), nn.GELU(), nn.Dropout(dropout))
        self.feature_fusion = nn.Sequential(nn.Linear(128, d_model), nn.LayerNorm(d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model,d_model), nn.LayerNorm(d_model))
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout, activation='gelu', batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.attention_pool = MultiHeadAttentionPooling(d_model, n_heads=attn_pool_heads, dropout=dropout)
        self.classifier = nn.Sequential(nn.Linear(d_model, d_model//2), nn.LayerNorm(d_model//2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model//2, d_model//4), nn.LayerNorm(d_model//4), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model//4,1))
        self.instance_classifier = nn.Sequential(nn.Linear(d_model, d_model//2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model//2,1))
    def encode_sequences(self, x_seq):
        B, K, L = x_seq.shape
        seq_flat = x_seq.reshape(B*K, L)
        z_seq = self.seq_encoder(seq_flat)
        return z_seq.view(B, K, -1)
    def forward(self, x_num, mask, x_seq=None):
        B, K, _ = x_num.shape
        num_features = self.num_proj(x_num)
        if x_seq is not None:
            seq_features = self.encode_sequences(x_seq)
            combined = torch.cat([num_features, seq_features], dim=-1)
        else:
            combined = num_features
        h = self.feature_fusion(combined)
        src_key_padding_mask = (mask == 0)
        h = self.transformer(h, src_key_padding_mask=src_key_padding_mask)
        instance_logits = self.instance_classifier(h).squeeze(-1)
        instance_probs = torch.sigmoid(instance_logits) * mask
        bag_repr, attention_weights = self.attention_pool(h, mask)
        bag_logits = self.classifier(bag_repr).squeeze(-1)
        return bag_logits, attention_weights, instance_probs

print('Model definitions ready')

In [None]:
# 5. Loss and metrics
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    def forward(self, logits, targets):
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        probs = torch.sigmoid(logits)
        p_t = probs * targets + (1 - probs) * (1 - targets)
        focal_weight = (1 - p_t) ** self.gamma
        alpha_weight = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        return (alpha_weight * focal_weight * bce).mean()

class MILLoss(nn.Module):
    def __init__(self, alpha=0.3, focal_gamma=2.0, class_weight=None):
        super().__init__()
        self.alpha = alpha
        self.focal = FocalLoss(gamma=focal_gamma)
        self.class_weight = class_weight
    def forward(self, bag_logits, instance_probs, labels, mask):
        bag_loss = self.focal(bag_logits, labels)
        positive_bags = labels == 1
        if positive_bags.any():
            pos_instance_probs = instance_probs[positive_bags]
            mil_pos = -torch.log(1 - torch.prod(1 - pos_instance_probs + 1e-8, dim=1) + 1e-8)
            mil_pos = mil_pos.mean()
        else:
            mil_pos = torch.tensor(0.0, device=bag_logits.device)
        return bag_loss + self.alpha * mil_pos, bag_loss, mil_pos


def summarize_metrics(y_true, y_prob):
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)
    roc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true))>1 else float('nan')
    pr = average_precision_score(y_true, y_prob) if len(np.unique(y_true))>1 else float('nan')
    return {'roc_auc': roc, 'pr_auc': pr}

print('Loss and metric utilities ready')

In [None]:
# 6. Training / evaluation loops (support scheduler)

def train_one_epoch(model, loader, optimizer, criterion, device, grad_scaler=None, mixed_precision=True, scheduler=None):
    model.train()
    total_loss = 0.0
    all_preds = []
    all_labels = []
    for x_num, x_seq, mask, y in tqdm(loader, desc='Train', leave=False):
        x_num = x_num.to(device)
        x_seq = x_seq.to(device)
        mask = mask.to(device)
        y = y.to(device)
        optimizer.zero_grad(set_to_none=True)
        if mixed_precision and grad_scaler is not None:
            with autocast():
                bag_logits, _, inst_probs = model(x_num, mask, x_seq)
                loss, bag_loss, mil_pos = criterion(bag_logits, inst_probs, y, mask)
            grad_scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            grad_scaler.step(optimizer)
            grad_scaler.update()
        else:
            bag_logits, _, inst_probs = model(x_num, mask, x_seq)
            loss, bag_loss, mil_pos = criterion(bag_logits, inst_probs, y, mask)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        if scheduler is not None:
            try:
                scheduler.step()
            except Exception:
                pass
        probs = torch.sigmoid(bag_logits).detach().cpu().numpy()
        all_preds.extend(probs)
        all_labels.extend(y.detach().cpu().numpy())
        total_loss += float(loss.detach().cpu())
    metrics = summarize_metrics(all_labels, all_preds)
    return total_loss/len(loader), metrics

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []
    for x_num, x_seq, mask, y in tqdm(loader, desc='Eval', leave=False):
        x_num = x_num.to(device)
        x_seq = x_seq.to(device)
        mask = mask.to(device)
        y = y.to(device)
        bag_logits, _, inst_probs = model(x_num, mask, x_seq)
        loss, _, _ = criterion(bag_logits, inst_probs, y, mask)
        probs = torch.sigmoid(bag_logits).detach().cpu().numpy()
        all_preds.extend(probs)
        all_labels.extend(y.detach().cpu().numpy())
        total_loss += float(loss.detach().cpu())
    metrics = summarize_metrics(all_labels, all_preds)
    return total_loss/len(loader), metrics

print('Training/eval loops ready')

In [None]:
# 7. Hyperparameter search: Stage 1 (quick random trials) then Stage 2 (refine)

def sample_config():
    return {
        'lr': float(10**np.random.uniform(-5, -3)),            # 1e-5 .. 1e-3
        'd_model': int(random.choice([128, 256])),
        'n_heads': int(4 if random.choice([128,256])==128 else 8),
        'n_layers': int(random.choice([4,6])),
        'dropout': float(random.choice([0.1, 0.2, 0.3])),
        'bag_size': int(random.choice([20, 40])),
        'warmup_steps': int(random.choice([500,1000]))
    }


def run_trial(config, epochs=3, batch_size=32, device=None):
    device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
    # Prepare datasets
    # split bags by simple random split: 80/10/10
    random.shuffle(bags)
    n = len(bags)
    n_train = int(0.8*n)
    n_val = int(0.1*n)
    train_bags = bags[:n_train]
    val_bags = bags[n_train:n_train+n_val]
    test_bags = bags[n_train+n_val:]
    train_ds = RNA_MIL_Dataset(train_bags, bag_size=config['bag_size'], is_train=True)
    val_ds = RNA_MIL_Dataset(val_bags, bag_size=config['bag_size'], is_train=False)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    # Model
    model = TransformerMIL(num_features=9, d_model=config['d_model'], n_heads=config['n_heads'], n_layers=config['n_layers'], dropout=config['dropout']).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-5)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(1.0, (step+1)/config['warmup_steps']))
    scaler = GradScaler()
    criterion = MILLoss(alpha=0.3, focal_gamma=2.0)
    best_val = -float('inf')
    history = []
    for ep in range(epochs):
        train_loss, train_metrics = train_one_epoch(model, train_loader, optimizer, criterion, device, grad_scaler=scaler, mixed_precision=True, scheduler=scheduler)
        val_loss, val_metrics = evaluate(model, val_loader, criterion, device)
        val_score = (val_metrics['roc_auc'] if not math.isnan(val_metrics['roc_auc']) else 0.0) + (val_metrics['pr_auc'] if not math.isnan(val_metrics['pr_auc']) else 0.0)
        history.append({'epoch': ep+1, 'train_loss': train_loss, 'val_loss': val_loss, **val_metrics, 'val_score': val_score})
        if val_score > best_val:
            best_val = val_score
    return best_val, history

# Stage 1: quick random search
N_TRIALS = 8
stage1_results = []
print('Stage 1: running', N_TRIALS, 'cheap random trials (epochs=3)')
for i in range(N_TRIALS):
    cfg = sample_config()
    print(f'Trial {i+1}/{N_TRIALS} cfg:', cfg)
    best_val, hist = run_trial(cfg, epochs=3)
    stage1_results.append({'config': cfg, 'best_val': best_val, 'history': hist})

# Select top-k
stage1_results_sorted = sorted(stage1_results, key=lambda x: x['best_val'], reverse=True)
TOPK = min(3, len(stage1_results_sorted))
print('Top configs from stage1:')
for k in range(TOPK):
    print(k+1, stage1_results_sorted[k]['config'], 'score=', stage1_results_sorted[k]['best_val'])

# Stage 2: refine top configs
stage2_results = []
for k in range(TOPK):
    cfg = stage1_results_sorted[k]['config']
    print('\nRefining config', k+1, cfg)
    best_val, hist = run_trial(cfg, epochs=12)
    stage2_results.append({'config': cfg, 'best_val': best_val, 'history': hist})

# Choose best overall
all_candidates = stage2_results if len(stage2_results)>0 else stage1_results_sorted
best_overall = max(all_candidates, key=lambda x: x['best_val'])
print('\nBest config found:', best_overall['config'], 'score=', best_overall['best_val'])

# Save results
results_path = OUT_DIR / f'tuning_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json'
with open(results_path, 'w') as f:
    json.dump({'stage1': stage1_results, 'stage2': stage2_results, 'best': best_overall}, f, default=str)
print('Saved full results to', results_path)

# Save best config separately (easy to load from transformer_2 notebook)
best_cfg_path = OUT_DIR / 'best_config.json'
with open(best_cfg_path, 'w') as f:
    json.dump(best_overall['config'], f, indent=2)
print('Saved best hyperparameters to', best_cfg_path)

# Also produce a CSV summary of configs
rows = []
for r in stage1_results:
    rows.append({'phase':'stage1','lr':r['config']['lr'],'d_model':r['config']['d_model'],'n_layers':r['config']['n_layers'],'dropout':r['config']['dropout'],'bag_size':r['config']['bag_size'],'best_val':r['best_val']})
for r in stage2_results:
    rows.append({'phase':'stage2','lr':r['config']['lr'],'d_model':r['config']['d_model'],'n_layers':r['config']['n_layers'],'dropout':r['config']['dropout'],'bag_size':r['config']['bag_size'],'best_val':r['best_val']})
summary_df = pd.DataFrame(rows)
summary_csv = OUT_DIR / 'tuning_summary.csv'
summary_df.to_csv(summary_csv, index=False)
print('Saved summary CSV to', summary_csv)

print('\nTuning complete. Next steps:')
print('- Open', best_cfg_path, 'and copy values into your training notebook')
print('- Or modify `transformer_2.ipynb` to load this JSON at the configuration section')

## How to use the best hyperparameters

In `transformer_2.ipynb`:
- At the top where the training config is created, you can add:

```python
import json
best = json.load(open('notebooks/tuning_results/best_config.json'))
config.LEARNING_RATE = best['lr']
config.EMBED_DIM = best['d_model']  # or d_model mapping if different naming
config.NUM_LAYERS = best['n_layers']
config.DROPOUT = best['dropout']
config.N_READS_PER_SITE = best['bag_size']
config.WARMUP_STEPS = best['warmup_steps']
```

Or manually copy values.

## 8. Optuna Hyperparameter Tuning

We use Optuna with TPE sampler and MedianPruner. The objective maximizes ROC-AUC + PR-AUC on the validation split. Trials are pruned early if progress is poor.
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner

# Optuna study configuration
N_TRIALS = 20           # adjust down/up based on your budget
MAX_EPOCHS = 12         # per trial (uses pruning, so many trials stop early)
STARTUP_TRIALS = 4      # number of trials before pruning kicks in
PRUNER_WARMUP_EPOCHS = 2

study_dir = OUT_DIR
study_dir.mkdir(parents=True, exist_ok=True)


def objective(trial: optuna.trial.Trial) -> float:
    # Hyperparameter search space
    lr = trial.suggest_float('lr', 1e-5, 1e-3, log=True)
    d_model = trial.suggest_categorical('d_model', [128, 256])
    n_layers = trial.suggest_categorical('n_layers', [4, 6])
    dropout = trial.suggest_categorical('dropout', [0.1, 0.2, 0.3])
    bag_size = trial.suggest_categorical('bag_size', [20, 40])
    warmup_steps = trial.suggest_categorical('warmup_steps', [500, 1000])

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Random 80/10/10 split per trial to reduce leakage between configs
    local_bags = bags.copy()
    random.shuffle(local_bags)
    n = len(local_bags)
    n_train = int(0.8*n)
    n_val = int(0.1*n)
    train_bags = local_bags[:n_train]
    val_bags = local_bags[n_train:n_train+n_val]

    train_ds = RNA_MIL_Dataset(train_bags, bag_size=bag_size, is_train=True)
    val_ds = RNA_MIL_Dataset(val_bags, bag_size=bag_size, is_train=False)
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)

    # Model and training setup
    model = TransformerMIL(num_features=9, d_model=d_model, n_heads=(4 if d_model==128 else 8), n_layers=n_layers, dropout=dropout).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(1.0, (step+1)/warmup_steps))
    scaler = GradScaler()
    criterion = MILLoss(alpha=0.3, focal_gamma=2.0)

    best_score = -float('inf')

    for epoch in range(1, MAX_EPOCHS+1):
        # one epoch train
        train_one_epoch(model, train_loader, optimizer, criterion, device, grad_scaler=scaler, mixed_precision=True, scheduler=scheduler)
        # evaluate
        _, val_metrics = evaluate(model, val_loader, criterion, device)
        val_auc = 0.0 if (math.isnan(val_metrics['roc_auc'])) else val_metrics['roc_auc']
        val_pr = 0.0 if (math.isnan(val_metrics['pr_auc'])) else val_metrics['pr_auc']
        score = val_auc + val_pr

        # Report intermediate value for pruning
        trial.report(score, step=epoch)
        if trial.should_prune() and epoch >= PRUNER_WARMUP_EPOCHS:
            raise optuna.TrialPruned()

        if score > best_score:
            best_score = score

    return best_score

# Create study with TPE sampler + median pruner
study = optuna.create_study(
    direction='maximize',
    sampler=TPESampler(seed=42),
    pruner=MedianPruner(n_startup_trials=STARTUP_TRIALS, n_warmup_steps=PRUNER_WARMUP_EPOCHS)
)

print('Starting Optuna study...')
study.optimize(objective, n_trials=N_TRIALS, show_progress_bar=True)

print('Number of finished trials:', len(study.trials))
print('Best value:', study.best_value)
print('Best params:', study.best_params)

# Save best config
best_cfg = study.best_params.copy()
# Ensure required keys for downstream notebook
if 'n_heads' not in best_cfg:
    best_cfg['n_heads'] = 4 if best_cfg['d_model'] == 128 else 8
best_cfg_path = study_dir / 'best_config_optuna.json'
with open(best_cfg_path, 'w') as f:
    json.dump(best_cfg, f, indent=2)
print('Saved best config to', best_cfg_path)

# Save trials dataframe
try:
    trials_df = study.trials_dataframe()
    trials_csv = study_dir / 'optuna_trials.csv'
    trials_df.to_csv(trials_csv, index=False)
    print('Saved trials to', trials_csv)
except Exception as e:
    print('Could not save trials dataframe:', e)

print('Best (Optuna):', study.best_value, '\nParams:', study.best_params)

In [None]:
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner

# Optuna study configuration
N_TRIALS = 20           # adjust down/up based on your budget
MAX_EPOCHS = 12         # per trial (uses pruning, so many trials stop early)
STARTUP_TRIALS = 4      # number of trials before pruning kicks in
PRUNER_WARMUP_EPOCHS = 2

study_dir = OUT_DIR
study_dir.mkdir(parents=True, exist_ok=True)


def objective(trial: optuna.trial.Trial) -> float:
    # Hyperparameter search space
    lr = trial.suggest_float('lr', 1e-5, 1e-3, log=True)
    d_model = trial.suggest_categorical('d_model', [128, 256])
    n_layers = trial.suggest_categorical('n_layers', [4, 6])
    dropout = trial.suggest_categorical('dropout', [0.1, 0.2, 0.3])
    bag_size = trial.suggest_categorical('bag_size', [20, 40])
    warmup_steps = trial.suggest_categorical('warmup_steps', [500, 1000])

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Random 80/10/10 split per trial to reduce leakage between configs
    local_bags = bags.copy()
    random.shuffle(local_bags)
    n = len(local_bags)
    n_train = int(0.8*n)
    n_val = int(0.1*n)
    train_bags = local_bags[:n_train]
    val_bags = local_bags[n_train:n_train+n_val]

    train_ds = RNA_MIL_Dataset(train_bags, bag_size=bag_size, is_train=True)
    val_ds = RNA_MIL_Dataset(val_bags, bag_size=bag_size, is_train=False)
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)

    # Model and training setup
    model = TransformerMIL(num_features=9, d_model=d_model, n_heads=(4 if d_model==128 else 8), n_layers=n_layers, dropout=dropout).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(1.0, (step+1)/warmup_steps))
    scaler = GradScaler()
    criterion = MILLoss(alpha=0.3, focal_gamma=2.0)

    best_score = -float('inf')

    for epoch in range(1, MAX_EPOCHS+1):
        # Train one epoch (quietly - no progress bar)
        model.train()
        for x_num, x_seq, mask, y in train_loader:
            x_num, x_seq, mask, y = x_num.to(device), x_seq.to(device), mask.to(device), y.to(device)
            optimizer.zero_grad(set_to_none=True)
            with autocast():
                bag_logits, _, inst_probs = model(x_num, mask, x_seq)
                loss, _, _ = criterion(bag_logits, inst_probs, y, mask)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            if scheduler is not None:
                scheduler.step()
        
        # Evaluate (quietly)
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for x_num, x_seq, mask, y in val_loader:
                x_num, x_seq, mask, y = x_num.to(device), x_seq.to(device), mask.to(device), y.to(device)
                bag_logits, _, _ = model(x_num, mask, x_seq)
                probs = torch.sigmoid(bag_logits).cpu().numpy()
                all_preds.extend(probs)
                all_labels.extend(y.cpu().numpy())
        
        val_metrics = summarize_metrics(all_labels, all_preds)
        val_auc = 0.0 if math.isnan(val_metrics['roc_auc']) else val_metrics['roc_auc']
        val_pr = 0.0 if math.isnan(val_metrics['pr_auc']) else val_metrics['pr_auc']
        score = val_auc + val_pr

        # Report intermediate value for pruning
        trial.report(score, step=epoch)
        if trial.should_prune() and epoch >= PRUNER_WARMUP_EPOCHS:
            raise optuna.TrialPruned()

        if score > best_score:
            best_score = score

    return best_score

# Create study with TPE sampler + median pruner
study = optuna.create_study(
    direction='maximize',
    sampler=TPESampler(seed=42),
    pruner=MedianPruner(n_startup_trials=STARTUP_TRIALS, n_warmup_steps=PRUNER_WARMUP_EPOCHS)
)

print('Starting Optuna study...')
# Disable progress bar to reduce output spam, use verbose logging instead
optuna.logging.set_verbosity(optuna.logging.INFO)
study.optimize(objective, n_trials=N_TRIALS, show_progress_bar=False)

print('Number of finished trials:', len(study.trials))
print('Best value:', study.best_value)
print('Best params:', study.best_params)

# Save best config
best_cfg = study.best_params.copy()
# Ensure required keys for downstream notebook
if 'n_heads' not in best_cfg:
    best_cfg['n_heads'] = 4 if best_cfg['d_model'] == 128 else 8
best_cfg_path = study_dir / 'best_config_optuna.json'
with open(best_cfg_path, 'w') as f:
    json.dump(best_cfg, f, indent=2)
print('Saved best config to', best_cfg_path)

# Save trials dataframe
try:
    trials_df = study.trials_dataframe()
    trials_csv = study_dir / 'optuna_trials.csv'
    trials_df.to_csv(trials_csv, index=False)
    print('Saved trials to', trials_csv)
except Exception as e:
    print('Could not save trials dataframe:', e)


In [None]:
print('Best (Optuna) score:', study.best_value)
print('Best (Optuna) params:', study.best_params)