In [11]:
import os, subprocess
# Reduce fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Kill all other GPU users (keep current PID)
try:
    me = str(os.getpid())
    out = subprocess.check_output(
        ['nvidia-smi','--query-compute-apps=pid','--format=csv,noheader']
    ).decode().strip().splitlines()
    for pid in out:
        pid = pid.strip()
        if pid and pid != me:
            try: os.kill(int(pid), 9)
            except: pass
except Exception:
    pass

import subprocess
import sys
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'monai', 'scikit-image', 'pydicom', '-q'])
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from monai.networks.nets import ResNet
from monai.networks.layers import Norm
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import log_loss
from scipy.special import expit as sigmoid
import gc
import SimpleITK as sitk
from scipy.ndimage import zoom
import pydicom
import copy
from torch.utils.checkpoint import checkpoint

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
torch.backends.cudnn.benchmark = True
# Safer kernels + TF32
try:
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)
except Exception:
    pass
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Fixed-window normalization
def win_norm(v, c, w):
    lo, hi = c - w/2, c + w/2
    v = np.clip(v, lo, hi)
    return (v - lo) / (hi - lo + 1e-6)

# Precompute cached 2-channel volumes (soft + bone) with ROI crop
print('Checking/Precomputing 128^3 2-channel (soft+bone) fixed-window volumes with ROI...')
cache_dir = 'temp_3d_vols'
train_df = pd.read_csv('train.csv')
image_dir = 'train_images'
# Filter to imaged studies only
train_df = train_df[train_df['StudyInstanceUID'].apply(lambda u: os.path.isdir(os.path.join(image_dir, str(u))))].reset_index(drop=True)
label_cols = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']
y = train_df[label_cols].values.astype(np.float32)
y_overall = y.max(axis=1).astype(np.float32)
precompute_needed = True
if os.path.exists(cache_dir):
    cached_files = set(f[:-4] for f in os.listdir(cache_dir) if f.endswith('.npy'))
    uids_set = set(train_df['StudyInstanceUID'])
    if cached_files == uids_set:
        print('Cache complete, skipping precompute.')
        precompute_needed = False
if precompute_needed:
    print('Precomputing...')
    os.system('rm -rf temp_3d_vols')
    os.makedirs('temp_3d_vols', exist_ok=True)
    for uid in tqdm(train_df['StudyInstanceUID']):
        uid_dir = os.path.join(image_dir, uid)
        if not os.path.exists(uid_dir):
            print(f'Skipping missing {uid}')
            continue
        series_reader = sitk.ImageSeriesReader()
        series_ids = series_reader.GetGDCMSeriesIDs(uid_dir)
        if not series_ids:
            print(f'No series for {uid}')
            continue
        # Prefer bone-kernel series
        best_dicom_names = None
        max_slices = 0
        bone_score = 0
        for sid in series_ids:
            dicom_names = series_reader.GetGDCMSeriesFileNames(uid_dir, sid)
            if len(dicom_names) < 50: continue  # Skip thin series if too few
            # Check first DICOM for kernel
            try:
                ds = pydicom.dcmread(dicom_names[0])
                kernel = ds.get((0x0018, 0x1210), '').value if hasattr(ds.get((0x0018, 0x1210), ''), 'value') else ''
                filter_type = ds.get((0x0028, 0x0060), '').value if hasattr(ds.get((0x0028, 0x0060), ''), 'value') else ''
                is_bone = ('BONE' in str(kernel).upper() or 'B70' in str(kernel) or 'B75' in str(kernel) or 'CB' in str(filter_type))
                score = 100 if is_bone else 0
            except:
                score = 0
            score += len(dicom_names)
            if score > bone_score:
                bone_score = score
                best_dicom_names = dicom_names
                max_slices = len(dicom_names)
        if best_dicom_names is None:
            continue
        reader = sitk.ImageSeriesReader()
        reader.SetFileNames(best_dicom_names)
        img = reader.Execute()
        # Resample to 1.0mm isotropic
        orig_spacing = np.array(img.GetSpacing())[::-1]  # (Z,Y,X)
        orig_size = np.array(img.GetSize())[::-1]  # (Z,Y,X)
        new_spacing = np.array([1.0, 1.0, 1.0])
        new_size = np.round(orig_size * orig_spacing / new_spacing).astype(int)
        resampler = sitk.ResampleImageFilter()
        resampler.SetOutputSpacing(new_spacing[::-1].tolist())
        resampler.SetSize(new_size[::-1].tolist())
        resampler.SetOutputOrigin(img.GetOrigin())
        resampler.SetOutputDirection(img.GetDirection())
        resampler.SetInterpolator(sitk.sitkLinear)
        resampled_img = resampler.Execute(img)
        vol_hu = sitk.GetArrayFromImage(resampled_img)  # (Z, Y, X)
        if vol_hu.size == 0:
            print(f'Empty volume for {uid}; skipping')
            continue
        # Cervical ROI crop using bone mask
        bone_mask = (vol_hu > 200).astype(np.float32)
        z_profile = bone_mask.sum(axis=(1,2))
        total_bone = z_profile.sum()
        if total_bone > 0:
            z_com = np.average(np.arange(len(z_profile)), weights=z_profile)
            start_z = max(0, int(z_com - 64))
            end_z = min(len(vol_hu), int(z_com + 64))
            vol_roi = vol_hu[start_z:end_z]
            if len(vol_roi) < 128:
                need = 128 - len(vol_roi)
                pad0, pad1 = need // 2, need - need // 2
                vol_roi = np.pad(vol_roi, ((pad0, pad1), (0,0), (0,0)), mode='constant', constant_values=0)
        else:
            # Fallback to central crop
            z_mid = len(vol_hu) // 2
            start_z = max(0, z_mid - 64)
            end_z = min(len(vol_hu), z_mid + 64)
            vol_roi = vol_hu[start_z:end_z]
            if len(vol_roi) < 128:
                need = 128 - len(vol_roi)
                pad0, pad1 = need // 2, need - need // 2
                vol_roi = np.pad(vol_roi, ((pad0, pad1), (0,0), (0,0)), mode='constant', constant_values=0)
        # Resize to 128^3
        target_size = (128, 128, 128)
        zoom_factors = [ts / s for ts, s in zip(target_size, vol_roi.shape)]
        vol_soft = zoom(win_norm(vol_roi, 40, 400), zoom_factors, order=1)
        vol_bone = zoom(win_norm(vol_roi, 300, 1500), zoom_factors, order=1)
        volume = np.stack([vol_soft, vol_bone], 0).astype(np.float32)  # (2, Z, Y, X)
        volume = np.ascontiguousarray(volume)
        np.save(os.path.join('temp_3d_vols', f'{uid}.npy'), volume)
# Enforce cache exists
have = train_df['StudyInstanceUID'].map(lambda u: os.path.exists(os.path.join(cache_dir, f'{u}.npy')))
train_df = train_df[have].reset_index(drop=True)
y = train_df[label_cols].values.astype(np.float32)
y_overall = y.max(axis=1).astype(np.float32)
np.save('oof_uids_3d.npy', train_df['StudyInstanceUID'].values)
print('Precompute complete. Filtered to', len(train_df), 'studies.')
class RSNA3DDataset(Dataset):
    def __init__(self, df, cache_dir, is_train=True):
        self.df = df.reset_index(drop=True)
        self.cache_dir = cache_dir
        self.is_train = is_train
        self.label_cols = label_cols
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        uid = row['StudyInstanceUID']
        volume_path = os.path.join(self.cache_dir, f'{uid}.npy')
        volume = np.ascontiguousarray(np.load(volume_path))
        if self.is_train:
            # existing flips/rot90
            if np.random.rand() < 0.5:
                volume = np.ascontiguousarray(np.flip(volume, axis=2))  # Y (axis=2 for (C,Z,Y,X))
            if np.random.rand() < 0.5:
                volume = np.ascontiguousarray(np.flip(volume, axis=3))  # X (axis=3)
            if np.random.rand() < 0.5:
                k = np.random.randint(1, 4)
                volume = np.ascontiguousarray(np.rot90(volume, k=k, axes=(2, 3)))  # Y, X
            # z-shift along Z (axis=1 for (C,Z,Y,X))
            if np.random.rand() < 0.5:
                dz = np.random.randint(-8, 9)
                volume = np.roll(volume, shift=dz, axis=1)
            # intensity jitter
            if np.random.rand() < 0.8:
                alpha = 1.0 + np.random.uniform(-0.10, 0.10)
                beta  = np.random.uniform(-0.05, 0.05)
                volume = np.clip(volume * alpha + beta, 0.0, 1.0)
            # tiny Gaussian noise
            if np.random.rand() < 0.5:
                noise = np.random.normal(0, 0.01, size=volume.shape).astype(np.float32)
                volume = np.clip(volume + noise, 0.0, 1.0)
        volume = np.ascontiguousarray(volume)
        volume = volume.copy()
        volume = torch.from_numpy(volume).float()
        labels7 = row[self.label_cols].values.astype(np.float32)
        any_label = np.float32(labels7.max())
        labels = torch.tensor(np.concatenate([labels7, [any_label]]), dtype=torch.float32)
        return {'image': volume, 'label': labels}

# Helper to replace BatchNorm with GroupNorm
def replace_bn_with_gn(module, default_groups=8):
    for name, child in list(module.named_children()):
        if isinstance(child, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)):
            C = child.num_features
            g = min(default_groups, C)
            while g > 1 and C % g != 0:
                g -= 1
            setattr(module, name, nn.GroupNorm(g, C, eps=child.eps, affine=True))
        else:
            replace_bn_with_gn(child, default_groups)

# Helper to replace BatchNorm with InstanceNorm
def replace_bn_with_in(module):
    for name, child in list(module.named_children()):
        if isinstance(child, nn.BatchNorm3d):
            C = child.num_features
            inorm = nn.InstanceNorm3d(C, eps=child.eps, momentum=child.momentum, affine=True, track_running_stats=False)
            setattr(module, name, inorm)
        else:
            replace_bn_with_in(child)

# Checkpointed ResNet
class CheckpointedResNet(ResNet):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x); x = self.bn1(x); x = self.act(x)
        if not self.no_max_pool: x = self.maxpool(x)
        x = checkpoint(self.layer1, x, use_reentrant=False)
        x = checkpoint(self.layer2, x, use_reentrant=False)
        x = checkpoint(self.layer3, x, use_reentrant=False)
        x = checkpoint(self.layer4, x, use_reentrant=False)
        x = self.avgpool(x); x = torch.flatten(x, 1); x = self.fc(x)
        return x

# Model: MONAI 3D ResNet18 with GroupNorm3d (affine=True) and dropout
def build_3d_resnet(use_gn=True):
    model = CheckpointedResNet(
        spatial_dims=3,
        n_input_channels=2,
        num_classes=8,                        # 8 outputs (C1..C7 + any)
        block='basic',
        layers=(2, 2, 2, 2),                 # ResNet18 depth
        block_inplanes=(48, 96, 192, 384),  # reduced for memory
        norm=Norm.BATCH,
    )
    if use_gn:
        replace_bn_with_gn(model, default_groups=8)   # recommended
    else:
        replace_bn_with_in(model)                     # InstanceNorm(affine=True)
    in_features = model.fc.in_features
    model.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_features, 8))
    model = model.to(device)  # move AFTER replacements
    return model

# Training params (ACCUM_STEPS=8 for memory safety)
N_FOLDS = 5
BATCH_SIZE = 1
VAL_BATCH_SIZE = 1
NUM_EPOCHS = 40
LR = 3e-4
PATIENCE = 12
SEED = 792
ACCUM_STEPS = 8
torch.manual_seed(SEED)
np.random.seed(SEED)

# Data
y_overall = y.max(axis=1).astype(int)
groups = train_df['StudyInstanceUID'].values
cache_dir = 'temp_3d_vols'

# Compute clamped pos_weight
col_sums7 = y.sum(axis=0)
pos_weight7 = (len(train_df) - col_sums7) / np.clip(col_sums7, 1, None)
pos_weight7 = np.clip(pos_weight7, 1, 4)
pos_any = (len(train_df) - y_overall.sum()) / max(y_overall.sum(), 1.0)
pos_any = np.clip(pos_any, 1, 4)
pos_weight = torch.tensor(np.concatenate([pos_weight7, [pos_any]]), dtype=torch.float32).to(device)

# SGKF
skf = StratifiedGroupKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)

# OOF logits
oof_logits = np.zeros((len(train_df), 8), dtype=np.float32)
fold_scores = []

# LR lambda for warmup + cosine
def lr_lambda(epoch):
    if epoch < 5:
        return (epoch + 1) / 5.0
    else:
        return 0.5 * (1 + np.cos(np.pi * (epoch - 5) / max(1, NUM_EPOCHS - 5)))

for fold in range(1, N_FOLDS + 1):
    print(f'\n=== Fold {fold}/{N_FOLDS} ===')
    train_idx, val_idx = list(skf.split(train_df, y_overall, groups))[fold - 1]
    train_ds = RSNA3DDataset(train_df.iloc[train_idx], cache_dir, is_train=True)
    val_ds = RSNA3DDataset(train_df.iloc[val_idx], cache_dir, is_train=False)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=VAL_BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
    model = build_3d_resnet(use_gn=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    scaler = None  # No GradScaler for bfloat16
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None
    for epoch in range(NUM_EPOCHS):
        model.train()
        train_loss = 0.0
        optimizer.zero_grad(set_to_none=True)
        for batch_idx, batch in enumerate(train_loader):
            images = batch['image'].to(device, non_blocking=True)
            labels = batch['label'].to(device, non_blocking=True)
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                logits = model(images)
                loss = criterion(logits, labels) / ACCUM_STEPS
            loss.backward()
            if (batch_idx + 1) % ACCUM_STEPS == 0 or (batch_idx + 1) == len(train_loader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
            train_loss += loss.item() * ACCUM_STEPS
        train_loss /= len(train_loader)
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                images = batch['image'].to(device, non_blocking=True)
                labels = batch['label'].to(device, non_blocking=True)
                with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                    logits = model(images)
                    loss = criterion(logits, labels)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        scheduler.step()
        print(f'Epoch {epoch+1}: Train {train_loss:.4f}, Val {val_loss:.4f}')
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            best_model_state = copy.deepcopy({k: v.detach().cpu() for k, v in model.state_dict().items()})
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f'Early stopping at epoch {epoch+1}')
                break
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        torch.save(model.state_dict(), f'fold_{fold}_3d_resnet.pth')
    else:
        print(f'No best model for fold {fold}')
        continue
    # OOF with 3-way TTA: orig + flips on dims 3,4 (Y,X)
    model.eval()
    fold_oof_logits = []
    with torch.no_grad():
        for batch in val_loader:
            images = batch['image'].to(device, non_blocking=True)
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                logits_list = [model(images)]
                # Flip Y (dim=3)
                logits_list.append(model(torch.flip(images, dims=[3])))
                # Flip X (dim=4)
                logits_list.append(model(torch.flip(images, dims=[4])))
                logits_avg = torch.stack(logits_list).mean(0)
            fold_oof_logits.append(logits_avg.cpu().numpy())
    fold_oof = np.concatenate(fold_oof_logits, axis=0)
    oof_logits[val_idx] = fold_oof
    # Fold WLL
    p8 = sigmoid(fold_oof)
    fold_y7 = y[val_idx]
    fold_y_any = y_overall[val_idx].astype(int)
    vert_losses = [log_loss(fold_y7[:, i], p8[:, i], labels=[0,1]) for i in range(7)]
    overall_loss = log_loss(fold_y_any, p8[:, 7], labels=[0,1])
    fold_score = np.average(vert_losses + [overall_loss], weights=[1]*7 + [2])
    fold_scores.append(fold_score)
    print(f'Fold {fold} full WLL: {fold_score:.4f}')
    del model, train_loader, val_loader, train_ds, val_ds
    gc.collect()
    torch.cuda.empty_cache()

# Save OOF
np.save('oof_logits_3d_2ch.npy', oof_logits)

# CV WLL
cv_wll = np.mean(fold_scores)
print(f'5-fold CV full WLL: {cv_wll:.4f} (target ~0.45, leakage-free)')
p8_oof = sigmoid(oof_logits)
vert_losses = [log_loss(y[:, i], p8_oof[:, i], labels=[0,1]) for i in range(7)]
overall_loss = log_loss(y_overall.astype(int), p8_oof[:, 7], labels=[0,1])
cv_wll = np.average(vert_losses + [overall_loss], weights=[1]*7 + [2])
print('Vertebrae-only OOF WLL:', np.mean(vert_losses))
print('''Improved 3D ResNet18 2ch with GroupNorm3d (affine=True), ROI crop, 1.0mm resample, bone series pref, clamped pos_weight, LR warmup+cosine, dropout=0.2, 8-heads, reduced channels, augs, bfloat16, checkpointing complete. OOF oof_logits_3d_2ch.npy (N,8). Next: execute 03 cell 8 gate vs 4-way 0.4197 (expect CV~0.44 gain>0.001 accept), if yes execute cell 1 5-way submission OOF<=0.41, submit_final_answer medal.''')

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 2.4.1+cu121 requires torch==2.4.1, but you have torch 2.8.0 which is incompatible.
datasets 2.21.0 requires fsspec[http]<=2024.6.1,>=2023.1.0, but you have fsspec 2025.9.0 which is incompatible.


Using device: cuda
Checking/Precomputing 128^3 2-channel (soft+bone) fixed-window volumes with ROI...
Cache complete, skipping precompute.
Precompute complete. Filtered to 202 studies.

=== Fold 1/5 ===


OutOfMemoryError: CUDA out of memory. Tried to allocate 384.00 MiB. GPU 0 has a total capacity of 23.72 GiB of which 398.12 MiB is free. Process 42479 has 3.88 GiB memory in use. Process 343136 has 565.00 MiB memory in use. Process 388383 has 6.64 GiB memory in use. Process 406842 has 2.90 GiB memory in use. Process 431889 has 539.00 MiB memory in use. Process 455298 has 1.20 GiB memory in use. Process 461173 has 741.00 MiB memory in use. Process 728162 has 589.00 MiB memory in use. Process 977910 has 833.00 MiB memory in use. Process 1198413 has 407.00 MiB memory in use. Process 1472380 has 2.82 GiB memory in use. Of the allocated memory 2.26 GiB is allocated by PyTorch, and 221.64 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)