In [2]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import log_loss
from scipy.special import expit as sigmoid
import gc
from itertools import chain
import torch.amp as amp
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Soft-tissue axial MIP files generated with C=40 W=400 window.
train_df = pd.read_csv('train.csv')
axial_dir = 'data/mips_axial_soft/train'

# Labels
label_cols = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']
y = train_df[label_cols].values
y_overall = y.max(axis=1).astype(int)
groups = train_df['StudyInstanceUID'].values

# ImageNet-style normalization for replicated channels (average RGB means/stds)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

# Fixed augmentations: anatomy-safe, no VerticalFlip/RandomRotate90
train_transform = A.Compose([
    A.Resize(384, 384),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.GaussianBlur(blur_limit=3, p=0.1),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(384, 384),
    A.Normalize(mean=mean, std=std),
    ToTensorV2()
])

class AxialDataset(Dataset):
    def __init__(self, df, mip_dir, transform):
        self.df = df.reset_index(drop=True)
        self.mip_dir = mip_dir
        self.transform = transform
        self.label_cols = label_cols

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        uid = row['StudyInstanceUID']
        img_path = os.path.join(self.mip_dir, f'{uid}.npy')
        img = np.load(img_path).astype(np.float32)  # (384, 384)
        img = np.stack([img] * 3, axis=-1)  # replicate to (384, 384, 3)
        augmented = self.transform(image=img)['image']
        labels = torch.tensor(row[self.label_cols].values.astype(np.float32))
        return augmented, labels

# Training parameters - BATCH_SIZE=1, ACCUM_STEPS=8 to avoid OOM
N_FOLDS = 5
BATCH_SIZE = 1
VAL_BATCH_SIZE = 1
NUM_EPOCHS = 15
LR = 1e-4
PATIENCE = 4
SEED = 790
ACCUM_STEPS = 8
torch.manual_seed(SEED)
np.random.seed(SEED)

# CV splitter - FIXED: Use StratifiedGroupKFold to prevent patient leakage
skf = StratifiedGroupKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)

# OOF logits collection (will include HFlip TTA averaging)
oof_logits = np.zeros((len(train_df), 7), dtype=np.float32)
fold_scores = []

for fold in range(1, N_FOLDS + 1):
    print(f'\n=== Fold {fold}/{N_FOLDS} ===')
    train_idx, val_idx = list(skf.split(train_df, y_overall, groups))[fold-1]
    train_ds = AxialDataset(train_df.iloc[train_idx], axial_dir, train_transform)
    val_ds = AxialDataset(train_df.iloc[val_idx], axial_dir, val_transform)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

    model = timm.create_model('regnety_004', pretrained=True, num_classes=7, in_chans=3,
                              drop_rate=0.3, drop_path_rate=0.15)
    try:
        model.set_grad_checkpointing(True)
    except:
        pass
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
    criterion = nn.BCEWithLogitsLoss()
    scaler = amp.GradScaler('cuda')

    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(NUM_EPOCHS):
        # Train with AMP, float16, channels_last, grad accum
        model.train()
        train_loss = 0.0
        optimizer.zero_grad(set_to_none=True)
        for batch_idx, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            labels = labels.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits = model(imgs)
                loss = criterion(logits, labels) / ACCUM_STEPS
            scaler.scale(loss).backward()
            if (batch_idx + 1) % ACCUM_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
            train_loss += loss.item() * ACCUM_STEPS
        train_loss /= len(train_loader)

        # Val (BCE loss only, no TTA here)
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for imgs, lbls in val_loader:
                imgs = imgs.to(device, non_blocking=True).to(memory_format=torch.channels_last)
                lbls = lbls.to(device, non_blocking=True)
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    logits = model(imgs)
                    loss = criterion(logits, lbls)
                val_loss += loss.item()
        val_loss /= len(val_loader)

        scheduler.step()
        print(f'Epoch {epoch+1}: Train {train_loss:.4f}, Val {val_loss:.4f}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f'Early stopping at epoch {epoch+1}')
                break

    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        torch.save(model.state_dict(), f'fold_{fold}_regnet_axial_soft_fixed.pth')
    else:
        print(f'No best model for fold {fold}')
        continue

    # Collect true OOF logits with HFlip TTA for this fold's val
    model.eval()
    fold_oof_logits = []
    with torch.no_grad():
        for imgs, _ in val_loader:
            imgs = imgs.to(device, non_blocking=True).to(memory_format=torch.channels_last)
            with torch.cuda.amp.autocast(dtype=torch.float16):
                logits_orig = model(imgs)
                imgs_flip = torch.flip(imgs, dims=[3])  # HFlip
                logits_flip = model(imgs_flip)
                logits_avg = 0.5 * (logits_orig + logits_flip)
            fold_oof_logits.append(logits_avg.cpu().numpy())
    fold_oof = np.concatenate(fold_oof_logits, axis=0)
    oof_logits[val_idx] = fold_oof

    # Fold score (full WLL with patient_overall=max(p7)) - FIXED: Add labels=[0,1] to handle all-0 val sets
    p7 = sigmoid(fold_oof)
    fold_y7 = y[val_idx]
    fold_y_overall = fold_y7.max(axis=1).astype(int)
    p_overall = p7.max(axis=1)
    vert_losses = [log_loss(fold_y7[:, i], p7[:, i], labels=[0,1]) for i in range(7)]
    overall_loss = log_loss(fold_y_overall, p_overall, labels=[0,1])
    fold_score = np.average(vert_losses + [overall_loss], weights=[1]*7 + [2])
    fold_scores.append(fold_score)
    print(f'Fold {fold} full WLL: {fold_score:.4f}')

    # Cleanup
    del model, train_loader, val_loader, train_ds, val_ds, scaler
    gc.collect()
    torch.cuda.empty_cache()

# Save OOF logits (with TTA)
np.save('oof_logits_axial_soft_fixed.npy', oof_logits)

# Overall CV full WLL
cv_wll = np.mean(fold_scores)
print(f'5-fold CV full WLL: {cv_wll:.4f} (target ~0.45-0.48 standalone, leakage-free)')

# Also vertebrae-only for reference - FIXED: Add labels=[0,1]
p7_oof = sigmoid(oof_logits)
vert_losses = [log_loss(y[:, i], p7_oof[:, i], labels=[0,1]) for i in range(7)]
vert_wll = np.mean(vert_losses)
print(f'Vertebrae-only OOF WLL: {vert_wll:.4f}')

print('Leakage-fixed soft-tissue axial training complete. OOF logits (HFlip TTA) saved as oof_logits_axial_soft_fixed.npy. CV scores now reliable (expect higher than leaky 0.4759). Next: re-execute cell 6 in 03_inference_tta.ipynb with all fixed OOFs for true gating (baseline ~0.42 leakage-free, Swin accepted, check if axial/soft now >0.001 gain), await mip_prep finish, then edit 12_multi_channel_training.ipynb cell 0 (in_chans=6 load (6,H,W).npy transpose (H,W,6) norm [0.485]*6 B=1 VAL=1 ACCUM=8 checkpointing float16 channels_last SEED=791 fixed skf HFlip TTA dims=[3] save fold_regnet_multi_fixed.pth oof_logits_multi_fixed.npy ~2h), execute run_all, gate Multi, if passes update cell 1 in 03 for 5-way inference (add predict_multi in_chans=6 TTA X5 blend refit postproc), execute cell 1, submit.')

Using device: cuda

=== Fold 1/5 ===


  original_init(self, **validated_kwargs)


Epoch 1: Train 0.6049, Val 0.6117


Epoch 2: Train 0.4475, Val 0.5296


Epoch 3: Train 0.3755, Val 0.4277


Epoch 4: Train 0.3565, Val 0.4229


Epoch 5: Train 0.3821, Val 0.3593


Epoch 6: Train 0.3614, Val 0.3492


Epoch 7: Train 0.3296, Val 0.3502


Epoch 8: Train 0.3439, Val 0.3533


Epoch 9: Train 0.3341, Val 0.3452


Epoch 10: Train 0.3417, Val 0.3485


Epoch 11: Train 0.3319, Val 0.3461


Epoch 12: Train 0.3294, Val 0.3473


Epoch 13: Train 0.3421, Val 0.3490
Early stopping at epoch 13


Fold 1 full WLL: 0.5134

=== Fold 2/5 ===


Epoch 1: Train 0.6767, Val 0.6465


Epoch 2: Train 0.4824, Val 0.5237


Epoch 3: Train 0.4222, Val 0.4197


Epoch 4: Train 0.3875, Val 0.3618


Epoch 5: Train 0.3571, Val 0.3191


Epoch 6: Train 0.3509, Val 0.3741


Epoch 7: Train 0.3356, Val 0.3565


Epoch 8: Train 0.3470, Val 0.3021


Epoch 9: Train 0.3319, Val 0.3170


Epoch 10: Train 0.3637, Val 0.3059


Epoch 11: Train 0.3463, Val 0.3024


Epoch 12: Train 0.3587, Val 0.2992


Epoch 13: Train 0.3311, Val 0.3039


Epoch 14: Train 0.3405, Val 0.3034


Epoch 15: Train 0.3485, Val 0.3000


Fold 2 full WLL: 0.4558

=== Fold 3/5 ===


Epoch 1: Train 0.6382, Val 0.5957


Epoch 2: Train 0.4746, Val 0.5163


Epoch 3: Train 0.3883, Val 0.4677


Epoch 4: Train 0.3757, Val 0.3972


Epoch 5: Train 0.3479, Val 0.3717


Epoch 6: Train 0.3301, Val 0.3698


Epoch 7: Train 0.3494, Val 0.3701


Epoch 8: Train 0.3395, Val 0.3695


Epoch 9: Train 0.3376, Val 0.3737


Epoch 10: Train 0.3390, Val 0.3805


Epoch 11: Train 0.3397, Val 0.3774


Epoch 12: Train 0.3323, Val 0.3715
Early stopping at epoch 12


Fold 3 full WLL: 0.4771

=== Fold 4/5 ===


Epoch 1: Train 0.6504, Val 0.6075


Epoch 2: Train 0.4949, Val 0.4939


Epoch 3: Train 0.4324, Val 0.4497


Epoch 4: Train 0.3752, Val 0.3388


Epoch 5: Train 0.3483, Val 0.3202


Epoch 6: Train 0.3469, Val 0.2965


Epoch 7: Train 0.3528, Val 0.2985


Epoch 8: Train 0.3756, Val 0.2995


Epoch 9: Train 0.3338, Val 0.2965


Epoch 10: Train 0.3523, Val 0.2895


Epoch 11: Train 0.3435, Val 0.2899


Epoch 12: Train 0.3549, Val 0.2847


Epoch 13: Train 0.3673, Val 0.2892


Epoch 14: Train 0.3348, Val 0.2926


Epoch 15: Train 0.3350, Val 0.2913


ValueError: y_true contains only one label (0). Please provide the list of all expected class labels explicitly through the labels argument.