In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import log_loss
from scipy.special import expit as sigmoid
import gc
from itertools import chain
import torch.amp as amp
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Coronal MIP files already generated from full_mip[1].
train_df = pd.read_csv('train.csv')
coronal_dir = 'data/mips_coronal/train'

# Labels
label_cols = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']
y = train_df[label_cols].values
y_overall = y.max(axis=1).astype(int)
groups = train_df['StudyInstanceUID'].values

# ImageNet-style normalization for replicated channels (average RGB means/stds)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# Fixed augmentations: anatomy-safe, no VerticalFlip/RandomRotate90
train_transform = A.Compose([
    A.Resize(384, 384),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.GaussianBlur(blur_limit=3, p=0.1),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(384, 384),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

class CoronalDataset(Dataset):
    def __init__(self, df, mip_dir, transform):
        self.df = df.reset_index(drop=True)
        self.mip_dir = mip_dir
        self.transform = transform
        self.label_cols = label_cols

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        uid = row['StudyInstanceUID']
        img_path = os.path.join(self.mip_dir, f'{uid}.npy')
        img = np.load(img_path).astype(np.float32)  # (384, 384)
        img = np.stack([img] * 3, axis=-1)  # replicate to (384, 384, 3)
        augmented = self.transform(image=img)['image']
        labels = torch.tensor(row[self.label_cols].values.astype(np.float32))
        return augmented, labels

# Training parameters - BATCH_SIZE=1, ACCUM_STEPS=8 to avoid OOM
N_FOLDS = 5
BATCH_SIZE = 1
VAL_BATCH_SIZE = 1
NUM_EPOCHS = 15
LR = 1e-4
PATIENCE = 4
SEED = 789
ACCUM_STEPS = 8
torch.manual_seed(SEED)
np.random.seed(SEED)

# CV splitter - FIXED: Use StratifiedGroupKFold to prevent patient leakage
skf = StratifiedGroupKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)

# OOF logits collection (will include HFlip TTA averaging)
oof_logits = np.zeros((len(train_df), 7), dtype=np.float32)
fold_scores = []

for fold in range(1, N_FOLDS + 1):
    print(f'\n=== Fold {fold}/{N_FOLDS} ===')
    train_idx, val_idx = list(skf.split(train_df, y_overall, groups))[fold-1]
    train_ds = CoronalDataset(train_df.iloc[train_idx], coronal_dir, train_transform)
    val_ds = CoronalDataset(train_df.iloc[val_idx], coronal_dir, val_transform)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    model = timm.create_model('regnety_004', pretrained=True, num_classes=7, in_chans=3,
                              drop_rate=0.3, drop_path_rate=0.15)
    try:
        model.set_grad_checkpointing(True)
    except:
        pass
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
    criterion = nn.BCEWithLogitsLoss()
    scaler = amp.GradScaler('cuda')

    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(NUM_EPOCHS):
        # Train with AMP, float16, channels_last, grad accum
        model.train()
        train_loss = 0.0
        optimizer.zero_grad(set_to_none=True)
        for batch_idx, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            labels = labels.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits = model(imgs)
                loss = criterion(logits, labels) / ACCUM_STEPS
            scaler.scale(loss).backward()
            if (batch_idx + 1) % ACCUM_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
            train_loss += loss.item() * ACCUM_STEPS
        train_loss /= len(train_loader)

        # Val (BCE loss only, no TTA here)
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for imgs, lbls in val_loader:
                imgs = imgs.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                lbls = lbls.to(device, non_blocking=True)
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    logits = model(imgs)
                    loss = criterion(logits, lbls)
                val_loss += loss.item()
        val_loss /= len(val_loader)

        scheduler.step()
        print(f'Epoch {epoch+1}: Train {train_loss:.4f}, Val {val_loss:.4f}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f'Early stopping at epoch {epoch+1}')
                break

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        torch.save(model.state_dict(), f'fold_{fold}_regnet_coronal_fixed.pth')
    else:
        print(f'No best model for fold {fold}')
        continue

    # Collect true OOF logits with HFlip TTA for this fold's val
    model.eval()
    fold_oof_logits = []
    with torch.no_grad():
        for imgs, _ in val_loader:
            imgs = imgs.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits_orig = model(imgs)
                imgs_flip = torch.flip(imgs, dims=[3])  # HFlip
                logits_flip = model(imgs_flip)
                logits_avg = 0.5 * (logits_orig + logits_flip)
            fold_oof_logits.append(logits_avg.cpu().numpy())
    fold_oof = np.concatenate(fold_oof_logits, axis=0)
    oof_logits[val_idx] = fold_oof

    # Fold score (full WLL with patient_overall=max(p7)) - FIXED: Add labels=[0,1] to handle all-0 val sets
    p7 = sigmoid(fold_oof)
    fold_y7 = y[val_idx]
    fold_y_overall = fold_y7.max(axis=1).astype(int)
    p_overall = p7.max(axis=1)
    vert_losses = [log_loss(fold_y7[:, i], p7[:, i], labels=[0,1]) for i in range(7)]
    overall_loss = log_loss(fold_y_overall, p_overall, labels=[0,1])
    fold_score = np.average(vert_losses + [overall_loss], weights=[1]*7 + [2])
    fold_scores.append(fold_score)
    print(f'Fold {fold} full WLL: {fold_score:.4f}')

    # Cleanup
    del model, train_loader, val_loader, train_ds, val_ds, scaler
    gc.collect()
    torch.cuda.empty_cache()

# Save OOF logits (with TTA)
np.save('oof_logits_coronal_fixed.npy', oof_logits)

# Overall CV full WLL
cv_wll = np.mean(fold_scores)
print(f'5-fold CV full WLL: {cv_wll:.4f} (target ~0.45-0.48 standalone, leakage-free)')

# Also vertebrae-only for reference - FIXED: Add labels=[0,1]
p7_oof = sigmoid(oof_logits)
vert_losses = [log_loss(y[:, i], p7_oof[:, i], labels=[0,1]) for i in range(7)]
vert_wll = np.mean(vert_losses)
print(f'Vertebrae-only OOF WLL: {vert_wll:.4f}')

print('Leakage-fixed coronal training complete. OOF logits (HFlip TTA) saved as oof_logits_coronal_fixed.npy. CV scores now reliable (expect higher than leaky 0.4680). Next: similarly fix 10_axial_training.ipynb and 11_soft_tissue_axial_training.ipynb, retrain all three for new leakage-free OOF, then re-gate in 03_inference_tta.ipynb cell 6, establish reliable baseline, proceed to 6-channel training after mip_prep completes.')

Using device: cuda

=== Fold 1/5 ===


  original_init(self, **validated_kwargs)


  with torch.cuda.amp.autocast(dtype=torch.float16):
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


  with torch.cuda.amp.autocast(dtype=torch.float16):


Epoch 1: Train 0.5982, Val 0.5517


Epoch 2: Train 0.4120, Val 0.4788


Epoch 3: Train 0.3542, Val 0.3646


Epoch 4: Train 0.3482, Val 0.3706


Epoch 5: Train 0.3441, Val 0.4074


Epoch 6: Train 0.3442, Val 0.3927


Epoch 7: Train 0.3477, Val 0.4275
Early stopping at epoch 7


  with torch.cuda.amp.autocast(dtype=torch.float16):


Fold 1 full WLL: 0.4812

=== Fold 2/5 ===


  with torch.cuda.amp.autocast(dtype=torch.float16):
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


  with torch.cuda.amp.autocast(dtype=torch.float16):


Epoch 1: Train 0.5999, Val 0.5734


Epoch 2: Train 0.4182, Val 0.3791


Epoch 3: Train 0.3655, Val 0.3156


Epoch 4: Train 0.3526, Val 0.3366


Epoch 5: Train 0.3513, Val 0.3513


Epoch 6: Train 0.3540, Val 0.3232


Epoch 7: Train 0.3540, Val 0.3262
Early stopping at epoch 7


  with torch.cuda.amp.autocast(dtype=torch.float16):


Fold 2 full WLL: 0.4189

=== Fold 3/5 ===


  with torch.cuda.amp.autocast(dtype=torch.float16):
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


  with torch.cuda.amp.autocast(dtype=torch.float16):


Epoch 1: Train 0.5723, Val 0.5143


Epoch 2: Train 0.4083, Val 0.4358


Epoch 3: Train 0.3509, Val 0.3852


Epoch 4: Train 0.3384, Val 0.3754


Epoch 5: Train 0.3413, Val 0.3816


Epoch 6: Train 0.3386, Val 0.3746


Epoch 7: Train 0.3408, Val 0.3896


Epoch 8: Train 0.3385, Val 0.3669


Epoch 9: Train 0.3366, Val 0.3903


Epoch 10: Train 0.3389, Val 0.3865


Epoch 11: Train 0.3411, Val 0.3662


Epoch 12: Train 0.3351, Val 0.3721


Epoch 13: Train 0.3333, Val 0.3817


Epoch 14: Train 0.3362, Val 0.3768


Epoch 15: Train 0.3348, Val 0.3753
Early stopping at epoch 15


  with torch.cuda.amp.autocast(dtype=torch.float16):


Fold 3 full WLL: 0.4784

=== Fold 4/5 ===


  with torch.cuda.amp.autocast(dtype=torch.float16):
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


  with torch.cuda.amp.autocast(dtype=torch.float16):


Epoch 1: Train 0.6387, Val 0.6021


Epoch 2: Train 0.4325, Val 0.5117


Epoch 3: Train 0.3362, Val 0.4594


Epoch 4: Train 0.3315, Val 0.4059


Epoch 5: Train 0.3297, Val 0.5498


Epoch 6: Train 0.3261, Val 0.4828


Epoch 7: Train 0.3248, Val 0.4109


Epoch 8: Train 0.3244, Val 0.4259
Early stopping at epoch 8


  with torch.cuda.amp.autocast(dtype=torch.float16):


Fold 4 full WLL: 0.5227

=== Fold 5/5 ===


  with torch.cuda.amp.autocast(dtype=torch.float16):
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]


  with torch.cuda.amp.autocast(dtype=torch.float16):


Epoch 1: Train 0.6007, Val 0.5311


Epoch 2: Train 0.4201, Val 0.3605


Epoch 3: Train 0.3676, Val 0.3658


Epoch 4: Train 0.3550, Val 0.3247


Epoch 5: Train 0.3604, Val 0.3092


Epoch 6: Train 0.3554, Val 0.3333


Epoch 7: Train 0.3527, Val 0.3077


Epoch 8: Train 0.3600, Val 0.3095


Epoch 9: Train 0.3527, Val 0.3575


Epoch 10: Train 0.3559, Val 0.3312


Epoch 11: Train 0.3520, Val 0.3344
Early stopping at epoch 11


  with torch.cuda.amp.autocast(dtype=torch.float16):


Fold 5 full WLL: 0.4173
5-fold CV full WLL: 0.4637 (target ~0.45-0.48 standalone, leakage-free)
Vertebrae-only OOF WLL: 0.3738
Leakage-fixed coronal training complete. OOF logits (HFlip TTA) saved as oof_logits_coronal_fixed.npy. CV scores now reliable (expect higher than leaky 0.4680). Next: similarly fix 10_axial_training.ipynb and 11_soft_tissue_axial_training.ipynb, retrain all three for new leakage-free OOF, then re-gate in 03_inference_tta.ipynb cell 6, establish reliable baseline, proceed to 6-channel training after mip_prep completes.
