In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import log_loss
from scipy.special import expit as sigmoid
import gc
from itertools import chain
import torch.amp as amp
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# 6-channel multi-MIP files from mip_prep.py (bone_sag, bone_cor, bone_ax, soft_ax, soft_sag, soft_cor)
train_df = pd.read_csv('data/train_mips_multi.csv')
multi_dir = 'data/mips_multi/train'

# Labels
label_cols = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']
y = train_df[label_cols].values
y_overall = y.max(axis=1).astype(int)
groups = train_df['StudyInstanceUID'].values

# Normalization for 6-channel: replicate ImageNet avg RGB (bone/soft MIPs [0,1])
mean = [0.485] * 6
std = [0.229] * 6

# Fixed augmentations: apply to (H,W,C=6)
train_transform = A.Compose([
    A.Resize(384, 384),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.GaussianBlur(blur_limit=3, p=0.1),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()  # (C,H,W)
])

val_transform = A.Compose([
    A.Resize(384, 384),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

class MultiChannelDataset(Dataset):
    def __init__(self, df, mip_dir, transform):
        self.df = df.reset_index(drop=True)
        self.mip_dir = mip_dir
        self.transform = transform
        self.label_cols = label_cols

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        uid = row['StudyInstanceUID']
        img_path = os.path.join(self.mip_dir, f'{uid}.npy')
        mips6 = np.load(img_path).astype(np.float32)  # (6,384,384)
        # Transpose to (H,W,C=6) for albumentations
        img_hwc = np.transpose(mips6, (1,2,0))
        augmented = self.transform(image=img_hwc)['image']  # (C,H,W)
        labels = torch.tensor(row[self.label_cols].values.astype(np.float32))
        return augmented, labels

# Training parameters - BATCH_SIZE=1, ACCUM_STEPS=8 to avoid OOM with in_chans=6
N_FOLDS = 5
BATCH_SIZE = 1
VAL_BATCH_SIZE = 1
NUM_EPOCHS = 15
LR = 1e-4
PATIENCE = 4
SEED = 791
ACCUM_STEPS = 8
torch.manual_seed(SEED)
np.random.seed(SEED)

# CV splitter - FIXED: Use StratifiedGroupKFold to prevent patient leakage
skf = StratifiedGroupKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)

# OOF logits collection (will include HFlip TTA averaging)
oof_logits = np.zeros((len(train_df), 7), dtype=np.float32)
fold_scores = []

for fold in range(1, N_FOLDS + 1):
    print(f'\n=== Fold {fold}/{N_FOLDS} ===')
    train_idx, val_idx = list(skf.split(train_df, y_overall, groups))[fold-1]
    train_ds = MultiChannelDataset(train_df.iloc[train_idx], multi_dir, train_transform)
    val_ds = MultiChannelDataset(train_df.iloc[val_idx], multi_dir, val_transform)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    model = timm.create_model('regnety_004', pretrained=True, num_classes=7, in_chans=6,
                              drop_rate=0.3, drop_path_rate=0.15)
    try:
        model.set_grad_checkpointing(True)
    except:
        pass
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
    criterion = nn.BCEWithLogitsLoss()
    scaler = amp.GradScaler('cuda')

    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(NUM_EPOCHS):
        # Train with AMP, float16, channels_last, grad accum
        model.train()
        train_loss = 0.0
        optimizer.zero_grad(set_to_none=True)
        for batch_idx, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            labels = labels.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits = model(imgs)
                loss = criterion(logits, labels) / ACCUM_STEPS
            scaler.scale(loss).backward()
            if (batch_idx + 1) % ACCUM_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
            train_loss += loss.item() * ACCUM_STEPS
        train_loss /= len(train_loader)

        # Val (BCE loss only, no TTA here)
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for imgs, lbls in val_loader:
                imgs = imgs.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                lbls = lbls.to(device, non_blocking=True)
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    logits = model(imgs)
                    loss = criterion(logits, lbls)
                val_loss += loss.item()
        val_loss /= len(val_loader)

        scheduler.step()
        print(f'Epoch {epoch+1}: Train {train_loss:.4f}, Val {val_loss:.4f}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f'Early stopping at epoch {epoch+1}')
                break

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        torch.save(model.state_dict(), f'fold_{fold}_regnet_multi_fixed.pth')
    else:
        print(f'No best model for fold {fold}')
        continue

    # Collect true OOF logits with HFlip TTA for this fold's val
    model.eval()
    fold_oof_logits = []
    with torch.no_grad():
        for imgs, _ in val_loader:
            imgs = imgs.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits_orig = model(imgs)
                imgs_flip = torch.flip(imgs, dims=[3])  # HFlip on W dim (dim=3 in C,H,W)
                logits_flip = model(imgs_flip)
                logits_avg = 0.5 * (logits_orig + logits_flip)
            fold_oof_logits.append(logits_avg.cpu().numpy())
    fold_oof = np.concatenate(fold_oof_logits, axis=0)
    oof_logits[val_idx] = fold_oof

    # Fold score (full WLL with patient_overall=max(p7)) - FIXED: Add labels=[0,1]
    p7 = sigmoid(fold_oof)
    fold_y7 = y[val_idx]
    fold_y_overall = fold_y7.max(axis=1).astype(int)
    p_overall = p7.max(axis=1)
    vert_losses = [log_loss(fold_y7[:, i], p7[:, i], labels=[0,1]) for i in range(7)]
    overall_loss = log_loss(fold_y_overall, p_overall, labels=[0,1])
    fold_score = np.average(vert_losses + [overall_loss], weights=[1]*7 + [2])
    fold_scores.append(fold_score)
    print(f'Fold {fold} full WLL: {fold_score:.4f}')

    # Cleanup
    del model, train_loader, val_loader, train_ds, val_ds, scaler
    gc.collect()
    torch.cuda.empty_cache()

# Save OOF logits (with TTA)
np.save('oof_logits_multi_fixed.npy', oof_logits)

# Overall CV full WLL
cv_wll = np.mean(fold_scores)
print(f'5-fold CV full WLL: {cv_wll:.4f} (target ~0.45-0.47 standalone, leakage-free)')

# Also vertebrae-only for reference - FIXED: Add labels=[0,1]
p7_oof = sigmoid(oof_logits)
vert_losses = [log_loss(y[:, i], p7_oof[:, i], labels=[0,1]) for i in range(7)]
vert_wll = np.mean(vert_losses)
print(f'Vertebrae-only OOF WLL: {vert_wll:.4f}')

print('Leakage-fixed 6-channel multi-MIP training complete. OOF logits (HFlip TTA) saved as oof_logits_multi_fixed.npy. CV scores reliable (expect ~0.45 diversity from bone/soft MIPs). Next: re-execute cell 6 in 03_inference_tta.ipynb to gate "Multi" (add gate_model(np.load("oof_logits_multi_fixed.npy"), "Multi")), if accepted (gain >0.001), update cell 1 for 5-way inference (add def predict_multi(ckpt_pattern): timm regnety_004 in_chans=6 load fold pth TTA HFlip mean over folds; multi_logits = predict_multi("fold_{}_regnet_multi_fixed.pth"); X5 = X4_swin + s_multi * multi_logits; refit full postproc on X5 OOF stack), execute cell 1 to generate submission.csv (OOF <=0.41 target), submit_final_answer.')

Using device: cuda

=== Fold 1/5 ===


  original_init(self, **validated_kwargs)


  with torch.cuda.amp.autocast(dtype=torch.float16):


  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


  with torch.cuda.amp.autocast(dtype=torch.float16):


Epoch 1: Train 0.5994, Val 0.5266


Epoch 2: Train 0.4272, Val 0.4216


Epoch 3: Train 0.3597, Val 0.3515


Epoch 4: Train 0.3533, Val 0.4317


Epoch 5: Train 0.3532, Val 0.3466


Epoch 6: Train 0.3570, Val 0.3610


Epoch 7: Train 0.3566, Val 0.3833


Epoch 8: Train 0.3516, Val 0.4016


Epoch 9: Train 0.3542, Val 0.3289


Epoch 10: Train 0.3496, Val 0.3492


Epoch 11: Train 0.3517, Val 0.3133


Epoch 12: Train 0.3550, Val 0.3500


Epoch 13: Train 0.3468, Val 0.3318


Epoch 14: Train 0.3516, Val 0.3192


Epoch 15: Train 0.3499, Val 0.3669
Early stopping at epoch 15


  with torch.cuda.amp.autocast(dtype=torch.float16):


Fold 1 full WLL: 0.4407

=== Fold 2/5 ===


  with torch.cuda.amp.autocast(dtype=torch.float16):
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


  with torch.cuda.amp.autocast(dtype=torch.float16):


Epoch 1: Train 0.5797, Val 0.5703


Epoch 2: Train 0.4135, Val 0.4245


Epoch 3: Train 0.3526, Val 0.4891


Epoch 4: Train 0.3461, Val 0.4081


Epoch 5: Train 0.3468, Val 0.3873


Epoch 6: Train 0.3404, Val 0.4392


Epoch 7: Train 0.3420, Val 0.3753


Epoch 8: Train 0.3424, Val 0.4060


Epoch 9: Train 0.3376, Val 0.3615


Epoch 10: Train 0.3450, Val 0.3825


Epoch 11: Train 0.3397, Val 0.3775


Epoch 12: Train 0.3448, Val 0.3657


Epoch 13: Train 0.3350, Val 0.3855
Early stopping at epoch 13


  with torch.cuda.amp.autocast(dtype=torch.float16):


Fold 2 full WLL: 0.4629

=== Fold 3/5 ===


  with torch.cuda.amp.autocast(dtype=torch.float16):
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


  with torch.cuda.amp.autocast(dtype=torch.float16):


Epoch 1: Train 0.5643, Val 0.6486


Epoch 2: Train 0.4127, Val 0.5934


Epoch 3: Train 0.3415, Val 0.5458


Epoch 4: Train 0.3311, Val 0.4944


Epoch 5: Train 0.3318, Val 0.4209


Epoch 6: Train 0.3310, Val 0.5273


Epoch 7: Train 0.3234, Val 0.4183


Epoch 8: Train 0.3292, Val 0.4898


Epoch 9: Train 0.3299, Val 0.4271


Epoch 10: Train 0.3218, Val 0.4367


Epoch 11: Train 0.3211, Val 0.4193
Early stopping at epoch 11


  with torch.cuda.amp.autocast(dtype=torch.float16):


Fold 3 full WLL: 0.5620

=== Fold 4/5 ===


  with torch.cuda.amp.autocast(dtype=torch.float16):
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


  with torch.cuda.amp.autocast(dtype=torch.float16):


Epoch 1: Train 0.5952, Val 0.5228


Epoch 2: Train 0.4290, Val 0.4616


Epoch 3: Train 0.3545, Val 0.3641


Epoch 4: Train 0.3540, Val 0.3506


Epoch 5: Train 0.3466, Val 0.4262


Epoch 6: Train 0.3411, Val 0.3574


Epoch 7: Train 0.3496, Val 0.4639


Epoch 8: Train 0.3507, Val 0.3734
Early stopping at epoch 8


  with torch.cuda.amp.autocast(dtype=torch.float16):


Fold 4 full WLL: 0.4733

=== Fold 5/5 ===


  with torch.cuda.amp.autocast(dtype=torch.float16):
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


  with torch.cuda.amp.autocast(dtype=torch.float16):


Epoch 1: Train 0.6012, Val 0.5740


Epoch 2: Train 0.4309, Val 0.4391


Epoch 3: Train 0.3561, Val 0.3313


Epoch 4: Train 0.3580, Val 0.4733


Epoch 5: Train 0.3484, Val 0.3116


Epoch 6: Train 0.3436, Val 0.3851


Epoch 7: Train 0.3552, Val 0.3414


Epoch 8: Train 0.3478, Val 0.3366


Epoch 9: Train 0.3513, Val 0.3296
Early stopping at epoch 9


  with torch.cuda.amp.autocast(dtype=torch.float16):


Fold 5 full WLL: 0.4141
5-fold CV full WLL: 0.4706 (target ~0.45-0.47 standalone, leakage-free)
Vertebrae-only OOF WLL: 0.3769
Leakage-fixed 6-channel multi-MIP training complete. OOF logits (HFlip TTA) saved as oof_logits_multi_fixed.npy. CV scores reliable (expect ~0.45 diversity from bone/soft MIPs). Next: re-execute cell 6 in 03_inference_tta.ipynb to gate "Multi" (add gate_model(np.load("oof_logits_multi_fixed.npy"), "Multi")), if accepted (gain >0.001), update cell 1 for 5-way inference (add def predict_multi(ckpt_pattern): timm regnety_004 in_chans=6 load fold pth TTA HFlip mean over folds; multi_logits = predict_multi("fold_{}_regnet_multi_fixed.pth"); X5 = X4_swin + s_multi * multi_logits; refit full postproc on X5 OOF stack), execute cell 1 to generate submission.csv (OOF <=0.41 target), submit_final_answer.
