In [1]:
import torch, timm, os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from scipy.special import expit as sigmoid
import joblib
from scipy.optimize import minimize_scalar
from itertools import product
from sklearn.metrics import log_loss

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load holdout data
holdout = pd.read_csv('data/holdout.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
y_holdout = holdout[label_cols].values.astype(float)
y_overall_holdout = y_holdout.max(axis=1).astype(int)
y8_holdout = np.hstack([y_holdout, y_overall_holdout[:,None]])
print(f'Holdout size: {len(holdout)}')

# Dataset for holdout (load from train mips dir)
class HoldoutDataset(Dataset):
    def __init__(self, df, mip_dir, transform):
        self.df = df.reset_index(drop=True)
        self.mip_dir = mip_dir
        self.transform = transform
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        uid = self.df.iloc[idx]['StudyInstanceUID']
        mip = np.load(os.path.join(self.mip_dir, f'{uid}.npy')).astype(np.float32)
        img = np.transpose(mip, (1,2,0))
        img = self.transform(image=img)['image']
        return uid, img

holdout_transform = A.Compose([A.Resize(384,384), A.Normalize(mean=0.5, std=0.5), ToTensorV2()])
holdout_ds = HoldoutDataset(holdout, 'data/mips/train', holdout_transform)
holdout_loader = DataLoader(holdout_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

# Model builders from 03_inference_tta
def build_convnext_v1():
    return timm.create_model('convnext_tiny', pretrained=False, num_classes=7, in_chans=3,
                             drop_rate=0.3, drop_path_rate=0.1).to(device).eval()

def build_convnext_v2():
    return timm.create_model('convnext_tiny', pretrained=False, num_classes=7, in_chans=3,
                             drop_rate=0.4, drop_path_rate=0.2).to(device).eval()

def build_regnet():
    return timm.create_model('regnety_004', pretrained=False, num_classes=7, in_chans=3,
                             drop_rate=0.4, drop_path_rate=0.2).to(device).eval()

# Predict function adapted for holdout
def predict_backbone_holdout(ckpt_pattern, build_fn):
    all_logits = []
    N_FOLDS = 5
    for f in range(1, N_FOLDS+1):
        ckpt = ckpt_pattern.format(f)
        if not os.path.exists(ckpt):
            print(f'Skip missing {ckpt}'); continue
        model = build_fn()
        sd = torch.load(ckpt, map_location='cpu', weights_only=True)
        model.load_state_dict(sd, strict=True)
        fold_logits = []
        with torch.no_grad():
            for uids, images in tqdm(holdout_loader, desc=f'{ckpt_pattern} F{f} holdout'):
                images = images.to(device)
                logits = model(images)
                logits_f = model(torch.flip(images, dims=[3]))  # HFlip TTA
                logits = 0.5 * (logits + logits_f)
                fold_logits.append(logits.cpu().numpy())
        all_logits.append(np.concatenate(fold_logits, axis=0))
    if len(all_logits) == 0:
        raise RuntimeError(f'No checkpoints for {ckpt_pattern}')
    return np.mean(np.stack(all_logits, 0), 0)

# Compute 3-way logits on holdout
print('Computing 3-way logits on holdout...')
logits_v1 = predict_backbone_holdout('fold_{}_convnext.pth', build_convnext_v1)
logits_v2 = predict_backbone_holdout('fold_{}_convnext_v2.pth', build_convnext_v2)
logits_reg = predict_backbone_holdout('fold_{}_regnet.pth', build_regnet)

# Load artifacts
W = np.load('weights_threeway_tta.npy')
T7 = np.load('temperatures_weighted_tta.npy')
rule = open('overall_rule_tta.txt').read().strip()
lr = joblib.load('overall_regressor_tta.pkl')
alpha = float(np.load('overall_lr_alpha_tta.npy')[0])

# Blend
stack = np.stack([logits_v1, logits_v2, logits_reg], axis=2)
X_holdout = np.sum(stack * W[None,:,:], axis=2)
p7_uncal = sigmoid(X_holdout)
p7_cal = sigmoid(X_holdout / T7)
union = 1 - np.prod(1 - p7_uncal, axis=1)
max_prob = p7_uncal.max(axis=1)
base_overall = union if rule=='union' else max_prob
p_lr_holdout = lr.predict_proba(X_holdout)[:,1]
p_overall_holdout = alpha * p_lr_holdout + (1 - alpha) * base_overall
p8_holdout = np.hstack([p7_cal, p_overall_holdout[:,None]])

# Metric for holdout
def wll_holdout(y_true8, p8):
    p8 = np.clip(p8, 1e-6, 1-1e-6)
    losses = [log_loss(y_true8[:,i], p8[:,i], labels=[0,1]) for i in range(8)]
    return float(np.average(losses, weights=np.array([1]*7+[2], float)))

baseline_wll_holdout = wll_holdout(y8_holdout, p8_holdout)
print(f'Baseline WLL on holdout: {baseline_wll_holdout:.4f}')

# Smooth chain function
def smooth_chain(p7, lam):
    p_smooth = p7.copy()
    for i in range(1, 6):
        p_smooth[:,i] = (1 - 2*lam) * p7[:,i] + lam * (p7[:,i-1] + p7[:,i+1])
    p_smooth[:,0] = (1 - lam) * p7[:,0] + lam * p7[:,1]
    p_smooth[:,6] = (1 - lam) * p7[:,6] + lam * p7[:,5]
    return p_smooth

# Tune lam on holdout
def obj_lam_holdout(lam):
    p7_smooth = smooth_chain(p7_cal, lam)
    return wll_holdout(y8_holdout, np.hstack([p7_smooth, p_overall_holdout[:,None]]))

res_lam = minimize_scalar(obj_lam_holdout, bounds=(0, 0.35), method='bounded')
lam_holdout = res_lam.x
print(f'Tuned lambda on holdout: {lam_holdout:.4f}, WLL: {res_lam.fun:.4f}')

# Tune gamma on holdout
gammas = np.arange(0, 1.01, 0.05)
best_gamma_holdout = 0.0
best_wll_holdout = float('inf')
p7_smooth_holdout = smooth_chain(p7_cal, lam_holdout)
for gamma in gammas:
    base = gamma * union + (1 - gamma) * max_prob
    p_overall = alpha * p_lr_holdout + (1 - alpha) * base
    score = wll_holdout(y8_holdout, np.hstack([p7_smooth_holdout, p_overall[:,None]]))
    if score < best_wll_holdout:
        best_wll_holdout = score
        best_gamma_holdout = gamma
print(f'Tuned gamma on holdout: {best_gamma_holdout:.4f}, WLL: {best_wll_holdout:.4f}')

# Baseline p8 after smoothing/gamma on holdout
base_holdout = best_gamma_holdout * union + (1 - best_gamma_holdout) * max_prob
p_overall_holdout = alpha * p_lr_holdout + (1 - alpha) * base_holdout
p8_baseline_holdout = np.hstack([p7_smooth_holdout, p_overall_holdout[:,None]])
baseline_wll_after_smooth_gamma = wll_holdout(y8_holdout, p8_baseline_holdout)
print(f'Baseline WLL after smoothing/gamma on holdout: {baseline_wll_after_smooth_gamma:.4f}')

# Grouped shrink on holdout (no folds, simple check if improves baseline)
a_verts_range = np.arange(0.92, 1.00, 0.01)
b_verts_range = np.arange(0, 0.031, 0.005)
a_overall_range = np.arange(0.95, 1.005, 0.005)
b_overall_range = np.arange(0, 0.021, 0.0025)

best_a_holdout = np.ones(8)
best_b_holdout = np.zeros(8)
best_final_wll_holdout = float('inf')
best_params = None
for a_v, b_v in product(a_verts_range, b_verts_range):
    for a_o, b_o in product(a_overall_range, b_overall_range):
        a8 = np.full(8, a_v)
        a8[7] = a_o
        b8 = np.full(8, b_v)
        b8[7] = b_o
        p8_shrunk = np.clip(b8[None,:] + a8[None,:] * p8_baseline_holdout, 1e-6, 1-1e-6)
        score = wll_holdout(y8_holdout, p8_shrunk)
        if score < best_final_wll_holdout:
            best_final_wll_holdout = score
            best_a_holdout = a8.copy()
            best_b_holdout = b8.copy()
            best_params = (a_v, b_v, a_o, b_o)
print(f'Best grouped shrink on holdout: a_verts={best_params[0]:.4f}, b_verts={best_params[1]:.4f}, a_overall={best_params[2]:.4f}, b_overall={best_params[3]:.4f}, WLL: {best_final_wll_holdout:.4f}')

# Simple check: if improves baseline, use it; else uniform
if best_final_wll_holdout < baseline_wll_after_smooth_gamma:
    use_grouped_holdout = True
    final_a = best_a_holdout
    final_b = best_b_holdout
    final_wll_holdout = best_final_wll_holdout
    print('Grouped shrink improves holdout WLL: Use it.')
else:
    use_grouped_holdout = False
    final_a = np.full(8, 0.92)
    final_b = np.full(8, 0.005)
    final_wll_holdout = wll_holdout(y8_holdout, np.clip(final_b[None,:] + final_a[None,:] * p8_baseline_holdout, 1e-6, 1-1e-6))
    print('Grouped shrink does not improve holdout WLL: Revert to uniform.')

# Save holdout-tuned params
np.savez('postproc_holdout.npz', lam=lam_holdout, gamma=best_gamma_holdout, a=final_a, b=final_b, alpha=alpha, grouped=use_grouped_holdout)
print(f'Holdout-tuned postproc saved: WLL {final_wll_holdout:.4f} (baseline {baseline_wll_holdout:.4f}). Next: edit 03_inference_tta cell 0 to load postproc_holdout.npz, re-execute inference/submit for LB validation.')

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Holdout size: 41
Computing 3-way logits on holdout...


fold_{}_convnext.pth F1 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_convnext.pth F1 holdout:   2%|▏         | 1/41 [00:00<00:10,  3.94it/s]

fold_{}_convnext.pth F1 holdout:  29%|██▉       | 12/41 [00:00<00:00, 40.93it/s]

fold_{}_convnext.pth F1 holdout:  54%|█████▎    | 22/41 [00:00<00:00, 60.33it/s]

fold_{}_convnext.pth F1 holdout:  78%|███████▊  | 32/41 [00:00<00:00, 72.83it/s]

fold_{}_convnext.pth F1 holdout: 100%|██████████| 41/41 [00:00<00:00, 60.40it/s]




fold_{}_convnext.pth F2 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_convnext.pth F2 holdout:   5%|▍         | 2/41 [00:00<00:01, 19.89it/s]

fold_{}_convnext.pth F2 holdout:  29%|██▉       | 12/41 [00:00<00:00, 66.79it/s]

fold_{}_convnext.pth F2 holdout:  54%|█████▎    | 22/41 [00:00<00:00, 81.33it/s]

fold_{}_convnext.pth F2 holdout:  78%|███████▊  | 32/41 [00:00<00:00, 88.23it/s]

fold_{}_convnext.pth F2 holdout: 100%|██████████| 41/41 [00:00<00:00, 78.45it/s]




fold_{}_convnext.pth F3 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_convnext.pth F3 holdout:   5%|▍         | 2/41 [00:00<00:02, 18.88it/s]

fold_{}_convnext.pth F4 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_convnext.pth F4 holdout:   5%|▍         | 2/41 [00:00<00:02, 18.68it/s]

fold_{}_convnext.pth F4 holdout:  29%|██▉       | 12/41 [00:00<00:00, 64.98it/s]

fold_{}_convnext.pth F4 holdout:  54%|█████▎    | 22/41 [00:00<00:00, 79.90it/s]

fold_{}_convnext.pth F4 holdout:  78%|███████▊  | 32/41 [00:00<00:00, 87.38it/s]

fold_{}_convnext.pth F4 holdout: 100%|██████████| 41/41 [00:00<00:00, 77.14it/s]




fold_{}_convnext.pth F5 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_convnext.pth F5 holdout:   5%|▍         | 2/41 [00:00<00:02, 19.08it/s]

fold_{}_convnext.pth F5 holdout:  32%|███▏      | 13/41 [00:00<00:00, 67.57it/s]

fold_{}_convnext.pth F5 holdout:  56%|█████▌    | 23/41 [00:00<00:00, 80.91it/s]

fold_{}_convnext.pth F5 holdout:  80%|████████  | 33/41 [00:00<00:00, 87.51it/s]

fold_{}_convnext.pth F5 holdout: 100%|██████████| 41/41 [00:00<00:00, 76.81it/s]




fold_{}_convnext_v2.pth F1 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_convnext_v2.pth F1 holdout:   5%|▍         | 2/41 [00:00<00:02, 19.18it/s]

fold_{}_convnext_v2.pth F1 holdout:  32%|███▏      | 13/41 [00:00<00:00, 67.78it/s]

fold_{}_convnext_v2.pth F1 holdout:  56%|█████▌    | 23/41 [00:00<00:00, 81.35it/s]

fold_{}_convnext_v2.pth F1 holdout:  80%|████████  | 33/41 [00:00<00:00, 87.91it/s]

fold_{}_convnext_v2.pth F1 holdout: 100%|██████████| 41/41 [00:00<00:00, 77.26it/s]




fold_{}_convnext_v2.pth F2 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_convnext_v2.pth F2 holdout:   5%|▍         | 2/41 [00:00<00:02, 18.56it/s]

fold_{}_convnext_v2.pth F2 holdout:  29%|██▉       | 12/41 [00:00<00:00, 64.45it/s]

fold_{}_convnext_v2.pth F2 holdout:  54%|█████▎    | 22/41 [00:00<00:00, 79.47it/s]

fold_{}_convnext_v2.pth F2 holdout:  78%|███████▊  | 32/41 [00:00<00:00, 86.86it/s]

fold_{}_convnext_v2.pth F2 holdout: 100%|██████████| 41/41 [00:00<00:00, 76.72it/s]




fold_{}_convnext_v2.pth F3 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_convnext_v2.pth F3 holdout:   5%|▍         | 2/41 [00:00<00:02, 18.60it/s]

fold_{}_convnext_v2.pth F3 holdout:  29%|██▉       | 12/41 [00:00<00:00, 64.85it/s]

fold_{}_convnext_v2.pth F3 holdout:  54%|█████▎    | 22/41 [00:00<00:00, 79.80it/s]

fold_{}_convnext_v2.pth F3 holdout:  78%|███████▊  | 32/41 [00:00<00:00, 87.25it/s]

fold_{}_convnext_v2.pth F3 holdout: 100%|██████████| 41/41 [00:00<00:00, 77.22it/s]




fold_{}_convnext_v2.pth F4 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_convnext_v2.pth F4 holdout:   5%|▍         | 2/41 [00:00<00:02, 18.56it/s]

fold_{}_convnext_v2.pth F4 holdout:  29%|██▉       | 12/41 [00:00<00:00, 64.66it/s]

fold_{}_convnext_v2.pth F4 holdout:  54%|█████▎    | 22/41 [00:00<00:00, 79.54it/s]

fold_{}_convnext_v2.pth F4 holdout:  78%|███████▊  | 32/41 [00:00<00:00, 86.79it/s]

fold_{}_convnext_v2.pth F4 holdout: 100%|██████████| 41/41 [00:00<00:00, 76.27it/s]




fold_{}_convnext_v2.pth F5 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_convnext_v2.pth F5 holdout:   2%|▏         | 1/41 [00:00<00:04,  9.84it/s]

fold_{}_convnext_v2.pth F5 holdout:  27%|██▋       | 11/41 [00:00<00:00, 60.77it/s]

fold_{}_convnext_v2.pth F5 holdout:  51%|█████     | 21/41 [00:00<00:00, 75.40it/s]

fold_{}_convnext_v2.pth F5 holdout:  76%|███████▌  | 31/41 [00:00<00:00, 83.37it/s]

fold_{}_convnext_v2.pth F5 holdout: 100%|██████████| 41/41 [00:00<00:00, 88.51it/s]

fold_{}_convnext_v2.pth F5 holdout: 100%|██████████| 41/41 [00:00<00:00, 73.60it/s]




fold_{}_regnet.pth F1 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_regnet.pth F1 holdout:   2%|▏         | 1/41 [00:00<00:09,  4.33it/s]

fold_{}_regnet.pth F1 holdout:  20%|█▉        | 8/41 [00:00<00:01, 29.13it/s]

fold_{}_regnet.pth F1 holdout:  37%|███▋      | 15/41 [00:00<00:00, 43.24it/s]

fold_{}_regnet.pth F1 holdout:  54%|█████▎    | 22/41 [00:00<00:00, 51.93it/s]

fold_{}_regnet.pth F1 holdout:  71%|███████   | 29/41 [00:00<00:00, 57.44it/s]

fold_{}_regnet.pth F1 holdout:  90%|█████████ | 37/41 [00:00<00:00, 62.02it/s]

fold_{}_regnet.pth F1 holdout: 100%|██████████| 41/41 [00:00<00:00, 49.08it/s]




fold_{}_regnet.pth F2 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_regnet.pth F2 holdout:   2%|▏         | 1/41 [00:00<00:04,  9.04it/s]

fold_{}_regnet.pth F2 holdout:  20%|█▉        | 8/41 [00:00<00:00, 42.46it/s]

fold_{}_regnet.pth F2 holdout:  37%|███▋      | 15/41 [00:00<00:00, 53.82it/s]

fold_{}_regnet.pth F2 holdout:  54%|█████▎    | 22/41 [00:00<00:00, 59.32it/s]

fold_{}_regnet.pth F2 holdout:  71%|███████   | 29/41 [00:00<00:00, 62.47it/s]

fold_{}_regnet.pth F2 holdout:  88%|████████▊ | 36/41 [00:00<00:00, 64.94it/s]

fold_{}_regnet.pth F2 holdout: 100%|██████████| 41/41 [00:00<00:00, 56.23it/s]




fold_{}_regnet.pth F3 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_regnet.pth F3 holdout:   2%|▏         | 1/41 [00:00<00:04,  9.91it/s]

fold_{}_regnet.pth F3 holdout:  20%|█▉        | 8/41 [00:00<00:00, 43.47it/s]

fold_{}_regnet.pth F3 holdout:  37%|███▋      | 15/41 [00:00<00:00, 54.06it/s]

fold_{}_regnet.pth F3 holdout:  54%|█████▎    | 22/41 [00:00<00:00, 59.24it/s]

fold_{}_regnet.pth F3 holdout:  71%|███████   | 29/41 [00:00<00:00, 62.07it/s]

fold_{}_regnet.pth F3 holdout:  88%|████████▊ | 36/41 [00:00<00:00, 64.46it/s]

fold_{}_regnet.pth F3 holdout: 100%|██████████| 41/41 [00:00<00:00, 56.29it/s]




fold_{}_regnet.pth F4 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_regnet.pth F4 holdout:   2%|▏         | 1/41 [00:00<00:04,  8.98it/s]

fold_{}_regnet.pth F4 holdout:  20%|█▉        | 8/41 [00:00<00:00, 42.12it/s]

fold_{}_regnet.pth F4 holdout:  37%|███▋      | 15/41 [00:00<00:00, 53.50it/s]

fold_{}_regnet.pth F4 holdout:  54%|█████▎    | 22/41 [00:00<00:00, 59.15it/s]

fold_{}_regnet.pth F4 holdout:  71%|███████   | 29/41 [00:00<00:00, 62.17it/s]

fold_{}_regnet.pth F4 holdout:  88%|████████▊ | 36/41 [00:00<00:00, 64.61it/s]

fold_{}_regnet.pth F4 holdout: 100%|██████████| 41/41 [00:00<00:00, 55.97it/s]




fold_{}_regnet.pth F5 holdout:   0%|          | 0/41 [00:00<?, ?it/s]

fold_{}_regnet.pth F5 holdout:   2%|▏         | 1/41 [00:00<00:04,  9.61it/s]

fold_{}_regnet.pth F5 holdout:  20%|█▉        | 8/41 [00:00<00:00, 42.95it/s]

fold_{}_regnet.pth F5 holdout:  37%|███▋      | 15/41 [00:00<00:00, 53.37it/s]

fold_{}_regnet.pth F5 holdout:  54%|█████▎    | 22/41 [00:00<00:00, 58.92it/s]

fold_{}_regnet.pth F5 holdout:  71%|███████   | 29/41 [00:00<00:00, 62.08it/s]

fold_{}_regnet.pth F5 holdout:  88%|████████▊ | 36/41 [00:00<00:00, 64.45it/s]

fold_{}_regnet.pth F5 holdout: 100%|██████████| 41/41 [00:00<00:00, 56.58it/s]




Baseline WLL on holdout: 0.4344
Tuned lambda on holdout: 0.0000, WLL: 0.4344


Tuned gamma on holdout: 1.0000, WLL: 0.4344
Baseline WLL after smoothing/gamma on holdout: 0.4344


Best grouped shrink on holdout: a_verts=0.9900, b_verts=0.0050, a_overall=0.9850, b_overall=0.0200, WLL: 0.4343
Grouped shrink improves holdout WLL: Use it.
Holdout-tuned postproc saved: WLL 0.4343 (baseline 0.4344). Next: edit 03_inference_tta cell 0 to load postproc_holdout.npz, re-execute inference/submit for LB validation.
