# Baseline 2D CNN for RSNA 2022 Cervical Spine Fracture Detection

This notebook implements a baseline EfficientNet-B0 model on preprocessed MIP images.
Key changes based on expert review:
- Predict 7 vertebrae (C1-C7) only; derive patient_overall = max(C1-C7) for metric and submission.
- Metric weights: [1,1,1,1,1,1,1,2] for C1-C7 + overall.
- CV: MultilabelStratifiedKFold on C1-C7 labels.
- Normalization: mean=0.5, std=0.5 (since bone-windowed MIPs).
- Loss: BCEWithLogitsLoss with per-fold pos_weight.
- Augs: Flips, ShiftScaleRotate, BrightnessContrast, Gamma, GaussianNoise.
- Training: AdamW lr=1e-4, CosineAnnealingLR, batch=16, AMP, 15 epochs, early stop patience=5.
- TTA: Horizontal flip only (no VerticalFlip).

In [1]:
import sys
import subprocess

# Install required packages using subprocess to bypass !pip issues
subprocess.run([sys.executable, '-m', 'pip', 'install', 'iterative-stratification'], check=True)
subprocess.run([sys.executable, '-m', 'pip', 'install', '--upgrade', 'albumentations'], check=True)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler
import torch.nn.utils as nn_utils
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pandas as pd
import numpy as np
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import os
from tqdm import tqdm
import time
import warnings
warnings.filterwarnings('ignore')

# Seeding for reproducibility
def seed_everything(seed=42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
seed_everything(42)

# GPU Check
print(subprocess.run(['nvidia-smi'], capture_output=True, text=True).stdout)
assert torch.cuda.is_available(), 'CUDA not available'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print(f'GPU: {torch.cuda.get_device_name(0)}')

# Load and subsample data for initial debug
train_df = pd.read_csv('data/train_mips.csv')
print(f'Full train size: {len(train_df)}')

# Subsample 50 for quick debug
subsample_df = train_df.sample(n=50, random_state=42).reset_index(drop=True)
print(f'Subsample size: {len(subsample_df)}')

label_cols = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']
y_train = subsample_df[label_cols].values
print('Label distribution:')
print(y_train.sum(axis=0))

# Compute initial pos_weight for imbalance
pos = (y_train == 1).sum(axis=0)
neg = (y_train == 0).sum(axis=0)
pos_weight = torch.tensor(neg / (pos + 1e-6), dtype=torch.float32).to(device)
print('Pos weights:', pos_weight.cpu().numpy())

Collecting iterative-stratification
  Downloading iterative_stratification-0.1.9-py3-none-any.whl (8.5 kB)


Collecting scipy
  Downloading scipy-1.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.9/35.9 MB 280.4 MB/s eta 0:00:00


Collecting numpy
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 341.4 MB/s eta 0:00:00


Collecting scikit-learn
  Downloading scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.7/9.7 MB 166.3 MB/s eta 0:00:00
Collecting threadpoolctl>=3.1.0
  Downloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Collecting joblib>=1.2.0
  Downloading joblib-1.5.2-py3-none-any.whl (308 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 308.4/308.4 KB 471.5 MB/s eta 0:00:00


Installing collected packages: threadpoolctl, numpy, joblib, scipy, scikit-learn, iterative-stratification


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 2.21.0 requires fsspec[http]<=2024.6.1,>=2023.1.0, but you have fsspec 2025.9.0 which is incompatible.


Successfully installed iterative-stratification-0.1.9 joblib-1.5.2 numpy-1.26.4 scikit-learn-1.7.2 scipy-1.16.2 threadpoolctl-3.6.0


Collecting albumentations
  Downloading albumentations-2.0.8-py3-none-any.whl (369 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 369.4/369.4 KB 11.5 MB/s eta 0:00:00


Collecting scipy>=1.10.0
  Downloading scipy-1.16.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.9 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.9/35.9 MB 212.8 MB/s eta 0:00:00
Collecting opencv-python-headless>=4.9.0.80
  Downloading opencv_python_headless-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (54.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.0/54.0 MB 170.8 MB/s eta 0:00:00


Collecting numpy>=1.24.4
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18.3/18.3 MB 205.4 MB/s eta 0:00:00


Collecting pydantic>=2.9.2
  Downloading pydantic-2.11.9-py3-none-any.whl (444 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 444.9/444.9 KB 517.1 MB/s eta 0:00:00
Collecting albucore==0.0.24
  Downloading albucore-0.0.24-py3-none-any.whl (15 kB)
Collecting PyYAML
  Downloading pyyaml-6.0.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (806 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 806.6/806.6 KB 330.0 MB/s eta 0:00:00


Collecting stringzilla>=3.10.4
  Downloading stringzilla-4.0.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (496 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 496.5/496.5 KB 465.0 MB/s eta 0:00:00


Collecting simsimd>=5.9.2
  Downloading simsimd-6.5.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (1.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.1/1.1 MB 220.2 MB/s eta 0:00:00
Collecting opencv-python-headless>=4.9.0.80
  Downloading opencv_python_headless-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50.0 MB)


     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 50.0/50.0 MB 202.0 MB/s eta 0:00:00
Collecting typing-extensions>=4.12.2
  Downloading typing_extensions-4.15.0-py3-none-any.whl (44 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 44.6/44.6 KB 294.5 MB/s eta 0:00:00


Collecting pydantic-core==2.33.2
  Downloading pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 266.8 MB/s eta 0:00:00
Collecting typing-inspection>=0.4.0
  Downloading typing_inspection-0.4.1-py3-none-any.whl (14 kB)
Collecting annotated-types>=0.6.0
  Downloading annotated_types-0.7.0-py3-none-any.whl (13 kB)


Installing collected packages: simsimd, typing-extensions, stringzilla, PyYAML, numpy, annotated-types, typing-inspection, scipy, pydantic-core, opencv-python-headless, pydantic, albucore, albumentations


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.8.0 requires nvidia-nvjitlink-cu12==12.8.93; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-nvjitlink-cu12 12.9.86 which is incompatible.
datasets 2.21.0 requires fsspec[http]<=2024.6.1,>=2023.1.0, but you have fsspec 2025.9.0 which is incompatible.


Successfully installed PyYAML-6.0.3 albucore-0.0.24 albumentations-2.0.8 annotated-types-0.7.0 numpy-1.26.4 opencv-python-headless-4.11.0.86 pydantic-2.11.9 pydantic-core-2.33.2 scipy-1.16.2 simsimd-6.5.3 stringzilla-4.0.14 typing-extensions-4.15.0 typing-inspection-0.4.1


Fri Sep 26 05:54:40 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.06             Driver Version: 550.144.06     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A10-24Q                 On  |   00000002:00:00.0 Off |                    0 |
| N/A   N/A    P0             N/A /  N/A  |     528MiB /  24512MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [9]:
class MIPDataset(Dataset):
    def __init__(self, df, mip_dir, label_cols, transform=None):
        self.df = df
        self.mip_dir = mip_dir
        self.label_cols = label_cols
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        uid = self.df.iloc[idx]['StudyInstanceUID']
        mip_path = os.path.join(self.mip_dir, f'{uid}.npy')
        mip = np.load(mip_path).astype(np.float32)  # (3, 384, 384)
        image = np.transpose(mip, (1, 2, 0))  # (384, 384, 3) for albumentations

        labels = self.df.iloc[idx][self.label_cols].values.astype(np.float32)
        label = torch.tensor(labels, dtype=torch.float32)

        if self.transform:
            transformed = self.transform(image=image)
            image = transformed['image']
        else:
            # Manual normalize if no transform
            image = (image - 0.5) / 0.5
            image = torch.from_numpy(image).permute(2, 0, 1).float()

        return image, label

# Define transforms (workaround for albumentations: RandomCrop + Resize instead of RandomResizedCrop, light Affine augs)
train_transform = A.Compose([
    A.RandomCrop(height=384, width=384, p=0.5),
    A.Resize(height=384, width=384, p=1.0),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.7),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.GaussNoise(var_limit=(1e-5, 5e-4), p=0.1),  # Fixed for [0,1] float images
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(height=384, width=384, p=1.0),
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2()
])

print('Dataset class and transforms fixed with RandomCrop + Resize workaround successfully')

Dataset class and transforms fixed with RandomCrop + Resize workaround successfully


In [3]:
def weighted_log_loss(y_true, y_pred):
    # Clamp predictions to avoid log(0)
    y_pred = np.clip(y_pred, 1e-6, 1 - 1e-6)
    
    # Log loss for each class
    epsilon = 1e-15
    y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
    loss = -y_true * np.log(y_pred) - (1 - y_true) * np.log(1 - y_pred)
    
    # Weights: 1 for vertebrae, 2 for overall
    weights = np.array([1.0] * 7 + [2.0])
    
    # Weighted loss: average per class then weighted average
    return np.sum(loss.mean(0) * weights) / weights.sum()

def weighted_log_loss_torch(y_true_8, y_pred_8):
    # Clamp predictions to avoid log(0)
    y_pred_8 = torch.clamp(y_pred_8, 1e-6, 1 - 1e-6)
    
    # Log loss for each class
    epsilon = 1e-15
    y_pred_8 = torch.clamp(y_pred_8, epsilon, 1 - epsilon)
    loss = -y_true_8 * torch.log(y_pred_8) - (1 - y_true_8) * torch.log(1 - y_pred_8)
    
    # Weights: 1 for vertebrae, 2 for overall
    weights = torch.tensor([1.0] * 7 + [2.0], device=y_true_8.device, dtype=torch.float32)
    
    # Weighted loss: average per class then weighted average
    return torch.sum(loss.mean(0) * weights) / weights.sum()

# Test the metric
y_dummy = np.array([[0,1,0,0,0,0,0,0], [1,0,0,0,0,0,0,1]])
p_dummy = np.array([[0.1,0.9,0.1,0.1,0.1,0.1,0.1,0.1], [0.9,0.1,0.1,0.1,0.1,0.1,0.1,0.9]])
print(f'Dummy weighted log loss: {weighted_log_loss(y_dummy, p_dummy):.4f}')

print('Metrics updated with clamping, weight=2, and correct normalization')

Dummy weighted log loss: 0.1054
Metrics updated with clamping, weight=2, and correct normalization


In [4]:
import torch.nn.functional as F

class SmoothedBCEWithLogitsLoss(nn.Module):
    def __init__(self, pos_weight=None, smoothing=0.0):
        super(SmoothedBCEWithLogitsLoss, self).__init__()
        self.pos_weight = pos_weight
        self.smoothing = smoothing

    def forward(self, logits, targets):
        num_labels = logits.size(1)
        smoothing_value = self.smoothing / num_labels
        soft_targets = targets * (1.0 - self.smoothing) + smoothing_value
        loss = F.binary_cross_entropy_with_logits(logits, soft_targets, pos_weight=self.pos_weight, reduction='mean')
        return loss

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
            if self.counter >= self.patience:
                if self.restore_best_weights:
                    model.load_state_dict(self.best_weights)
                return True
        return False

    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

print('Custom SmoothedBCEWithLogitsLoss and EarlyStopping defined successfully')

Custom SmoothedBCEWithLogitsLoss and EarlyStopping defined successfully


In [5]:
class Model(nn.Module):
    def __init__(self, num_classes=7):  # 7 classes for C1-C7
        super().__init__()
        self.model = timm.create_model('tf_efficientnet_b0_ns', pretrained=True, num_classes=num_classes, in_chans=3, drop_rate=0.3, drop_path_rate=0.1)

    def forward(self, x):
        return self.model(x)

# Create and test model
model = Model(num_classes=7).to(device)
print('Model created and moved to device')
total_params = sum(p.numel() for p in model.parameters())
print(f'Model parameters: {total_params:,}')

# Dummy forward pass
dummy_input = torch.randn(1, 3, 384, 384).to(device)
with torch.no_grad():
    dummy_logits = model(dummy_input)
print(f'Dummy logits shape: {dummy_logits.shape}')  # Should be [1, 7]

print('7-class model with dropout=0.3 and drop_path=0.1 defined successfully')

Model created and moved to device
Model parameters: 4,016,515
Dummy logits shape: torch.Size([1, 7])
7-class model with dropout=0.3 and drop_path=0.1 defined successfully


In [None]:
# Debug training on subsample (40 train / 10 val) - 7-class version
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

# Split subsample
train_idx, val_idx = train_test_split(range(len(subsample_df)), test_size=0.2, random_state=42, stratify=None)
train_split_df = subsample_df.iloc[train_idx].reset_index(drop=True)
val_split_df = subsample_df.iloc[val_idx].reset_index(drop=True)

label_cols = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']  # 7 classes

# Compute pos_weight for 7 classes only
y_train_split = train_split_df[label_cols].values
pos = (y_train_split == 1).sum(axis=0)
neg = (y_train_split == 0).sum(axis=0)
w = np.minimum(neg / (pos + 1e-6), 10.0).astype(np.float32)
fold_pos_weight = torch.tensor(w, dtype=torch.float32).to(device)
print('Fold pos_weight (7 classes, clip=10):', fold_pos_weight.cpu().numpy())

mip_dir = 'data/mips/train'

train_ds = MIPDataset(train_split_df, mip_dir, label_cols, train_transform)
val_ds = MIPDataset(val_split_df, mip_dir, label_cols, val_transform)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)

# Model (7 classes)
model = Model(num_classes=7).to(device)
# Bias initialization for 7-class classifier
with torch.no_grad():
    priors = np.clip((y_train_split.mean(axis=0) + 1e-6), 1e-6, 1-1e-6)
    bias = np.log(priors / (1 - priors)).astype(np.float32)
    head = model.model.classifier
    if hasattr(head, 'bias') and head.bias is not None:
        head.bias.copy_(torch.from_numpy(bias).to(head.bias.device))
print('Bias initialized for 7 classes')

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=15)
criterion = SmoothedBCEWithLogitsLoss(pos_weight=fold_pos_weight, smoothing=0.02)
scaler = GradScaler()
early_stopping = EarlyStopping(patience=5, min_delta=0.001)

# Training loop
num_epochs = 15
oof_preds_7 = np.zeros((len(val_split_df), 7))
oof_overall = np.zeros(len(val_split_df))
oof_labels = val_split_df[label_cols].values

for epoch in range(num_epochs):
    start_time = time.time()
    model.train()
    train_loss = 0.0
    for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} Train')):
        images = images.to(device)
        labels = labels.to(device)  # [bs, 7]
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            logits = model(images)  # [bs, 7]
            loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        nn_utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item() * images.size(0)
    train_loss /= len(train_ds)
    scheduler.step()

    # Validation
    model.eval()
    val_loss = 0.0
    val_weighted_loss = 0.0
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(tqdm(val_loader, desc='Val')):
            images = images.to(device)
            labels7 = labels.to(device)
            logits = model(images)
            bce_loss = criterion(logits, labels7)
            val_loss += bce_loss.item() * images.size(0)
            probs7 = torch.sigmoid(logits)
            probs_overall = probs7.max(dim=1, keepdim=True)[0]
            labels_overall = labels7.max(dim=1, keepdim=True)[0]
            probs8 = torch.cat([probs7, probs_overall], dim=1)
            labels8 = torch.cat([labels7, labels_overall], dim=1)
            weighted_l = weighted_log_loss_torch(labels8, probs8)
            val_weighted_loss += weighted_l.item() * images.size(0)
            bs = images.size(0)
            oof_preds_7[batch_idx*bs:(batch_idx+1)*bs] = probs7.cpu().numpy()
            oof_overall[batch_idx*bs:(batch_idx+1)*bs] = probs_overall.squeeze(1).cpu().numpy()
    val_loss /= len(val_ds)
    val_weighted_loss /= len(val_ds)
    epoch_time = time.time() - start_time
    print(f'Epoch {epoch+1}: Train BCE={train_loss:.4f}, Val BCE={val_loss:.4f}, Val Weighted Log Loss={val_weighted_loss:.4f}, Time={epoch_time:.1f}s')

    # Early stopping on weighted loss
    if early_stopping(val_weighted_loss, model):
        print(f'Early stopping at epoch {epoch+1}')
        break

# After early stopping, best weights are restored. Run final val pass to collect OOF from best weights
model.eval()
oof_preds_7.fill(0)
oof_overall.fill(0)
with torch.no_grad():
    for batch_idx, (images, labels) in enumerate(tqdm(val_loader, desc='Final Val for OOF')):
        images = images.to(device)
        labels7 = labels.to(device)
        logits = model(images)
        probs7 = torch.sigmoid(logits)
        probs_overall = probs7.max(dim=1, keepdim=True)[0]
        bs = images.size(0)
        oof_preds_7[batch_idx*bs:(batch_idx+1)*bs] = probs7.cpu().numpy()
        oof_overall[batch_idx*bs:(batch_idx+1)*bs] = probs_overall.squeeze(1).cpu().numpy()

# Final OOF weighted loss
oof_preds_8 = np.column_stack([oof_preds_7, oof_overall])
oof_labels_8 = np.column_stack([oof_labels, np.max(oof_labels, axis=1)])
final_oof_loss = weighted_log_loss(oof_labels_8, oof_preds_8)
print(f'Final OOF Weighted Log Loss on val: {final_oof_loss:.4f}')

# Save model
torch.save(model.state_dict(), 'debug_model.pth')
print('Debug 7-class training completed, model saved')

In [10]:
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
from torch.utils.data import WeightedRandomSampler

# Full 5-fold CV on entire training set - 7-class version (tweaked for poor CV: smoothing=0.0, lr=2e-5, sampler cap=3.0, pos_weight clip=2.0)
full_train_df = pd.read_csv('data/train_mips.csv')
print(f'Full CV on {len(full_train_df)} samples')

n_folds = 5
label_cols = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']
y_full = full_train_df[label_cols].values
mip_dir = 'data/mips/train'

skf = MultilabelStratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
splits = list(skf.split(full_train_df, y_full))

oof_preds_full = np.zeros((len(full_train_df), 7), dtype=np.float32)
oof_overall_full = np.zeros(len(full_train_df), dtype=np.float32)
fold_scores = []

for fold, (train_idx, val_idx) in enumerate(splits):
    print(f'\n=== Fold {fold+1}/{n_folds} ===')
    train_fold_df = full_train_df.iloc[train_idx].reset_index(drop=True)
    val_fold_df = full_train_df.iloc[val_idx].reset_index(drop=True)
    print(f'Train: {len(train_fold_df)}, Val: {len(val_fold_df)}')

    # Fold pos_weight (clip=2.0)
    y_train_fold = train_fold_df[label_cols].values
    pos = (y_train_fold == 1).sum(axis=0)
    neg = (y_train_fold == 0).sum(axis=0)
    w = np.minimum(neg / (pos + 1e-6), 2.0).astype(np.float32)
    fold_pos_weight = torch.tensor(w, dtype=torch.float32).to(device)
    print('Fold pos_weight (clip=2):', w)

    # WeightedRandomSampler: per-sample weight from class rarity (multi-label safe, cap=3.0)
    class_rarity = np.minimum(neg / (pos + 1e-6), 10.0).astype(np.float32)
    sample_weight = (y_train_fold * class_rarity).max(axis=1)
    sample_weight = np.where(sample_weight > 0, sample_weight, 1.0)
    sample_weight = np.sqrt(sample_weight)              # soften extremes
    sample_weight = np.clip(sample_weight, 1.0, 3.0)    # cap reduced to 3.0
    sampler = WeightedRandomSampler(
        weights=torch.from_numpy(sample_weight),
        num_samples=len(sample_weight),
        replacement=True
    )

    # Datasets/loaders
    train_ds = MIPDataset(train_fold_df, mip_dir, label_cols, train_transform)
    val_ds = MIPDataset(val_fold_df, mip_dir, label_cols, val_transform)
    train_loader = DataLoader(train_ds, batch_size=16, sampler=sampler, num_workers=4, pin_memory=True, drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)

    # Model with dropout
    model = Model(num_classes=7).to(device)

    # Bias init from priors
    with torch.no_grad():
        priors = np.clip(y_train_fold.mean(axis=0) + 1e-6, 1e-6, 1-1e-6)
        bias = np.log(priors / (1 - priors)).astype(np.float32)
        head = model.model.classifier
        if hasattr(head, 'bias') and head.bias is not None:
            head.bias.copy_(torch.from_numpy(bias).to(head.bias.device))
    print('Bias initialized for 7 classes')

    optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True, min_lr=1e-6)
    criterion = SmoothedBCEWithLogitsLoss(pos_weight=fold_pos_weight, smoothing=0.0)
    scaler = GradScaler()
    early_stopping = EarlyStopping(patience=5, min_delta=0.001)

    num_epochs = 15
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} Train'):
            images = images.to(device)
            labels7 = labels.to(device)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast():
                logits = model(images)
                loss = criterion(logits, labels7)
            scaler.scale(loss).backward()
            nn_utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            train_loss += loss.item() * images.size(0)
        train_loss /= len(train_ds)

        # Validation
        model.eval()
        val_bce = 0.0
        val_weighted_loss = 0.0
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc='Val'):
                images = images.to(device)
                labels7 = labels.to(device)
                logits = model(images)
                bce_loss = criterion(logits, labels7)
                val_bce += bce_loss.item() * images.size(0)

                probs7 = torch.sigmoid(logits)
                probs_overall = probs7.max(dim=1, keepdim=True)[0]
                labels_overall = labels7.max(dim=1, keepdim=True)[0]
                probs8 = torch.cat([probs7, probs_overall], dim=1)
                labels8 = torch.cat([labels7, labels_overall], dim=1)
                weighted_l = weighted_log_loss_torch(labels8, probs8)
                val_weighted_loss += weighted_l.item() * images.size(0)
        val_bce /= len(val_ds)
        val_weighted_loss /= len(val_ds)

        scheduler.step(val_weighted_loss)
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Epoch {epoch+1}: Train BCE={train_loss:.4f}, Val BCE={val_bce:.4f}, Val WLL={val_weighted_loss:.4f}, LR={current_lr:.2e}')

        if early_stopping(val_weighted_loss, model):
            print(f'Early stopping at epoch {epoch+1}')
            break

    # Final val pass with best weights -> collect OOF (fixed indexing)
    model.eval()
    fold_oof_preds_7 = np.zeros((len(val_fold_df), 7), dtype=np.float32)
    fold_oof_overall = np.zeros(len(val_fold_df), dtype=np.float32)
    start = 0
    with torch.no_grad():
        for images, _ in tqdm(val_loader, desc='Final Val for OOF'):
            images = images.to(device)
            logits = model(images)
            probs7 = torch.sigmoid(logits).cpu().numpy()
            probs_overall = probs7.max(axis=1)
            bs = images.size(0)
            fold_oof_preds_7[start:start+bs] = probs7
            fold_oof_overall[start:start+bs] = probs_overall
            start += bs

    # Fold OOF score
    fold_oof_labels = val_fold_df[label_cols].values
    fold_oof_preds_8 = np.column_stack([fold_oof_preds_7, fold_oof_overall])
    fold_oof_labels_8 = np.column_stack([fold_oof_labels, np.max(fold_oof_labels, axis=1)])
    fold_score = weighted_log_loss(fold_oof_labels_8, fold_oof_preds_8)
    print(f'Fold {fold+1} OOF Weighted Log Loss: {fold_score:.4f}')
    fold_scores.append(fold_score)

    # Save into global OOF
    oof_preds_full[val_idx] = fold_oof_preds_7
    oof_overall_full[val_idx] = fold_oof_overall

    torch.save(model.state_dict(), f'fold_{fold+1}_model.pth')

# Overall CV
cv_mean = float(np.mean(fold_scores))
cv_std = float(np.std(fold_scores))
print(f'\nCV Mean Weighted Log Loss: {cv_mean:.4f} +/- {cv_std:.4f}')

# Full OOF score
oof_preds_8_full = np.column_stack([oof_preds_full, oof_overall_full])
oof_labels_8_full = np.column_stack([y_full, np.max(y_full, axis=1)])
full_oof_loss = weighted_log_loss(oof_labels_8_full, oof_preds_8_full)
print(f'Full OOF Weighted Log Loss: {full_oof_loss:.4f}')

# Save OOF
oof_df = full_train_df[['StudyInstanceUID']].copy()
oof_df[label_cols] = oof_preds_full
oof_df['patient_overall'] = oof_overall_full
oof_df.to_csv('oof_predictions.csv', index=False)
print('Full 7-class CV completed with tweaks, OOF saved to oof_predictions.csv')

Full CV on 202 samples

=== Fold 1/5 ===
Train: 162, Val: 40
Fold pos_weight (clip=2): [2. 2. 2. 2. 2. 2. 2.]


Bias initialized for 7 classes


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.46it/s]

Epoch 1/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.55it/s]

Epoch 1/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.83it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.98it/s]

Epoch 1/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.40it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.97it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.77it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.29it/s]




Epoch 1: Train BCE=0.9521, Val BCE=0.5961, Val WLL=0.4606, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.33it/s]

Epoch 2/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.40it/s]

Epoch 2/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.68it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.86it/s]

Epoch 2/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.31it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.77it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  5.06it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.78it/s]




Epoch 2: Train BCE=0.8939, Val BCE=0.5939, Val WLL=0.4423, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.32it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.41it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.71it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.87it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.32it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.80it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.74it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.31it/s]




Epoch 3: Train BCE=0.9399, Val BCE=0.5784, Val WLL=0.4463, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.37it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.53it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.80it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.94it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.39it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.86it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.75it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.25it/s]




Epoch 4: Train BCE=0.8447, Val BCE=0.5730, Val WLL=0.4541, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.41it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.44it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.79it/s]




Epoch 5: Train BCE=0.8531, Val BCE=0.5979, Val WLL=0.4461, LR=1.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.48it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.66it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.93it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 14.07it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.50it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:00<00:00, 13.07it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.73it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.38it/s]




Epoch 6: Train BCE=0.7937, Val BCE=0.6803, Val WLL=0.4768, LR=1.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.26it/s]

Epoch 7/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.27it/s]

Epoch 7/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.56it/s]

Epoch 7/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.74it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.22it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.66it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.61it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.28it/s]




Epoch 7: Train BCE=0.8138, Val BCE=0.7125, Val WLL=0.4953, LR=1.00e-05
Early stopping at epoch 7


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.48it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00, 10.02it/s]




Fold 1 OOF Weighted Log Loss: 0.4953

=== Fold 2/5 ===
Train: 162, Val: 40
Fold pos_weight (clip=2): [2. 2. 2. 2. 2. 2. 2.]
Bias initialized for 7 classes


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.35it/s]

Epoch 1/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.22it/s]

Epoch 1/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.52it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.72it/s]

Epoch 1/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.20it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.64it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.75it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.40it/s]




Epoch 1: Train BCE=1.0084, Val BCE=0.7441, Val WLL=0.5396, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.13it/s]

Epoch 2/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.11it/s]

Epoch 2/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.40it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.61it/s]

Epoch 2/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.06it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.44it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.65it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.23it/s]




Epoch 2: Train BCE=0.8670, Val BCE=0.7384, Val WLL=0.5237, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.23it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.24it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.52it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.71it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.20it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.65it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.66it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.33it/s]




Epoch 3: Train BCE=0.8928, Val BCE=0.7617, Val WLL=0.5541, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.23it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.24it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.53it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.71it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.20it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.59it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.69it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.10it/s]




Epoch 4: Train BCE=0.9240, Val BCE=0.8199, Val WLL=0.6144, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.31it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.44it/s]

Epoch 5/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.73it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.91it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.35it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.78it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.48it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.02it/s]




Epoch 5: Train BCE=0.8947, Val BCE=0.7751, Val WLL=0.5643, LR=1.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.13it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.07it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.37it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.61it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.13it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.50it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.66it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.34it/s]




Epoch 6: Train BCE=0.8611, Val BCE=0.7696, Val WLL=0.5425, LR=1.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.26it/s]

Epoch 7/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.30it/s]

Epoch 7/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.61it/s]

Epoch 7/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.79it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.26it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.69it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.47it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.06it/s]




Epoch 7: Train BCE=0.7570, Val BCE=0.7886, Val WLL=0.5586, LR=1.00e-05
Early stopping at epoch 7


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.77it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00, 10.31it/s]




Fold 2 OOF Weighted Log Loss: 0.5586

=== Fold 3/5 ===
Train: 161, Val: 41
Fold pos_weight (clip=2): [2. 2. 2. 2. 2. 2. 2.]


Bias initialized for 7 classes


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.36it/s]

Epoch 1/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.26it/s]

Epoch 1/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.60it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.76it/s]

Epoch 1/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.24it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.63it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.60it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.23it/s]




Epoch 1: Train BCE=0.9149, Val BCE=0.6756, Val WLL=0.5185, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.25it/s]

Epoch 2/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.28it/s]

Epoch 2/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.56it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.75it/s]

Epoch 2/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.22it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.59it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.54it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.05it/s]




Epoch 2: Train BCE=0.8770, Val BCE=0.7068, Val WLL=0.5273, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.37it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.46it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.72it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.87it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.32it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.71it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.58it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.16it/s]




Epoch 3: Train BCE=0.9271, Val BCE=0.9454, Val WLL=0.8518, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.24it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.26it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.55it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.72it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.20it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.65it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.52it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.79it/s]




Epoch 4: Train BCE=0.9375, Val BCE=0.7483, Val WLL=0.5862, LR=1.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.37it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.49it/s]

Epoch 5/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.74it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.91it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.34it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.76it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.62it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.24it/s]




Epoch 5: Train BCE=0.8878, Val BCE=0.8117, Val WLL=0.6449, LR=1.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.29it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.37it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.65it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.82it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.27it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.68it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.54it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.12it/s]




Epoch 6: Train BCE=0.7614, Val BCE=0.8740, Val WLL=0.7541, LR=1.00e-05
Early stopping at epoch 6


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.63it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  9.97it/s]




Fold 3 OOF Weighted Log Loss: 0.7541

=== Fold 4/5 ===
Train: 161, Val: 41
Fold pos_weight (clip=2): [2. 2. 2. 2. 2. 2. 2.]
Bias initialized for 7 classes


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.38it/s]

Epoch 1/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.45it/s]

Epoch 1/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.70it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.85it/s]

Epoch 1/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.30it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.72it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.44it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.85it/s]




Epoch 1: Train BCE=0.9652, Val BCE=0.8029, Val WLL=0.5358, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.24it/s]

Epoch 2/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.27it/s]

Epoch 2/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.53it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.70it/s]

Epoch 2/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.18it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.52it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.61it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.17it/s]




Epoch 2: Train BCE=1.1448, Val BCE=0.7739, Val WLL=0.5007, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.18it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.14it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.41it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.59it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.10it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.40it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.77it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.53it/s]




Epoch 3: Train BCE=1.0423, Val BCE=0.8636, Val WLL=0.6080, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.20it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.17it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.44it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.62it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.10it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.50it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.48it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.94it/s]




Epoch 4: Train BCE=0.9690, Val BCE=0.7962, Val WLL=0.5385, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.13it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.08it/s]

Epoch 5/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.35it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.54it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.02it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.35it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.72it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.36it/s]




Epoch 5: Train BCE=0.9636, Val BCE=0.7287, Val WLL=0.4896, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.29it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.33it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.51it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.69it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.15it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.61it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.52it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.80it/s]




Epoch 6: Train BCE=0.9222, Val BCE=0.7590, Val WLL=0.5292, LR=2.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.36it/s]

Epoch 7/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.42it/s]

Epoch 7/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.67it/s]

Epoch 7/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.81it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.24it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.73it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.81it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.23it/s]




Epoch 7: Train BCE=0.8456, Val BCE=0.7577, Val WLL=0.5407, LR=2.00e-05


Epoch 8/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.33it/s]

Epoch 8/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.43it/s]

Epoch 8/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.70it/s]

Epoch 8/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.85it/s]

Epoch 8/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.29it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.65it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.73it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.54it/s]




Epoch 8: Train BCE=0.7383, Val BCE=0.7258, Val WLL=0.5402, LR=1.00e-05


Epoch 9/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.44it/s]

Epoch 9/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.50it/s]

Epoch 9/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.76it/s]

Epoch 9/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.90it/s]

Epoch 9/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.34it/s]

Epoch 9/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.77it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.75it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.37it/s]




Epoch 9: Train BCE=0.7579, Val BCE=0.7130, Val WLL=0.5287, LR=1.00e-05


Epoch 10/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.60it/s]

Epoch 10/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.13it/s]

Epoch 10/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.41it/s]

Epoch 10/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 12.77it/s]

Epoch 10/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 14.47it/s]

Epoch 10/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.43it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.56it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.15it/s]




Epoch 10: Train BCE=0.7978, Val BCE=0.7196, Val WLL=0.5433, LR=1.00e-05
Early stopping at epoch 10


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.56it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00, 10.10it/s]




Fold 4 OOF Weighted Log Loss: 0.5433

=== Fold 5/5 ===
Train: 162, Val: 40
Fold pos_weight (clip=2): [2. 2. 2. 2. 2. 2. 2.]
Bias initialized for 7 classes


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.33it/s]

Epoch 1/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.33it/s]

Epoch 1/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.56it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.75it/s]

Epoch 1/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.20it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.69it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.55it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.89it/s]




Epoch 1: Train BCE=0.9934, Val BCE=0.7128, Val WLL=0.6364, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.27it/s]

Epoch 2/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.34it/s]

Epoch 2/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.61it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.78it/s]

Epoch 2/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.25it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.68it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.74it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.46it/s]




Epoch 2: Train BCE=0.9787, Val BCE=0.7305, Val WLL=0.6773, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.24it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.25it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.50it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.66it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.12it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.52it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.71it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.46it/s]




Epoch 3: Train BCE=0.8728, Val BCE=0.7273, Val WLL=0.6646, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.29it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.36it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.63it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.78it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.24it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.69it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.92it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.43it/s]




Epoch 4: Train BCE=0.8685, Val BCE=0.6943, Val WLL=0.5737, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.28it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.30it/s]

Epoch 5/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.59it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.75it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.24it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.72it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.84it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.63it/s]




Epoch 5: Train BCE=0.8480, Val BCE=0.7244, Val WLL=0.5365, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.09it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  8.00it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.27it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.48it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 14.99it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.33it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.44it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.75it/s]




Epoch 6: Train BCE=0.8035, Val BCE=0.6788, Val WLL=0.5095, LR=2.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.26it/s]

Epoch 7/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.31it/s]

Epoch 7/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.56it/s]

Epoch 7/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.73it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.19it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.61it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.86it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.54it/s]




Epoch 7: Train BCE=0.7632, Val BCE=0.6306, Val WLL=0.4618, LR=2.00e-05


Epoch 8/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.23it/s]

Epoch 8/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.25it/s]

Epoch 8/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.52it/s]

Epoch 8/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.72it/s]

Epoch 8/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.18it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.63it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.80it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.29it/s]




Epoch 8: Train BCE=0.7722, Val BCE=0.6650, Val WLL=0.4937, LR=2.00e-05


Epoch 9/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.66it/s]

Epoch 9/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.22it/s]

Epoch 9/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.49it/s]

Epoch 9/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 12.83it/s]

Epoch 9/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 14.47it/s]

Epoch 9/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.58it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.73it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.46it/s]




Epoch 9: Train BCE=0.7914, Val BCE=0.6021, Val WLL=0.4601, LR=2.00e-05


Epoch 10/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.11it/s]

Epoch 10/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.98it/s]

Epoch 10/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.25it/s]

Epoch 10/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.46it/s]

Epoch 10/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 14.98it/s]

Epoch 10/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.32it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.57it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.04it/s]




Epoch 10: Train BCE=0.7626, Val BCE=0.6183, Val WLL=0.5143, LR=2.00e-05


Epoch 11/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 11/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.21it/s]

Epoch 11/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.18it/s]

Epoch 11/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.42it/s]

Epoch 11/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.61it/s]

Epoch 11/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.10it/s]

Epoch 11/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.51it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.66it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.42it/s]




Epoch 11: Train BCE=0.7485, Val BCE=0.6391, Val WLL=0.5265, LR=2.00e-05


Epoch 12/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 12/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.25it/s]

Epoch 12/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.26it/s]

Epoch 12/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.53it/s]

Epoch 12/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.71it/s]

Epoch 12/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.15it/s]

Epoch 12/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.56it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.66it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.16it/s]




Epoch 12: Train BCE=0.7796, Val BCE=0.6318, Val WLL=0.5371, LR=1.00e-05


Epoch 13/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 13/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.43it/s]

Epoch 13/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.52it/s]

Epoch 13/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.76it/s]

Epoch 13/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.87it/s]

Epoch 13/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.28it/s]

Epoch 13/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.74it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.79it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.55it/s]




Epoch 13: Train BCE=0.7425, Val BCE=0.6769, Val WLL=0.5925, LR=1.00e-05


Epoch 14/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 14/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.46it/s]

Epoch 14/15 Train:  27%|██▋       | 3/11 [00:00<00:00,  8.62it/s]

Epoch 14/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.83it/s]

Epoch 14/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.91it/s]

Epoch 14/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.31it/s]

Epoch 14/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.84it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.71it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.19it/s]




Epoch 14: Train BCE=0.7460, Val BCE=0.6634, Val WLL=0.5306, LR=1.00e-05
Early stopping at epoch 14


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.53it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00, 10.11it/s]

Fold 5 OOF Weighted Log Loss: 0.5306

CV Mean Weighted Log Loss: 0.5764 +/- 0.0913
Full OOF Weighted Log Loss: 0.5771
Full 7-class CV completed with tweaks, OOF saved to oof_predictions.csv





In [11]:
# ConvNeXt-Tiny 5-fold CV using same pipeline/hparams as Cell 7, plus head warmup + EMA
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import WeightedRandomSampler
import torch.nn as nn
import torch.nn.functional as F
import timm

full_train_df = pd.read_csv('data/train_mips.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
y_full = full_train_df[label_cols].values
mip_dir = 'data/mips/train'
n_folds = 5

skf = MultilabelStratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
splits = list(skf.split(full_train_df, y_full))

def build_convnext(num_classes=7):
    m = timm.create_model('convnext_tiny', pretrained=True, num_classes=num_classes, in_chans=3,
                          drop_rate=0.3, drop_path_rate=0.1)
    return m

oof_preds_full = np.zeros((len(full_train_df), 7), dtype=np.float32)
oof_overall_full = np.zeros(len(full_train_df), dtype=np.float32)
fold_scores = []

for fold, (train_idx, val_idx) in enumerate(splits, 1):
    print(f'\n=== ConvNeXt Fold {fold}/{n_folds} ===')
    train_df_f = full_train_df.iloc[train_idx].reset_index(drop=True)
    val_df_f = full_train_df.iloc[val_idx].reset_index(drop=True)

    # pos_weight (clip=2.0)
    y_tr = train_df_f[label_cols].values
    pos = (y_tr == 1).sum(axis=0); neg = (y_tr == 0).sum(axis=0)
    w = np.minimum(neg / (pos + 1e-6), 2.0).astype(np.float32)
    fold_pos_weight = torch.tensor(w, dtype=torch.float32, device=device)

    # WeightedRandomSampler (cap=3.0)
    class_rarity = np.minimum(neg / (pos + 1e-6), 10.0).astype(np.float32)
    sample_weight = (y_tr * class_rarity).max(axis=1)
    sample_weight = np.where(sample_weight > 0, sample_weight, 1.0)
    sample_weight = np.sqrt(sample_weight)
    sample_weight = np.clip(sample_weight, 1.0, 3.0)
    sampler = WeightedRandomSampler(torch.from_numpy(sample_weight), len(sample_weight), replacement=True)

    train_ds = MIPDataset(train_df_f, mip_dir, label_cols, train_transform)
    val_ds = MIPDataset(val_df_f, mip_dir, label_cols, val_transform)
    train_loader = DataLoader(train_ds, batch_size=16, sampler=sampler, num_workers=4, pin_memory=True, drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)

    # Model
    model = build_convnext(7).to(device)

    # Bias init from priors
    with torch.no_grad():
        priors = np.clip(y_tr.mean(axis=0) + 1e-6, 1e-6, 1-1e-6)
        bias = np.log(priors / (1 - priors)).astype(np.float32)
        head = getattr(model, 'head', None)
        if head is not None and hasattr(head, 'fc') and head.fc.bias is not None:
            head.fc.bias.copy_(torch.from_numpy(bias).to(head.fc.bias.device))

    optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True, min_lr=1e-6)
    criterion = SmoothedBCEWithLogitsLoss(pos_weight=fold_pos_weight, smoothing=0.0)
    scaler = GradScaler()
    early_stopping = EarlyStopping(patience=5, min_delta=0.001, restore_best_weights=False)

    # EMA
    ema = timm.utils.ModelEmaV2(model, decay=0.995)

    num_epochs = 15
    warmup_epochs = 2
    best_wll = float('inf')
    best_state = None

    for epoch in range(num_epochs):
        # Head warmup
        if epoch < warmup_epochs:
            for p in model.parameters(): p.requires_grad = False
            if hasattr(model, 'head'):
                for p in model.head.parameters(): p.requires_grad = True
        else:
            for p in model.parameters(): p.requires_grad = True

        model.train()
        train_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} Train'):
            images = images.to(device); labels7 = labels.to(device)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast():
                logits = model(images)
                loss = criterion(logits, labels7)
            scaler.scale(loss).backward()
            nn_utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            ema.update(model)
            train_loss += loss.item() * images.size(0)
        train_loss /= len(train_ds)

        # Validation with EMA
        ema.module.eval()
        val_wll = 0.0
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc='Val'):
                images = images.to(device); labels7 = labels.to(device)
                logits = ema.module(images)
                probs7 = torch.sigmoid(logits)
                probs_overall = probs7.max(dim=1, keepdim=True)[0]
                labels_overall = labels7.max(dim=1, keepdim=True)[0]
                probs8 = torch.cat([probs7, probs_overall], dim=1)
                labels8 = torch.cat([labels7, labels_overall], dim=1)
                val_wll += weighted_log_loss_torch(labels8, probs8).item() * images.size(0)
        val_wll /= len(val_ds)
        scheduler.step(val_wll)
        print(f'Epoch {epoch+1}: TrainBCE={train_loss:.4f}, ValWLL={val_wll:.4f}, LR={optimizer.param_groups[0]["lr"]:.2e}')

        if val_wll < best_wll:
            best_wll = val_wll
            best_state = ema.module.state_dict()

        if early_stopping(val_wll, model):
            print(f'Early stopping at epoch {epoch+1}')
            break

    # Save best EMA weights
    torch.save(best_state, f'fold_{fold}_convnext.pth')
    print(f'ConvNeXt Fold {fold} best OOF WLL (val): {best_wll:.4f}')

    # OOF collection with best EMA
    ema.module.load_state_dict(best_state)
    ema.module.eval()
    fold_oof_preds_7 = np.zeros((len(val_df_f), 7), dtype=np.float32)
    fold_oof_overall = np.zeros(len(val_df_f), dtype=np.float32)
    start = 0
    with torch.no_grad():
        for images, _ in tqdm(val_loader, desc='Final Val for OOF'):
            images = images.to(device)
            logits = ema.module(images)
            probs7 = torch.sigmoid(logits).cpu().numpy()
            probs_overall = probs7.max(axis=1)
            bs = images.size(0)
            fold_oof_preds_7[start:start+bs] = probs7
            fold_oof_overall[start:start+bs] = probs_overall
            start += bs

    oof_preds_full[val_idx] = fold_oof_preds_7
    oof_overall_full[val_idx] = fold_oof_overall

    y_val = val_df_f[label_cols].values
    fold_preds_8 = np.column_stack([fold_oof_preds_7, fold_oof_overall])
    fold_labels_8 = np.column_stack([y_val, np.max(y_val, axis=1)])
    fold_score = weighted_log_loss(fold_labels_8, fold_preds_8)
    fold_scores.append(fold_score)
    print(f'ConvNeXt Fold {fold} OOF WLL: {fold_score:.4f}')

cv_mean = float(np.mean(fold_scores)); cv_std = float(np.std(fold_scores))
print(f'\nConvNeXt CV Mean WLL: {cv_mean:.4f} +/- {cv_std:.4f}')
conv_oof_df = full_train_df[['StudyInstanceUID']].copy()
conv_oof_df[label_cols] = oof_preds_full
conv_oof_df['patient_overall'] = oof_overall_full
conv_oof_df.to_csv('oof_predictions_convnext.csv', index=False)


=== ConvNeXt Fold 1/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.97it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 10.48it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 15.49it/s]

Epoch 1/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 18.99it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 14.75it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.98it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.72it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.62it/s]




Epoch 1: TrainBCE=0.9976, ValWLL=0.4813, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.26it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.29it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.43it/s]

Epoch 2/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 19.78it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.49it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.87it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.55it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.32it/s]




Epoch 2: TrainBCE=0.9243, ValWLL=0.4808, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.78it/s]

Epoch 3/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.79it/s]

Epoch 3/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.02it/s]

Epoch 3/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.33it/s]

Epoch 3/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.06it/s]

Epoch 3/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.43it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.21it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.08it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.90it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.73it/s]




Epoch 3: TrainBCE=0.8731, ValWLL=0.4801, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.78it/s]

Epoch 4/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.82it/s]

Epoch 4/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.18it/s]

Epoch 4/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.33it/s]

Epoch 4/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.97it/s]

Epoch 4/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.35it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.24it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.91it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.62it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.54it/s]




Epoch 4: TrainBCE=0.7322, ValWLL=0.4764, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.79it/s]

Epoch 5/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.84it/s]

Epoch 5/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.20it/s]

Epoch 5/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.35it/s]

Epoch 5/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.99it/s]

Epoch 5/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.37it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.31it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.99it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.72it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.60it/s]




Epoch 5: TrainBCE=0.7354, ValWLL=0.4742, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.80it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.03it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.61it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.51it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.05it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.53it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.29it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.79it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.40it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.40it/s]




Epoch 6: TrainBCE=0.7172, ValWLL=0.4665, LR=2.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.87it/s]

Epoch 7/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.94it/s]

Epoch 7/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.28it/s]

Epoch 7/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.41it/s]

Epoch 7/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.04it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.22it/s]

Epoch 7/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.39it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.33it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.28it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.17it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.93it/s]




Epoch 7: TrainBCE=0.6840, ValWLL=0.4602, LR=2.00e-05


Epoch 8/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.82it/s]

Epoch 8/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.88it/s]

Epoch 8/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.22it/s]

Epoch 8/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.35it/s]

Epoch 8/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.99it/s]

Epoch 8/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.36it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.27it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.86it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.52it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.44it/s]




Epoch 8: TrainBCE=0.7505, ValWLL=0.4555, LR=2.00e-05


Epoch 9/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.74it/s]

Epoch 9/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.78it/s]

Epoch 9/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.15it/s]

Epoch 9/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.30it/s]

Epoch 9/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.95it/s]

Epoch 9/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.33it/s]

Epoch 9/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.19it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.01it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.75it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.67it/s]




Epoch 9: TrainBCE=0.7165, ValWLL=0.4530, LR=2.00e-05


Epoch 10/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.78it/s]

Epoch 10/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.82it/s]

Epoch 10/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.17it/s]

Epoch 10/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.32it/s]

Epoch 10/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.97it/s]

Epoch 10/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.16it/s]

Epoch 10/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.35it/s]

Epoch 10/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.25it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.80it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.40it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.37it/s]




Epoch 10: TrainBCE=0.7058, ValWLL=0.4533, LR=2.00e-05


Epoch 11/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 11/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.77it/s]

Epoch 11/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.81it/s]

Epoch 11/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.15it/s]

Epoch 11/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.30it/s]

Epoch 11/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.95it/s]

Epoch 11/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.33it/s]

Epoch 11/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.21it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.02it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.72it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.57it/s]




Epoch 11: TrainBCE=0.7146, ValWLL=0.4541, LR=2.00e-05


Epoch 12/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 12/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.77it/s]

Epoch 12/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.81it/s]

Epoch 12/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.17it/s]

Epoch 12/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.31it/s]

Epoch 12/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.96it/s]

Epoch 12/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.34it/s]

Epoch 12/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.25it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.80it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.42it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.46it/s]




Epoch 12: TrainBCE=0.6921, ValWLL=0.4569, LR=1.00e-05


Epoch 13/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 13/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.79it/s]

Epoch 13/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.82it/s]

Epoch 13/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.18it/s]

Epoch 13/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.33it/s]

Epoch 13/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.96it/s]

Epoch 13/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.35it/s]

Epoch 13/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.22it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.85it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.53it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.43it/s]




Epoch 13: TrainBCE=0.7125, ValWLL=0.4605, LR=1.00e-05


Epoch 14/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 14/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.94it/s]

Epoch 14/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.02it/s]

Epoch 14/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.34it/s]

Epoch 14/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.43it/s]

Epoch 14/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.04it/s]

Epoch 14/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.24it/s]

Epoch 14/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.87it/s]

Epoch 14/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.36it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.99it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.69it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.64it/s]




Epoch 14: TrainBCE=0.7058, ValWLL=0.4617, LR=1.00e-05
Early stopping at epoch 14
ConvNeXt Fold 1 best OOF WLL (val): 0.4530


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.05it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  9.83it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  7.64it/s]




ConvNeXt Fold 1 OOF WLL: 0.4617

=== ConvNeXt Fold 2/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.57it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.90it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.92it/s]

Traceback (most recent call last):
  File "/usr/lib/python3.11/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/usr/lib/python3.11/multiprocessing/connection.py", line 182, in close
    self._close()
  File "/usr/lib/python3.11/multiprocessing/connection.py", line 365, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor
Epoch 1/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 10.95it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.24it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.26it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.10it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.74it/s]




Epoch 1: TrainBCE=0.7835, ValWLL=0.4658, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.74it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 12.02it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 17.05it/s]

Epoch 2/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 20.18it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.18it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.90it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.54it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.47it/s]




Epoch 2: TrainBCE=0.8640, ValWLL=0.4656, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.92it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.21it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.72it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.68it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.15it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.65it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.39it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.02it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.71it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.62it/s]




Epoch 3: TrainBCE=0.8579, ValWLL=0.4641, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.00it/s]

Epoch 4/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.08it/s]

Epoch 4/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.38it/s]

Epoch 4/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.44it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.79it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.26it/s]

Epoch 4/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.38it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.37it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.94it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.59it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.44it/s]




Epoch 4: TrainBCE=0.6812, ValWLL=0.4606, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.98it/s]

Epoch 5/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.05it/s]

Epoch 5/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.35it/s]

Epoch 5/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.43it/s]

Epoch 5/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.03it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.21it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.85it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.34it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.90it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.58it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.56it/s]




Epoch 5: TrainBCE=0.7132, ValWLL=0.4573, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.03it/s]

Epoch 6/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.13it/s]

Epoch 6/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.40it/s]

Epoch 6/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.46it/s]

Epoch 6/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.04it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.22it/s]

Epoch 6/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.37it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.36it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.82it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.40it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.34it/s]




Epoch 6: TrainBCE=0.7525, ValWLL=0.4575, LR=2.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.02it/s]

Epoch 7/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.11it/s]

Epoch 7/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.38it/s]

Epoch 7/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.44it/s]

Epoch 7/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.03it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.22it/s]

Epoch 7/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.39it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.36it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.12it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.87it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.77it/s]




Epoch 7: TrainBCE=0.6954, ValWLL=0.4630, LR=2.00e-05


Epoch 8/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.79it/s]

Epoch 8/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.84it/s]

Epoch 8/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.17it/s]

Epoch 8/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.84it/s]

Epoch 8/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.39it/s]

Epoch 8/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.10it/s]

Epoch 8/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.28it/s]

Epoch 8/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.45it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.22it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.14it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.87it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.81it/s]




Epoch 8: TrainBCE=0.7219, ValWLL=0.4612, LR=1.00e-05


Epoch 9/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.90it/s]

Epoch 9/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.96it/s]

Epoch 9/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.27it/s]

Epoch 9/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.37it/s]

Epoch 9/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.98it/s]

Epoch 9/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.17it/s]

Epoch 9/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.36it/s]

Epoch 9/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.31it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.90it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.55it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.34it/s]




Epoch 9: TrainBCE=0.7551, ValWLL=0.4658, LR=1.00e-05


Epoch 10/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.75it/s]

Epoch 10/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.78it/s]

Epoch 10/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.23it/s]

Epoch 10/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.95it/s]

Epoch 10/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.44it/s]

Epoch 10/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.84it/s]

Epoch 10/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.12it/s]

Epoch 10/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.33it/s]

Epoch 10/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.50it/s]

Epoch 10/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.10it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.88it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.50it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.47it/s]




Epoch 10: TrainBCE=0.6664, ValWLL=0.4627, LR=1.00e-05
Early stopping at epoch 10
ConvNeXt Fold 2 best OOF WLL (val): 0.4573


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.25it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00, 10.02it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  7.92it/s]




ConvNeXt Fold 2 OOF WLL: 0.4627

=== ConvNeXt Fold 3/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.49it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.37it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 15.81it/s]

Epoch 1/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 19.10it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.33it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.07it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.53it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.56it/s]




Epoch 1: TrainBCE=0.8371, ValWLL=0.4723, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.96it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 12.47it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 17.41it/s]

Epoch 2/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 20.53it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.66it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.16it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.78it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.77it/s]




Epoch 2: TrainBCE=0.8233, ValWLL=0.4719, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.27it/s]

Epoch 3/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.40it/s]

Epoch 3/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.59it/s]

Epoch 3/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.68it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.97it/s]

Epoch 3/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.22it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.40it/s]

Epoch 3/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.55it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.51it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.97it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.52it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.57it/s]




Epoch 3: TrainBCE=0.7388, ValWLL=0.4687, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.12it/s]

Epoch 4/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.21it/s]

Epoch 4/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.46it/s]

Epoch 4/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.49it/s]

Epoch 4/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.06it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.22it/s]

Epoch 4/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.38it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.43it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.19it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.79it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.75it/s]




Epoch 4: TrainBCE=0.7711, ValWLL=0.4612, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.00it/s]

Epoch 5/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.08it/s]

Epoch 5/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.36it/s]

Epoch 5/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.43it/s]

Epoch 5/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.02it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.20it/s]

Epoch 5/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.36it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.39it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.94it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.44it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.46it/s]




Epoch 5: TrainBCE=0.7093, ValWLL=0.4594, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.08it/s]

Epoch 6/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.16it/s]

Epoch 6/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.43it/s]

Epoch 6/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.48it/s]

Epoch 6/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.06it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.24it/s]

Epoch 6/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.40it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.46it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.04it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.57it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.62it/s]




Epoch 6: TrainBCE=0.7066, ValWLL=0.4541, LR=2.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.86it/s]

Epoch 7/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.90it/s]

Epoch 7/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.36it/s]

Epoch 7/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.03it/s]

Epoch 7/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.83it/s]

Epoch 7/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.08it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.25it/s]

Epoch 7/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.42it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.24it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.84it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.28it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.25it/s]




Epoch 7: TrainBCE=0.7613, ValWLL=0.4567, LR=2.00e-05


Epoch 8/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.24it/s]

Epoch 8/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.36it/s]

Epoch 8/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.58it/s]

Epoch 8/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.57it/s]

Epoch 8/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.11it/s]

Epoch 8/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.28it/s]

Epoch 8/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.44it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.57it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.17it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.78it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.56it/s]




Epoch 8: TrainBCE=0.7202, ValWLL=0.4612, LR=2.00e-05


Epoch 9/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.96it/s]

Epoch 9/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.05it/s]

Epoch 9/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.52it/s]

Epoch 9/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.13it/s]

Epoch 9/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.89it/s]

Epoch 9/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.27it/s]

Epoch 9/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.39it/s]

Epoch 9/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.35it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.12it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.71it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.45it/s]




Epoch 9: TrainBCE=0.6874, ValWLL=0.4647, LR=1.00e-05


Epoch 10/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.86it/s]

Epoch 10/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.92it/s]

Epoch 10/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.23it/s]

Epoch 10/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.34it/s]

Epoch 10/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.95it/s]

Epoch 10/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.13it/s]

Epoch 10/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.31it/s]

Epoch 10/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.28it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.15it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.73it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.76it/s]




Epoch 10: TrainBCE=0.6629, ValWLL=0.4580, LR=1.00e-05


Epoch 11/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 11/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.95it/s]

Epoch 11/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.02it/s]

Epoch 11/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.50it/s]

Epoch 11/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.12it/s]

Epoch 11/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.59it/s]

Epoch 11/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.95it/s]

Epoch 11/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.39it/s]

Epoch 11/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.52it/s]

Epoch 11/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.35it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.17it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.72it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.56it/s]




Epoch 11: TrainBCE=0.7348, ValWLL=0.4560, LR=1.00e-05
Early stopping at epoch 11
ConvNeXt Fold 3 best OOF WLL (val): 0.4541


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.23it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  9.85it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  7.62it/s]




ConvNeXt Fold 3 OOF WLL: 0.4560

=== ConvNeXt Fold 4/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.84it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 12.35it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 17.33it/s]

Epoch 1/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 20.47it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.41it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.21it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.84it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.77it/s]




Epoch 1: TrainBCE=0.9231, ValWLL=0.4828, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.82it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 12.07it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 17.06it/s]

Epoch 2/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 20.06it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.09it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.18it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.79it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.64it/s]




Epoch 2: TrainBCE=0.9889, ValWLL=0.4823, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.23it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.55it/s]

Epoch 3/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.37it/s]

Epoch 3/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.63it/s]

Epoch 3/15 Train:  73%|███████▎  | 8/11 [00:00<00:00,  9.16it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.31it/s]

Epoch 3/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.45it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.57it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.18it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.77it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.66it/s]




Epoch 3: TrainBCE=0.8517, ValWLL=0.4809, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.12it/s]

Epoch 4/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.21it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.65it/s]

Epoch 4/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.66it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.37it/s]

Epoch 4/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.85it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  9.19it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.53it/s]

Epoch 4/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.61it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.39it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.04it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.56it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.49it/s]




Epoch 4: TrainBCE=0.6839, ValWLL=0.4692, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.12it/s]

Epoch 5/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.22it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.68it/s]

Epoch 5/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.69it/s]

Epoch 5/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.77it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  9.07it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.43it/s]

Epoch 5/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.53it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.42it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.00it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.52it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.33it/s]




Epoch 5: TrainBCE=0.7282, ValWLL=0.4571, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.90it/s]

Epoch 6/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.96it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.44it/s]

Epoch 6/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.50it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.24it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  9.06it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.39it/s]

Epoch 6/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.49it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.28it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.18it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.78it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.51it/s]




Epoch 6: TrainBCE=0.7153, ValWLL=0.4551, LR=2.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.99it/s]

Epoch 7/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.05it/s]

Epoch 7/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.52it/s]

Epoch 7/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.13it/s]

Epoch 7/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.59it/s]

Epoch 7/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.19it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.34it/s]

Epoch 7/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.48it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.32it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.76it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.16it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.24it/s]




Epoch 7: TrainBCE=0.7144, ValWLL=0.4428, LR=2.00e-05


Epoch 8/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.10it/s]

Epoch 8/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.21it/s]

Epoch 8/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.46it/s]

Epoch 8/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.09it/s]

Epoch 8/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.90it/s]

Epoch 8/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.29it/s]

Epoch 8/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.42it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.48it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.04it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.58it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.61it/s]




Epoch 8: TrainBCE=0.6799, ValWLL=0.4398, LR=2.00e-05


Epoch 9/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.04it/s]

Epoch 9/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.12it/s]

Epoch 9/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.58it/s]

Epoch 9/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.17it/s]

Epoch 9/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.91it/s]

Epoch 9/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.15it/s]

Epoch 9/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.30it/s]

Epoch 9/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.45it/s]

Epoch 9/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.43it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.15it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.74it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.72it/s]




Epoch 9: TrainBCE=0.7162, ValWLL=0.4397, LR=2.00e-05


Epoch 10/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.94it/s]

Epoch 10/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.01it/s]

Epoch 10/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.48it/s]

Epoch 10/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.53it/s]

Epoch 10/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.68it/s]

Epoch 10/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  9.00it/s]

Epoch 10/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.38it/s]

Epoch 10/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.50it/s]

Epoch 10/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.31it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.89it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.34it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.19it/s]




Epoch 10: TrainBCE=0.6709, ValWLL=0.4470, LR=2.00e-05


Epoch 11/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 11/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.98it/s]

Epoch 11/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.05it/s]

Epoch 11/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.52it/s]

Epoch 11/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.56it/s]

Epoch 11/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.29it/s]

Epoch 11/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  9.09it/s]

Epoch 11/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.45it/s]

Epoch 11/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.51it/s]

Epoch 11/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.32it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.16it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.80it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.70it/s]




Epoch 11: TrainBCE=0.6913, ValWLL=0.4410, LR=2.00e-05


Epoch 12/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 12/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.03it/s]

Epoch 12/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.12it/s]

Epoch 12/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.39it/s]

Epoch 12/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.45it/s]

Epoch 12/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.79it/s]

Epoch 12/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.26it/s]

Epoch 12/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.42it/s]

Epoch 12/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.39it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.04it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.60it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.63it/s]




Epoch 12: TrainBCE=0.7061, ValWLL=0.4432, LR=1.00e-05


Epoch 13/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 13/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.11it/s]

Epoch 13/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.18it/s]

Epoch 13/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.45it/s]

Epoch 13/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.48it/s]

Epoch 13/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.06it/s]

Epoch 13/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.25it/s]

Epoch 13/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.89it/s]

Epoch 13/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.46it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.16it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.78it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.77it/s]




Epoch 13: TrainBCE=0.7528, ValWLL=0.4473, LR=1.00e-05
Early stopping at epoch 13
ConvNeXt Fold 4 best OOF WLL (val): 0.4397


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.15it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  9.75it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  7.67it/s]




ConvNeXt Fold 4 OOF WLL: 0.4473

=== ConvNeXt Fold 5/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.54it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.82it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.87it/s]

Epoch 1/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 20.08it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.01it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.04it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.74it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.66it/s]




Epoch 1: TrainBCE=0.8590, ValWLL=0.4527, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.79it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 12.13it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 17.14it/s]

Epoch 2/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 20.24it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.39it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.93it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.55it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.53it/s]




Epoch 2: TrainBCE=0.8010, ValWLL=0.4525, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.02it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.34it/s]

Epoch 3/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.16it/s]

Epoch 3/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.43it/s]

Epoch 3/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.10it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.27it/s]

Epoch 3/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.43it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.44it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.04it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.73it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.62it/s]




Epoch 3: TrainBCE=0.8030, ValWLL=0.4470, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.85it/s]

Epoch 4/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.92it/s]

Epoch 4/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.24it/s]

Epoch 4/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.35it/s]

Epoch 4/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.96it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.14it/s]

Epoch 4/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.31it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.25it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.24it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.04it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.94it/s]




Epoch 4: TrainBCE=0.7233, ValWLL=0.4425, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.08it/s]

Epoch 5/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.16it/s]

Epoch 5/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.43it/s]

Epoch 5/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.47it/s]

Epoch 5/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.05it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.23it/s]

Epoch 5/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.40it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.43it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.03it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.70it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.65it/s]




Epoch 5: TrainBCE=0.6922, ValWLL=0.4439, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.05it/s]

Epoch 6/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.15it/s]

Epoch 6/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.42it/s]

Epoch 6/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.46it/s]

Epoch 6/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.04it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.20it/s]

Epoch 6/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.36it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.38it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.11it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.89it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.68it/s]




Epoch 6: TrainBCE=0.7386, ValWLL=0.4450, LR=2.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.08it/s]

Epoch 7/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.17it/s]

Epoch 7/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.43it/s]

Epoch 7/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.47it/s]

Epoch 7/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.04it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.23it/s]

Epoch 7/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.40it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.40it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.18it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.93it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.67it/s]




Epoch 7: TrainBCE=0.6723, ValWLL=0.4452, LR=1.00e-05


Epoch 8/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.93it/s]

Epoch 8/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.01it/s]

Epoch 8/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.32it/s]

Epoch 8/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.36it/s]

Epoch 8/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.98it/s]

Epoch 8/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.18it/s]

Epoch 8/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.36it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.34it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.27it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.08it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.97it/s]




Epoch 8: TrainBCE=0.7276, ValWLL=0.4442, LR=1.00e-05


Epoch 9/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.15it/s]

Epoch 9/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.27it/s]

Epoch 9/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.52it/s]

Epoch 9/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.15it/s]

Epoch 9/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.94it/s]

Epoch 9/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.34it/s]

Epoch 9/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.47it/s]

Epoch 9/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.47it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.99it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.67it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.50it/s]




Epoch 9: TrainBCE=0.7007, ValWLL=0.4458, LR=1.00e-05
Early stopping at epoch 9
ConvNeXt Fold 5 best OOF WLL (val): 0.4425


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.01it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  9.67it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  7.48it/s]

ConvNeXt Fold 5 OOF WLL: 0.4458

ConvNeXt CV Mean WLL: 0.4547 +/- 0.0071





In [None]:
# Compute CV score from existing OOF predictions (previous run)
full_train_df = pd.read_csv('data/train_mips.csv')
oof_df = pd.read_csv('oof_predictions.csv')
label_cols = ['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7']
oof_preds_full = oof_df[label_cols].values
y_full = full_train_df[label_cols].values
oof_overall_full = oof_df['patient_overall'].values
oof_preds_8_full = np.column_stack([oof_preds_full, oof_overall_full])
oof_labels_8_full = np.column_stack([y_full, np.max(y_full, axis=1)])
full_oof_loss = weighted_log_loss(oof_labels_8_full, oof_preds_8_full)
print(f'Full OOF Weighted Log Loss from previous run: {full_oof_loss:.4f}')

# Also print per-fold scores if available, but since interrupted, use full OOF
print('OOF computation completed')

In [None]:
# Preprocessing sanity check: Visualize MIPs for 3 random studies
import matplotlib.pyplot as plt
import numpy as np
import os
import random

mip_dir = 'data/mips/train'
train_df = pd.read_csv('data/train_mips.csv')
uids = train_df['StudyInstanceUID'].tolist()
random_uids = random.sample(uids, 3)

fig, axs = plt.subplots(3, 3, figsize=(12, 12))
for i, uid in enumerate(random_uids):
    mip_path = os.path.join(mip_dir, f'{uid}.npy')
    if os.path.exists(mip_path):
        mip = np.load(mip_path)  # (3, 384, 384): 0=sagittal, 1=coronal, 2=axial
        for j, view in enumerate(['Sagittal', 'Coronal', 'Axial']):
            img = mip[j]
            axs[i, j].imshow(img, cmap='gray')
            axs[i, j].set_title(f'{uid[:8]} - {view}')
            axs[i, j].axis('off')
    else:
        print(f'Missing {mip_path}')

plt.tight_layout()
plt.show()

# Check values: should be in [0,1]
for uid in random_uids:
    mip_path = os.path.join(mip_dir, f'{uid}.npy')
    if os.path.exists(mip_path):
        mip = np.load(mip_path)
        print(f'{uid[:8]} min/max: {mip.min():.3f}/{mip.max():.3f}, mean/std: {mip.mean():.3f}/{mip.std():.3f}')

print('Visualization completed. Check: Sagittal upright? Bone window contrast? Values [0,1]? No artifacts?')

In [12]:
# Compute OOF logits for ConvNeXt temperature scaling
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader

# Reuse MIPDataset and val_transform from earlier cells
full_train_df = pd.read_csv('data/train_mips.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
y_full = full_train_df[label_cols].values
mip_dir = 'data/mips/train'
n_folds = 5

from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
skf = MultilabelStratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
splits = list(skf.split(full_train_df, y_full))

def build_convnext(num_classes=7):
    import timm
    m = timm.create_model('convnext_tiny', pretrained=False, num_classes=num_classes, in_chans=3,
                          drop_rate=0.3, drop_path_rate=0.1)
    return m.to(device).eval()

oof_logits_full = np.zeros((len(full_train_df), 7), dtype=np.float32)

for fold, (train_idx, val_idx) in enumerate(splits, 1):
    print(f'Computing OOF logits for ConvNeXt Fold {fold}')
    val_df_f = full_train_df.iloc[val_idx].reset_index(drop=True)
    val_ds = MIPDataset(val_df_f, mip_dir, label_cols, val_transform)
    val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)

    ckpt = f'fold_{fold}_convnext.pth'
    model = build_convnext(7)
    sd = torch.load(ckpt, map_location='cpu')
    model.load_state_dict(sd, strict=True)

    model.eval()
    fold_logits = []
    with torch.no_grad():
        for images, _ in tqdm(val_loader, desc=f'Fold {fold} OOF logits'):
            images = images.to(device)
            logits = model(images)
            fold_logits.append(logits.cpu().numpy())
    fold_logits = np.concatenate(fold_logits, axis=0)
    oof_logits_full[val_idx] = fold_logits
    print(f'Fold {fold} OOF logits shape: {fold_logits.shape}')

# Save OOF logits
oof_logits_df = full_train_df[['StudyInstanceUID']].copy()
for i, col in enumerate(label_cols):
    oof_logits_df[col] = oof_logits_full[:, i]
oof_logits_df.to_csv('oof_logits_convnext.csv', index=False)
print('OOF logits saved to oof_logits_convnext.csv')

# Quick check: Compute probs from logits and verify matches oof_predictions_convnext.csv
oof_probs_from_logits = 1 / (1 + np.exp(-oof_logits_full))
oof_df = pd.read_csv('oof_predictions_convnext.csv')
oof_probs_saved = oof_df[label_cols].values
print(f'Probs match (max diff): {np.max(np.abs(oof_probs_from_logits - oof_probs_saved)):.6f}')

Computing OOF logits for ConvNeXt Fold 1


Fold 1 OOF logits:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 1 OOF logits:  33%|███▎      | 1/3 [00:00<00:00,  4.26it/s]

Fold 1 OOF logits: 100%|██████████| 3/3 [00:00<00:00, 10.06it/s]

Fold 1 OOF logits: 100%|██████████| 3/3 [00:00<00:00,  7.79it/s]




Fold 1 OOF logits shape: (40, 7)
Computing OOF logits for ConvNeXt Fold 2


Fold 2 OOF logits:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 2 OOF logits:  33%|███▎      | 1/3 [00:00<00:00,  4.03it/s]

Fold 2 OOF logits: 100%|██████████| 3/3 [00:00<00:00,  9.78it/s]

Fold 2 OOF logits: 100%|██████████| 3/3 [00:00<00:00,  7.67it/s]




Fold 2 OOF logits shape: (40, 7)
Computing OOF logits for ConvNeXt Fold 3


Fold 3 OOF logits:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 3 OOF logits:  33%|███▎      | 1/3 [00:00<00:00,  4.15it/s]

Fold 3 OOF logits: 100%|██████████| 3/3 [00:00<00:00,  9.73it/s]

Fold 3 OOF logits: 100%|██████████| 3/3 [00:00<00:00,  7.60it/s]




Fold 3 OOF logits shape: (41, 7)
Computing OOF logits for ConvNeXt Fold 4


Fold 4 OOF logits:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 4 OOF logits:  33%|███▎      | 1/3 [00:00<00:00,  4.05it/s]

Fold 4 OOF logits: 100%|██████████| 3/3 [00:00<00:00,  9.56it/s]

Fold 4 OOF logits: 100%|██████████| 3/3 [00:00<00:00,  7.41it/s]




Fold 4 OOF logits shape: (41, 7)
Computing OOF logits for ConvNeXt Fold 5


Fold 5 OOF logits:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 5 OOF logits:  33%|███▎      | 1/3 [00:00<00:00,  4.14it/s]

Fold 5 OOF logits: 100%|██████████| 3/3 [00:00<00:00,  9.87it/s]

Fold 5 OOF logits: 100%|██████████| 3/3 [00:00<00:00,  7.62it/s]

Fold 5 OOF logits shape: (40, 7)
OOF logits saved to oof_logits_convnext.csv
Probs match (max diff): 0.000000





In [13]:
# Fit per-class temperature scaling for ConvNeXt using OOF logits
import numpy as np
import pandas as pd
from scipy.optimize import minimize_scalar
from scipy.special import expit as sigmoid

# Load data
oof_logits_df = pd.read_csv('oof_logits_convnext.csv')
train_df = pd.read_csv('data/train_mips.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
oof_logits = oof_logits_df[label_cols].values  # (202, 7)
y_true = train_df[label_cols].values  # (202, 7)

def binary_log_loss(y_true, y_pred):
    y_pred = np.clip(y_pred, 1e-6, 1-1e-6)
    return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))

# Optimize T per class (minimize log loss for each class independently)
temperatures = np.ones(7)
for i, col in enumerate(label_cols):
    def loss_func(T):
        logits_scaled = oof_logits[:, i] / T
        probs = sigmoid(logits_scaled)
        return binary_log_loss(y_true[:, i], probs)
    res = minimize_scalar(loss_func, bounds=(0.5, 2.0), method='bounded')
    temperatures[i] = res.x
    print(f'{col} temperature: {temperatures[i]:.4f}, min loss: {res.fun:.4f}')

# Save temperatures
np.save('temperatures_convnext.npy', temperatures)
print('Temperatures saved to temperatures_convnext.npy')

# Quick check: Compute calibrated OOF WLL (only on 7 classes, weights=1 each)
logits_calib = oof_logits / temperatures
probs_calib = sigmoid(logits_calib)
wll_calib = np.mean([binary_log_loss(y_true[:, i], probs_calib[:, i]) for i in range(7)])
wll_raw = np.mean([binary_log_loss(y_true[:, i], sigmoid(oof_logits[:, i])) for i in range(7)])
print(f'Raw OOF BCE (avg): {wll_raw:.4f}, Calibrated: {wll_calib:.4f} (delta: {wll_calib - wll_raw:.4f})')

C1 temperature: 0.6592, min loss: 0.3298
C2 temperature: 0.5818, min loss: 0.4225
C3 temperature: 0.8497, min loss: 0.1850
C4 temperature: 0.6826, min loss: 0.2157
C5 temperature: 0.6529, min loss: 0.3548
C6 temperature: 0.6875, min loss: 0.4806
C7 temperature: 0.5915, min loss: 0.4465
Temperatures saved to temperatures_convnext.npy
Raw OOF BCE (avg): 0.3736, Calibrated: 0.3478 (delta: -0.0258)


In [14]:
# ConvNeXt-Tiny v2 5-fold CV: stronger reg (drop=0.4, path=0.2), seed=123 for diversity
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import WeightedRandomSampler
import torch.nn as nn
import torch.nn.functional as F
import timm

full_train_df = pd.read_csv('data/train_mips.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
y_full = full_train_df[label_cols].values
mip_dir = 'data/mips/train'
n_folds = 5

skf = MultilabelStratifiedKFold(n_splits=n_folds, shuffle=True, random_state=123)
splits = list(skf.split(full_train_df, y_full))

def build_convnext_v2(num_classes=7):
    m = timm.create_model('convnext_tiny', pretrained=True, num_classes=num_classes, in_chans=3,
                          drop_rate=0.4, drop_path_rate=0.2)
    return m

oof_preds_full_v2 = np.zeros((len(full_train_df), 7), dtype=np.float32)
oof_overall_full_v2 = np.zeros(len(full_train_df), dtype=np.float32)
fold_scores_v2 = []

for fold, (train_idx, val_idx) in enumerate(splits, 1):
    print(f'\n=== ConvNeXt v2 Fold {fold}/{n_folds} ===')
    train_df_f = full_train_df.iloc[train_idx].reset_index(drop=True)
    val_df_f = full_train_df.iloc[val_idx].reset_index(drop=True)

    # pos_weight (clip=2.0)
    y_tr = train_df_f[label_cols].values
    pos = (y_tr == 1).sum(axis=0); neg = (y_tr == 0).sum(axis=0)
    w = np.minimum(neg / (pos + 1e-6), 2.0).astype(np.float32)
    fold_pos_weight = torch.tensor(w, dtype=torch.float32, device=device)

    # WeightedRandomSampler (cap=3.0)
    class_rarity = np.minimum(neg / (pos + 1e-6), 10.0).astype(np.float32)
    sample_weight = (y_tr * class_rarity).max(axis=1)
    sample_weight = np.where(sample_weight > 0, sample_weight, 1.0)
    sample_weight = np.sqrt(sample_weight)
    sample_weight = np.clip(sample_weight, 1.0, 3.0)
    sampler = WeightedRandomSampler(torch.from_numpy(sample_weight), len(sample_weight), replacement=True)

    train_ds = MIPDataset(train_df_f, mip_dir, label_cols, train_transform)
    val_ds = MIPDataset(val_df_f, mip_dir, label_cols, val_transform)
    train_loader = DataLoader(train_ds, batch_size=16, sampler=sampler, num_workers=4, pin_memory=True, drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)

    # Model v2
    model = build_convnext_v2(7).to(device)

    # Bias init from priors
    with torch.no_grad():
        priors = np.clip(y_tr.mean(axis=0) + 1e-6, 1e-6, 1-1e-6)
        bias = np.log(priors / (1 - priors)).astype(np.float32)
        head = getattr(model, 'head', None)
        if head is not None and hasattr(head, 'fc') and head.fc.bias is not None:
            head.fc.bias.copy_(torch.from_numpy(bias).to(head.fc.bias.device))

    optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True, min_lr=1e-6)
    criterion = SmoothedBCEWithLogitsLoss(pos_weight=fold_pos_weight, smoothing=0.0)
    scaler = GradScaler()
    early_stopping = EarlyStopping(patience=5, min_delta=0.001, restore_best_weights=False)

    # EMA
    ema = timm.utils.ModelEmaV2(model, decay=0.995)

    num_epochs = 15
    warmup_epochs = 2
    best_wll = float('inf')
    best_state = None

    for epoch in range(num_epochs):
        # Head warmup
        if epoch < warmup_epochs:
            for p in model.parameters(): p.requires_grad = False
            if hasattr(model, 'head'):
                for p in model.head.parameters(): p.requires_grad = True
        else:
            for p in model.parameters(): p.requires_grad = True

        model.train()
        train_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} Train'):
            images = images.to(device); labels7 = labels.to(device)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast():
                logits = model(images)
                loss = criterion(logits, labels7)
            scaler.scale(loss).backward()
            nn_utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            ema.update(model)
            train_loss += loss.item() * images.size(0)
        train_loss /= len(train_ds)

        # Validation with EMA
        ema.module.eval()
        val_wll = 0.0
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc='Val'):
                images = images.to(device); labels7 = labels.to(device)
                logits = ema.module(images)
                probs7 = torch.sigmoid(logits)
                probs_overall = probs7.max(dim=1, keepdim=True)[0]
                labels_overall = labels7.max(dim=1, keepdim=True)[0]
                probs8 = torch.cat([probs7, probs_overall], dim=1)
                labels8 = torch.cat([labels7, labels_overall], dim=1)
                val_wll += weighted_log_loss_torch(labels8, probs8).item() * images.size(0)
        val_wll /= len(val_ds)
        scheduler.step(val_wll)
        print(f'Epoch {epoch+1}: TrainBCE={train_loss:.4f}, ValWLL={val_wll:.4f}, LR={optimizer.param_groups[0]["lr"]:.2e}')

        if val_wll < best_wll:
            best_wll = val_wll
            best_state = ema.module.state_dict()

        if early_stopping(val_wll, model):
            print(f'Early stopping at epoch {epoch+1}')
            break

    # Save best EMA weights v2
    torch.save(best_state, f'fold_{fold}_convnext_v2.pth')
    print(f'ConvNeXt v2 Fold {fold} best OOF WLL (val): {best_wll:.4f}')

    # OOF collection with best EMA
    ema.module.load_state_dict(best_state)
    ema.module.eval()
    fold_oof_preds_7 = np.zeros((len(val_df_f), 7), dtype=np.float32)
    fold_oof_overall = np.zeros(len(val_df_f), dtype=np.float32)
    start = 0
    with torch.no_grad():
        for images, _ in tqdm(val_loader, desc='Final Val for OOF'):
            images = images.to(device)
            logits = ema.module(images)
            probs7 = torch.sigmoid(logits).cpu().numpy()
            probs_overall = probs7.max(axis=1)
            bs = images.size(0)
            fold_oof_preds_7[start:start+bs] = probs7
            fold_oof_overall[start:start+bs] = probs_overall
            start += bs

    oof_preds_full_v2[val_idx] = fold_oof_preds_7
    oof_overall_full_v2[val_idx] = fold_oof_overall

    y_val = val_df_f[label_cols].values
    fold_preds_8 = np.column_stack([fold_oof_preds_7, fold_oof_overall])
    fold_labels_8 = np.column_stack([y_val, np.max(y_val, axis=1)])
    fold_score = weighted_log_loss(fold_labels_8, fold_preds_8)
    fold_scores_v2.append(fold_score)
    print(f'ConvNeXt v2 Fold {fold} OOF WLL: {fold_score:.4f}')

cv_mean_v2 = float(np.mean(fold_scores_v2)); cv_std_v2 = float(np.std(fold_scores_v2))
print(f'\nConvNeXt v2 CV Mean WLL: {cv_mean_v2:.4f} +/- {cv_std_v2:.4f}')
conv_oof_df_v2 = full_train_df[['StudyInstanceUID']].copy()
conv_oof_df_v2[label_cols] = oof_preds_full_v2
conv_oof_df_v2['patient_overall'] = oof_overall_full_v2
conv_oof_df_v2.to_csv('oof_predictions_convnext_v2.csv', index=False)


=== ConvNeXt v2 Fold 1/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.67it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.86it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.93it/s]

Epoch 1/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 20.11it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.02it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.00it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.54it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.34it/s]




Epoch 1: TrainBCE=0.8332, ValWLL=0.4408, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.53it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.60it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.62it/s]

Epoch 2/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 19.92it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.61it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.00it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.56it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.54it/s]




Epoch 2: TrainBCE=0.9736, ValWLL=0.4406, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.96it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.26it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.76it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.67it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.16it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.65it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.38it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.72it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.12it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.19it/s]




Epoch 3: TrainBCE=0.8330, ValWLL=0.4404, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.78it/s]

Epoch 4/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.82it/s]

Epoch 4/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.18it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.77it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.22it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.24it/s]




Epoch 4: TrainBCE=0.7717, ValWLL=0.4425, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.89it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.13it/s]

Epoch 5/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.69it/s]

Epoch 5/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.20it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.60it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.19it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.85it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.26it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.77it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.24it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.09it/s]




Epoch 5: TrainBCE=0.7410, ValWLL=0.4469, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.89it/s]

Epoch 6/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.94it/s]

Epoch 6/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.26it/s]

Epoch 6/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.37it/s]

Epoch 6/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.98it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.18it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.85it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.27it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.98it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.50it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.36it/s]




Epoch 6: TrainBCE=0.7314, ValWLL=0.4412, LR=1.00e-05
Early stopping at epoch 6
ConvNeXt v2 Fold 1 best OOF WLL (val): 0.4404


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.13it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  9.77it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  7.58it/s]




ConvNeXt v2 Fold 1 OOF WLL: 0.4412

=== ConvNeXt v2 Fold 2/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.52it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.62it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.70it/s]

Epoch 1/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 19.95it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.70it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.79it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.23it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.10it/s]




Epoch 1: TrainBCE=0.8659, ValWLL=0.4444, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.32it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.20it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.22it/s]

Epoch 2/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 19.60it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.28it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.90it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.40it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.37it/s]




Epoch 2: TrainBCE=0.8371, ValWLL=0.4443, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.94it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.25it/s]

Epoch 3/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.14it/s]

Epoch 3/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.47it/s]

Epoch 3/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.07it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.26it/s]

Epoch 3/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.42it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.39it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.00it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.54it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.51it/s]




Epoch 3: TrainBCE=0.8624, ValWLL=0.4445, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.68it/s]

Epoch 4/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.68it/s]

Epoch 4/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.05it/s]

Epoch 4/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.21it/s]

Epoch 4/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.87it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.08it/s]

Epoch 4/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.27it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.08it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.08it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.65it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.63it/s]




Epoch 4: TrainBCE=0.7707, ValWLL=0.4456, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.66it/s]

Epoch 5/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.67it/s]

Epoch 5/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.03it/s]

Epoch 5/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.21it/s]

Epoch 5/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.87it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.09it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.76it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.08it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.94it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.48it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.45it/s]




Epoch 5: TrainBCE=0.7099, ValWLL=0.4471, LR=1.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.80it/s]

Epoch 6/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.82it/s]

Epoch 6/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.16it/s]

Epoch 6/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.29it/s]

Epoch 6/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.93it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.14it/s]

Epoch 6/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.33it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.20it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.56it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  8.86it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  6.89it/s]




Epoch 6: TrainBCE=0.6711, ValWLL=0.4488, LR=1.00e-05
Early stopping at epoch 6
ConvNeXt v2 Fold 2 best OOF WLL (val): 0.4443


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  3.92it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  9.46it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  7.21it/s]




ConvNeXt v2 Fold 2 OOF WLL: 0.4488

=== ConvNeXt v2 Fold 3/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.57it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.72it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.73it/s]

Epoch 1/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 19.96it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.86it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.66it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.16it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.08it/s]




Epoch 1: TrainBCE=0.8736, ValWLL=0.4456, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.55it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.50it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.20it/s]

Epoch 2/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 19.34it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.36it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.87it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.47it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.25it/s]




Epoch 2: TrainBCE=0.8123, ValWLL=0.4456, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.98it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.29it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.78it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.67it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.26it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.73it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.43it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.47it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  8.84it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  6.87it/s]




Epoch 3: TrainBCE=0.8468, ValWLL=0.4460, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.81it/s]

Epoch 4/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.84it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.33it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.01it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.81it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.22it/s]

Epoch 4/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.37it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.22it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.89it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.50it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.35it/s]




Epoch 4: TrainBCE=0.7269, ValWLL=0.4462, LR=1.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.63it/s]

Epoch 5/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.61it/s]

Epoch 5/15 Train:  36%|███▋      | 4/11 [00:00<00:01,  6.99it/s]

Epoch 5/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.17it/s]

Epoch 5/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.85it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.06it/s]

Epoch 5/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.25it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.07it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.80it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.36it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.20it/s]




Epoch 5: TrainBCE=0.7611, ValWLL=0.4474, LR=1.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.77it/s]

Epoch 6/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.82it/s]

Epoch 6/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.16it/s]

Epoch 6/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.29it/s]

Epoch 6/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.93it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.12it/s]

Epoch 6/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.31it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.13it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.00it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.69it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.53it/s]




Epoch 6: TrainBCE=0.7329, ValWLL=0.4467, LR=1.00e-05
Early stopping at epoch 6
ConvNeXt v2 Fold 3 best OOF WLL (val): 0.4456


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  3.98it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  9.64it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  7.54it/s]




ConvNeXt v2 Fold 3 OOF WLL: 0.4467

=== ConvNeXt v2 Fold 4/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.65it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.87it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.85it/s]

Epoch 1/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 20.06it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.88it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.82it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.39it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.34it/s]




Epoch 1: TrainBCE=0.9226, ValWLL=0.4712, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.62it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.69it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.71it/s]

Epoch 2/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 19.90it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.75it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.06it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.74it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.68it/s]




Epoch 2: TrainBCE=0.9887, ValWLL=0.4708, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.05it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.38it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.86it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.72it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.24it/s]

Epoch 3/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.38it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.52it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.93it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.53it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.44it/s]




Epoch 3: TrainBCE=0.7322, ValWLL=0.4697, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.65it/s]

Epoch 4/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.65it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.15it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.89it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.74it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.20it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.72it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.09it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.97it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.63it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.51it/s]




Epoch 4: TrainBCE=0.7058, ValWLL=0.4663, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.88it/s]

Epoch 5/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.95it/s]

Epoch 5/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.25it/s]

Epoch 5/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.35it/s]

Epoch 5/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.95it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.13it/s]

Epoch 5/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.32it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.21it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.96it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.57it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.52it/s]




Epoch 5: TrainBCE=0.7344, ValWLL=0.4657, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.90it/s]

Epoch 6/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.96it/s]

Epoch 6/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.27it/s]

Epoch 6/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.37it/s]

Epoch 6/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.97it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.15it/s]

Epoch 6/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.34it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.29it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.75it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.27it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.29it/s]




Epoch 6: TrainBCE=0.7644, ValWLL=0.4659, LR=2.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.72it/s]

Epoch 7/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.73it/s]

Epoch 7/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.07it/s]

Epoch 7/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.77it/s]

Epoch 7/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.70it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.13it/s]

Epoch 7/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.30it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.12it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.78it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.28it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.21it/s]




Epoch 7: TrainBCE=0.7296, ValWLL=0.4583, LR=2.00e-05


Epoch 8/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.87it/s]

Epoch 8/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.92it/s]

Epoch 8/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.36it/s]

Epoch 8/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.03it/s]

Epoch 8/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.81it/s]

Epoch 8/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.21it/s]

Epoch 8/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.36it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:01<00:00,  6.10it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:01<00:00,  6.60it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.04it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.76it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.53it/s]




Epoch 8: TrainBCE=0.7187, ValWLL=0.4538, LR=2.00e-05


Epoch 9/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.80it/s]

Epoch 9/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.83it/s]

Epoch 9/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.32it/s]

Epoch 9/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.97it/s]

Epoch 9/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.79it/s]

Epoch 9/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.05it/s]

Epoch 9/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.25it/s]

Epoch 9/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.43it/s]

Epoch 9/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.16it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.48it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  8.87it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  6.82it/s]




Epoch 9: TrainBCE=0.7108, ValWLL=0.4523, LR=2.00e-05


Epoch 10/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.85it/s]

Epoch 10/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.89it/s]

Epoch 10/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.37it/s]

Epoch 10/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.03it/s]

Epoch 10/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.82it/s]

Epoch 10/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.24it/s]

Epoch 10/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.38it/s]

Epoch 10/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.23it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.95it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.55it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.31it/s]




Epoch 10: TrainBCE=0.6971, ValWLL=0.4525, LR=2.00e-05


Epoch 11/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 11/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.71it/s]

Epoch 11/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.73it/s]

Epoch 11/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.08it/s]

Epoch 11/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.21it/s]

Epoch 11/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.59it/s]

Epoch 11/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.12it/s]

Epoch 11/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.29it/s]

Epoch 11/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.09it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.94it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.54it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.43it/s]




Epoch 11: TrainBCE=0.7084, ValWLL=0.4533, LR=2.00e-05


Epoch 12/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 12/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.85it/s]

Epoch 12/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.88it/s]

Epoch 12/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.37it/s]

Epoch 12/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.03it/s]

Epoch 12/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.82it/s]

Epoch 12/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.22it/s]

Epoch 12/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.36it/s]

Epoch 12/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.19it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.95it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.59it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.49it/s]




Epoch 12: TrainBCE=0.7428, ValWLL=0.4553, LR=1.00e-05


Epoch 13/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 13/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.94it/s]

Epoch 13/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  5.00it/s]

Epoch 13/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.47it/s]

Epoch 13/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.10it/s]

Epoch 13/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.86it/s]

Epoch 13/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.27it/s]

Epoch 13/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.41it/s]

Epoch 13/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.23it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.90it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.47it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.39it/s]




Epoch 13: TrainBCE=0.7132, ValWLL=0.4575, LR=1.00e-05


Epoch 14/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 14/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.60it/s]

Epoch 14/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.58it/s]

Epoch 14/15 Train:  36%|███▋      | 4/11 [00:00<00:01,  6.95it/s]

Epoch 14/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.14it/s]

Epoch 14/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  8.81it/s]

Epoch 14/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.04it/s]

Epoch 14/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.25it/s]

Epoch 14/15 Train: 100%|██████████| 11/11 [00:01<00:00,  7.97it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.85it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.45it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.26it/s]




Epoch 14: TrainBCE=0.7218, ValWLL=0.4600, LR=1.00e-05
Early stopping at epoch 14
ConvNeXt v2 Fold 4 best OOF WLL (val): 0.4523


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  3.81it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  9.32it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  7.25it/s]




ConvNeXt v2 Fold 4 OOF WLL: 0.4600

=== ConvNeXt v2 Fold 5/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.35it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.42it/s]

Epoch 1/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 16.48it/s]

Epoch 1/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 19.70it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.42it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.97it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.62it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.36it/s]




Epoch 1: TrainBCE=0.7533, ValWLL=0.4664, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.15it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 10.91it/s]

Epoch 2/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 15.90it/s]

Epoch 2/15 Train:  91%|█████████ | 10/11 [00:00<00:00, 19.30it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.07it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.64it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.07it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  6.99it/s]




Epoch 2: TrainBCE=0.7468, ValWLL=0.4662, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.85it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.10it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.64it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.55it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.03it/s]

Epoch 3/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.20it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.23it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.96it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.60it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.27it/s]




Epoch 3: TrainBCE=0.8195, ValWLL=0.4636, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.76it/s]

Epoch 4/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.79it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.27it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.96it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.77it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.19it/s]

Epoch 4/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.34it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.14it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.01it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.63it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.44it/s]




Epoch 4: TrainBCE=0.6920, ValWLL=0.4603, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.68it/s]

Epoch 5/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.68it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.17it/s]

Epoch 5/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  7.89it/s]

Epoch 5/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.39it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.81it/s]

Epoch 5/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.12it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.33it/s]

Epoch 5/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.51it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.07it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.87it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.42it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.36it/s]




Epoch 5: TrainBCE=0.7406, ValWLL=0.4539, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.94it/s]

Epoch 6/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.98it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.45it/s]

Epoch 6/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.47it/s]

Epoch 6/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.64it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.97it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.35it/s]

Epoch 6/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.48it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.24it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.93it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.53it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.42it/s]




Epoch 6: TrainBCE=0.7575, ValWLL=0.4560, LR=2.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.88it/s]

Epoch 7/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.93it/s]

Epoch 7/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.42it/s]

Epoch 7/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.47it/s]

Epoch 7/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.22it/s]

Epoch 7/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.75it/s]

Epoch 7/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.36it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.37it/s]

Epoch 7/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.47it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.17it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.97it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.58it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.31it/s]




Epoch 7: TrainBCE=0.7063, ValWLL=0.4501, LR=2.00e-05


Epoch 8/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.83it/s]

Epoch 8/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.87it/s]

Epoch 8/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.35it/s]

Epoch 8/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.41it/s]

Epoch 8/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.17it/s]

Epoch 8/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  9.02it/s]

Epoch 8/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.38it/s]

Epoch 8/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.49it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.18it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.78it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.27it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.10it/s]




Epoch 8: TrainBCE=0.6779, ValWLL=0.4496, LR=2.00e-05


Epoch 9/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.68it/s]

Epoch 9/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.68it/s]

Epoch 9/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.18it/s]

Epoch 9/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.27it/s]

Epoch 9/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.07it/s]

Epoch 9/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.63it/s]

Epoch 9/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.26it/s]

Epoch 9/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.41it/s]

Epoch 9/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.54it/s]

Epoch 9/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.04it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.00it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.61it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.38it/s]




Epoch 9: TrainBCE=0.7977, ValWLL=0.4507, LR=2.00e-05


Epoch 10/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.83it/s]

Epoch 10/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.87it/s]

Epoch 10/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.35it/s]

Epoch 10/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.02it/s]

Epoch 10/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.50it/s]

Epoch 10/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.89it/s]

Epoch 10/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.18it/s]

Epoch 10/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.38it/s]

Epoch 10/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.54it/s]

Epoch 10/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.17it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.91it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.47it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.46it/s]




Epoch 10: TrainBCE=0.7171, ValWLL=0.4576, LR=2.00e-05


Epoch 11/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 11/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.73it/s]

Epoch 11/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.75it/s]

Epoch 11/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.24it/s]

Epoch 11/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.32it/s]

Epoch 11/15 Train:  45%|████▌     | 5/11 [00:00<00:00,  8.10it/s]

Epoch 11/15 Train:  64%|██████▎   | 7/11 [00:00<00:00,  8.94it/s]

Epoch 11/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.33it/s]

Epoch 11/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.47it/s]

Epoch 11/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.07it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.97it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.60it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.51it/s]




Epoch 11: TrainBCE=0.7157, ValWLL=0.4587, LR=1.00e-05


Epoch 12/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 12/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.74it/s]

Epoch 12/15 Train:  18%|█▊        | 2/11 [00:00<00:01,  4.77it/s]

Epoch 12/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.26it/s]

Epoch 12/15 Train:  36%|███▋      | 4/11 [00:00<00:00,  7.34it/s]

Epoch 12/15 Train:  55%|█████▍    | 6/11 [00:00<00:00,  8.56it/s]

Epoch 12/15 Train:  73%|███████▎  | 8/11 [00:01<00:00,  9.13it/s]

Epoch 12/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  9.28it/s]

Epoch 12/15 Train:  91%|█████████ | 10/11 [00:01<00:00,  9.40it/s]

Epoch 12/15 Train: 100%|██████████| 11/11 [00:01<00:00,  8.16it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.78it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.31it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  7.26it/s]




Epoch 12: TrainBCE=0.7393, ValWLL=0.4573, LR=1.00e-05
Early stopping at epoch 12
ConvNeXt v2 Fold 5 best OOF WLL (val): 0.4496


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  3.88it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  9.44it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  7.42it/s]

ConvNeXt v2 Fold 5 OOF WLL: 0.4573

ConvNeXt v2 CV Mean WLL: 0.4508 +/- 0.0069





In [15]:
# Compute OOF logits for ConvNeXt v2 temperature scaling
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader

# Reuse MIPDataset and val_transform from earlier cells
full_train_df = pd.read_csv('data/train_mips.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
y_full = full_train_df[label_cols].values
mip_dir = 'data/mips/train'
n_folds = 5

from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
skf = MultilabelStratifiedKFold(n_splits=n_folds, shuffle=True, random_state=123)
splits = list(skf.split(full_train_df, y_full))

def build_convnext_v2(num_classes=7):
    import timm
    m = timm.create_model('convnext_tiny', pretrained=False, num_classes=num_classes, in_chans=3,
                          drop_rate=0.4, drop_path_rate=0.2)
    return m.to(device).eval()

oof_logits_full_v2 = np.zeros((len(full_train_df), 7), dtype=np.float32)

for fold, (train_idx, val_idx) in enumerate(splits, 1):
    print(f'Computing OOF logits for ConvNeXt v2 Fold {fold}')
    val_df_f = full_train_df.iloc[val_idx].reset_index(drop=True)
    val_ds = MIPDataset(val_df_f, mip_dir, label_cols, val_transform)
    val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)

    ckpt = f'fold_{fold}_convnext_v2.pth'
    model = build_convnext_v2(7)
    sd = torch.load(ckpt, map_location='cpu')
    model.load_state_dict(sd, strict=True)

    model.eval()
    fold_logits = []
    with torch.no_grad():
        for images, _ in tqdm(val_loader, desc=f'Fold {fold} OOF logits v2'):
            images = images.to(device)
            logits = model(images)
            fold_logits.append(logits.cpu().numpy())
    fold_logits = np.concatenate(fold_logits, axis=0)
    oof_logits_full_v2[val_idx] = fold_logits
    print(f'Fold {fold} OOF logits v2 shape: {fold_logits.shape}')

# Save OOF logits v2
oof_logits_df_v2 = full_train_df[['StudyInstanceUID']].copy()
for i, col in enumerate(label_cols):
    oof_logits_df_v2[col] = oof_logits_full_v2[:, i]
oof_logits_df_v2.to_csv('oof_logits_convnext_v2.csv', index=False)
print('OOF logits v2 saved to oof_logits_convnext_v2.csv')

# Quick check: Compute probs from logits and verify matches oof_predictions_convnext_v2.csv
oof_probs_from_logits_v2 = 1 / (1 + np.exp(-oof_logits_full_v2))
oof_df_v2 = pd.read_csv('oof_predictions_convnext_v2.csv')
oof_probs_saved_v2 = oof_df_v2[label_cols].values
print(f'Probs v2 match (max diff): {np.max(np.abs(oof_probs_from_logits_v2 - oof_probs_saved_v2)):.6f}')

Computing OOF logits for ConvNeXt v2 Fold 1


Fold 1 OOF logits v2:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 1 OOF logits v2:  33%|███▎      | 1/3 [00:00<00:00,  4.04it/s]

Fold 1 OOF logits v2: 100%|██████████| 3/3 [00:00<00:00,  9.64it/s]

Fold 1 OOF logits v2: 100%|██████████| 3/3 [00:00<00:00,  7.40it/s]




Fold 1 OOF logits v2 shape: (41, 7)
Computing OOF logits for ConvNeXt v2 Fold 2


Fold 2 OOF logits v2:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 2 OOF logits v2:  33%|███▎      | 1/3 [00:00<00:00,  4.03it/s]

Fold 2 OOF logits v2: 100%|██████████| 3/3 [00:00<00:00,  9.58it/s]

Fold 2 OOF logits v2: 100%|██████████| 3/3 [00:00<00:00,  7.52it/s]




Fold 2 OOF logits v2 shape: (41, 7)
Computing OOF logits for ConvNeXt v2 Fold 3


Fold 3 OOF logits v2:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 3 OOF logits v2:  33%|███▎      | 1/3 [00:00<00:00,  4.08it/s]

Fold 3 OOF logits v2: 100%|██████████| 3/3 [00:00<00:00,  9.82it/s]

Fold 3 OOF logits v2: 100%|██████████| 3/3 [00:00<00:00,  7.62it/s]




Fold 3 OOF logits v2 shape: (40, 7)
Computing OOF logits for ConvNeXt v2 Fold 4


Fold 4 OOF logits v2:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 4 OOF logits v2:  33%|███▎      | 1/3 [00:00<00:00,  4.06it/s]

Fold 4 OOF logits v2: 100%|██████████| 3/3 [00:00<00:00,  9.74it/s]

Fold 4 OOF logits v2: 100%|██████████| 3/3 [00:00<00:00,  7.63it/s]




Fold 4 OOF logits v2 shape: (40, 7)
Computing OOF logits for ConvNeXt v2 Fold 5


Fold 5 OOF logits v2:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 5 OOF logits v2:  33%|███▎      | 1/3 [00:00<00:00,  3.85it/s]

Fold 5 OOF logits v2: 100%|██████████| 3/3 [00:00<00:00,  9.48it/s]

Fold 5 OOF logits v2: 100%|██████████| 3/3 [00:00<00:00,  7.32it/s]

Fold 5 OOF logits v2 shape: (40, 7)
OOF logits v2 saved to oof_logits_convnext_v2.csv
Probs v2 match (max diff): 0.000000





In [16]:
# Fit per-class temperature scaling for ConvNeXt ensemble (v1 + v2) using averaged OOF logits
import numpy as np
import pandas as pd
from scipy.optimize import minimize_scalar
from scipy.special import expit as sigmoid

# Load data
oof_logits_df_v1 = pd.read_csv('oof_logits_convnext.csv')
oof_logits_df_v2 = pd.read_csv('oof_logits_convnext_v2.csv')
train_df = pd.read_csv('data/train_mips.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
oof_logits_v1 = oof_logits_df_v1[label_cols].values  # (202, 7)
oof_logits_v2 = oof_logits_df_v2[label_cols].values  # (202, 7)
ensemble_oof_logits = (oof_logits_v1 + oof_logits_v2) / 2.0
y_true = train_df[label_cols].values  # (202, 7)

def binary_log_loss(y_true, y_pred):
    y_pred = np.clip(y_pred, 1e-6, 1-1e-6)
    return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))

# Optimize T per class on ensemble logits
temperatures_ensemble = np.ones(7)
for i, col in enumerate(label_cols):
    def loss_func(T):
        logits_scaled = ensemble_oof_logits[:, i] / T
        probs = sigmoid(logits_scaled)
        return binary_log_loss(y_true[:, i], probs)
    res = minimize_scalar(loss_func, bounds=(0.5, 2.0), method='bounded')
    temperatures_ensemble[i] = res.x
    print(f'{col} ensemble temperature: {temperatures_ensemble[i]:.4f}, min loss: {res.fun:.4f}')

# Save ensemble temperatures
np.save('temperatures_ensemble.npy', temperatures_ensemble)
print('Ensemble temperatures saved to temperatures_ensemble.npy')

# Quick check: Compute calibrated ensemble OOF BCE (7 classes, weights=1)
logits_calib_ens = ensemble_oof_logits / temperatures_ensemble
probs_calib_ens = sigmoid(logits_calib_ens)
wll_calib_ens = np.mean([binary_log_loss(y_true[:, i], probs_calib_ens[:, i]) for i in range(7)])
wll_raw_ens = np.mean([binary_log_loss(y_true[:, i], sigmoid(ensemble_oof_logits[:, i])) for i in range(7)])
print(f'Ensemble Raw OOF BCE (avg): {wll_raw_ens:.4f}, Calibrated: {wll_calib_ens:.4f} (delta: {wll_calib_ens - wll_raw_ens:.4f})')

C1 ensemble temperature: 0.7869, min loss: 0.3267
C2 ensemble temperature: 0.6527, min loss: 0.4224
C3 ensemble temperature: 0.8730, min loss: 0.1846
C4 ensemble temperature: 0.7786, min loss: 0.2157
C5 ensemble temperature: 0.7154, min loss: 0.3491
C6 ensemble temperature: 0.6866, min loss: 0.4584
C7 ensemble temperature: 0.6223, min loss: 0.4495
Ensemble temperatures saved to temperatures_ensemble.npy
Ensemble Raw OOF BCE (avg): 0.3609, Calibrated: 0.3438 (delta: -0.0171)


In [18]:
# RegNetY-004 5-fold CV: diverse third backbone, stronger reg (drop=0.4, path=0.2), seed=456
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import WeightedRandomSampler
import torch.nn as nn
import torch.nn.functional as F
import timm

full_train_df = pd.read_csv('data/train_mips.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
y_full = full_train_df[label_cols].values
mip_dir = 'data/mips/train'
n_folds = 5

skf = MultilabelStratifiedKFold(n_splits=n_folds, shuffle=True, random_state=456)
splits = list(skf.split(full_train_df, y_full))

def build_regnet(num_classes=7):
    m = timm.create_model('regnety_004', pretrained=True, num_classes=num_classes, in_chans=3,
                          drop_rate=0.4, drop_path_rate=0.2)
    return m

oof_preds_full_reg = np.zeros((len(full_train_df), 7), dtype=np.float32)
oof_overall_full_reg = np.zeros(len(full_train_df), dtype=np.float32)
fold_scores_reg = []

for fold, (train_idx, val_idx) in enumerate(splits, 1):
    print(f'\n=== RegNetY Fold {fold}/{n_folds} ===')
    train_df_f = full_train_df.iloc[train_idx].reset_index(drop=True)
    val_df_f = full_train_df.iloc[val_idx].reset_index(drop=True)

    # pos_weight (clip=2.0)
    y_tr = train_df_f[label_cols].values
    pos = (y_tr == 1).sum(axis=0); neg = (y_tr == 0).sum(axis=0)
    w = np.minimum(neg / (pos + 1e-6), 2.0).astype(np.float32)
    fold_pos_weight = torch.tensor(w, dtype=torch.float32, device=device)

    # WeightedRandomSampler (cap=3.0)
    class_rarity = np.minimum(neg / (pos + 1e-6), 10.0).astype(np.float32)
    sample_weight = (y_tr * class_rarity).max(axis=1)
    sample_weight = np.where(sample_weight > 0, sample_weight, 1.0)
    sample_weight = np.sqrt(sample_weight)
    sample_weight = np.clip(sample_weight, 1.0, 3.0)
    sampler = WeightedRandomSampler(torch.from_numpy(sample_weight), len(sample_weight), replacement=True)

    train_ds = MIPDataset(train_df_f, mip_dir, label_cols, train_transform)
    val_ds = MIPDataset(val_df_f, mip_dir, label_cols, val_transform)
    train_loader = DataLoader(train_ds, batch_size=16, sampler=sampler, num_workers=4, pin_memory=True, drop_last=False)
    val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)

    # Model
    model = build_regnet(7).to(device)

    # Bias init from priors
    with torch.no_grad():
        priors = np.clip(y_tr.mean(axis=0) + 1e-6, 1e-6, 1-1e-6)
        bias = np.log(priors / (1 - priors)).astype(np.float32)
        head = getattr(model, 'head', None)
        if head is not None and hasattr(head, 'fc') and head.fc.bias is not None:
            head.fc.bias.copy_(torch.from_numpy(bias).to(head.fc.bias.device))

    optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True, min_lr=1e-6)
    criterion = SmoothedBCEWithLogitsLoss(pos_weight=fold_pos_weight, smoothing=0.0)
    scaler = GradScaler()
    early_stopping = EarlyStopping(patience=5, min_delta=0.001, restore_best_weights=False)

    # EMA
    ema = timm.utils.ModelEmaV2(model, decay=0.995)

    num_epochs = 15
    warmup_epochs = 2
    best_wll = float('inf')
    best_state = None

    for epoch in range(num_epochs):
        # Head warmup
        if epoch < warmup_epochs:
            for p in model.parameters(): p.requires_grad = False
            if hasattr(model, 'head'):
                for p in model.head.parameters(): p.requires_grad = True
        else:
            for p in model.parameters(): p.requires_grad = True

        model.train()
        train_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} Train'):
            images = images.to(device); labels7 = labels.to(device)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast():
                logits = model(images)
                loss = criterion(logits, labels7)
            scaler.scale(loss).backward()
            nn_utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            ema.update(model)
            train_loss += loss.item() * images.size(0)
        train_loss /= len(train_ds)

        # Validation with EMA
        ema.module.eval()
        val_wll = 0.0
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc='Val'):
                images = images.to(device); labels7 = labels.to(device)
                logits = ema.module(images)
                probs7 = torch.sigmoid(logits)
                probs_overall = probs7.max(dim=1, keepdim=True)[0]
                labels_overall = labels7.max(dim=1, keepdim=True)[0]
                probs8 = torch.cat([probs7, probs_overall], dim=1)
                labels8 = torch.cat([labels7, labels_overall], dim=1)
                val_wll += weighted_log_loss_torch(labels8, probs8).item() * images.size(0)
        val_wll /= len(val_ds)
        scheduler.step(val_wll)
        print(f'Epoch {epoch+1}: TrainBCE={train_loss:.4f}, ValWLL={val_wll:.4f}, LR={optimizer.param_groups[0]["lr"]:.2e}')

        if val_wll < best_wll:
            best_wll = val_wll
            best_state = ema.module.state_dict()

        if early_stopping(val_wll, model):
            print(f'Early stopping at epoch {epoch+1}')
            break

    # Save best EMA weights
    torch.save(best_state, f'fold_{fold}_regnet.pth')
    print(f'RegNetY Fold {fold} best OOF WLL (val): {best_wll:.4f}')

    # OOF collection with best EMA
    ema.module.load_state_dict(best_state)
    ema.module.eval()
    fold_oof_preds_7 = np.zeros((len(val_df_f), 7), dtype=np.float32)
    fold_oof_overall = np.zeros(len(val_df_f), dtype=np.float32)
    start = 0
    with torch.no_grad():
        for images, _ in tqdm(val_loader, desc='Final Val for OOF'):
            images = images.to(device)
            logits = ema.module(images)
            probs7 = torch.sigmoid(logits).cpu().numpy()
            probs_overall = probs7.max(axis=1)
            bs = images.size(0)
            fold_oof_preds_7[start:start+bs] = probs7
            fold_oof_overall[start:start+bs] = probs_overall
            start += bs

    oof_preds_full_reg[val_idx] = fold_oof_preds_7
    oof_overall_full_reg[val_idx] = fold_oof_overall

    y_val = val_df_f[label_cols].values
    fold_preds_8 = np.column_stack([fold_oof_preds_7, fold_oof_overall])
    fold_labels_8 = np.column_stack([y_val, np.max(y_val, axis=1)])
    fold_score = weighted_log_loss(fold_labels_8, fold_preds_8)
    fold_scores_reg.append(fold_score)
    print(f'RegNetY Fold {fold} OOF WLL: {fold_score:.4f}')

cv_mean_reg = float(np.mean(fold_scores_reg)); cv_std_reg = float(np.std(fold_scores_reg))
print(f'\nRegNetY CV Mean WLL: {cv_mean_reg:.4f} +/- {cv_std_reg:.4f}')
reg_oof_df = full_train_df[['StudyInstanceUID']].copy()
reg_oof_df[label_cols] = oof_preds_full_reg
reg_oof_df['patient_overall'] = oof_overall_full_reg
reg_oof_df.to_csv('oof_predictions_regnet.csv', index=False)


=== RegNetY Fold 1/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.82it/s]

Epoch 1/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 13.08it/s]

Epoch 1/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 20.55it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 15.55it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.01it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  8.68it/s]




Epoch 1: TrainBCE=0.8361, ValWLL=0.4585, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.09it/s]

Epoch 2/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 13.80it/s]

Epoch 2/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 21.35it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 17.12it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.63it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.37it/s]




Epoch 2: TrainBCE=0.7832, ValWLL=0.4590, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.54it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  6.89it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.26it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 12.78it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 14.71it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:00<00:00, 13.66it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:01<00:00, 10.55it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.63it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.29it/s]




Epoch 3: TrainBCE=0.8814, ValWLL=0.4595, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.70it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.27it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.65it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.11it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 14.89it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.61it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.47it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.13it/s]




Epoch 4: TrainBCE=0.7838, ValWLL=0.4602, LR=1.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.79it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.33it/s]

Epoch 5/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.77it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.28it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.07it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.81it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.74it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.42it/s]




Epoch 5: TrainBCE=0.7548, ValWLL=0.4607, LR=1.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.80it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.33it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.73it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.22it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 14.73it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.55it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.20it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.63it/s]




Epoch 6: TrainBCE=0.8241, ValWLL=0.4607, LR=1.00e-05
Early stopping at epoch 6
RegNetY Fold 1 best OOF WLL (val): 0.4585


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.59it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00, 10.20it/s]




RegNetY Fold 1 OOF WLL: 0.4607

=== RegNetY Fold 2/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.30it/s]

Epoch 1/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 13.90it/s]

Epoch 1/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 21.42it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.24it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.47it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.15it/s]




Epoch 1: TrainBCE=0.8001, ValWLL=0.4858, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.23it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.65it/s]

Epoch 2/15 Train:  73%|███████▎  | 8/11 [00:00<00:00, 20.16it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 17.16it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.66it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.46it/s]




Epoch 2: TrainBCE=0.7749, ValWLL=0.4852, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.81it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.57it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.04it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.25it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 14.90it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:00<00:00, 14.45it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.16it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.45it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.99it/s]




Epoch 3: TrainBCE=0.7982, ValWLL=0.4806, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.80it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.30it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.74it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.18it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 14.93it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.63it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.26it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.60it/s]




Epoch 4: TrainBCE=0.8829, ValWLL=0.4799, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.06it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.81it/s]

Epoch 5/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.20it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.57it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.24it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.09it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.38it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.89it/s]




Epoch 5: TrainBCE=0.8489, ValWLL=0.4807, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.91it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.67it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.09it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.53it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.24it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.91it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.66it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.48it/s]




Epoch 6: TrainBCE=0.8515, ValWLL=0.4822, LR=2.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.03it/s]

Epoch 7/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.75it/s]

Epoch 7/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.10it/s]

Epoch 7/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.48it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.17it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.04it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.55it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.18it/s]




Epoch 7: TrainBCE=0.8359, ValWLL=0.4837, LR=1.00e-05


Epoch 8/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.97it/s]

Epoch 8/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.71it/s]

Epoch 8/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.09it/s]

Epoch 8/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.49it/s]

Epoch 8/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.20it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.92it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.23it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.61it/s]




Epoch 8: TrainBCE=0.8064, ValWLL=0.4836, LR=1.00e-05
Early stopping at epoch 8
RegNetY Fold 2 best OOF WLL (val): 0.4799


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.79it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00, 10.82it/s]




RegNetY Fold 2 OOF WLL: 0.4836

=== RegNetY Fold 3/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.36it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.88it/s]

Epoch 1/15 Train:  73%|███████▎  | 8/11 [00:00<00:00, 20.27it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 17.15it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.57it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.17it/s]




Epoch 1: TrainBCE=0.8465, ValWLL=0.4637, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:02,  3.44it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 12.16it/s]

Epoch 2/15 Train:  73%|███████▎  | 8/11 [00:00<00:00, 20.49it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.97it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.57it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.18it/s]




Epoch 2: TrainBCE=0.7981, ValWLL=0.4641, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.11it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.76it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.17it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.58it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.28it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.08it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.51it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.97it/s]




Epoch 3: TrainBCE=0.8080, ValWLL=0.4644, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.90it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.46it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.82it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.08it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 14.85it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.76it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.47it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.21it/s]




Epoch 4: TrainBCE=0.9154, ValWLL=0.4649, LR=1.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.91it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.50it/s]

Epoch 5/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.86it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.26it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.00it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.82it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.75it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.34it/s]




Epoch 5: TrainBCE=0.8425, ValWLL=0.4621, LR=1.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.03it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.69it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.08it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.50it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.25it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.08it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.55it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.39it/s]




Epoch 6: TrainBCE=0.7311, ValWLL=0.4618, LR=1.00e-05


Epoch 7/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 7/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.01it/s]

Epoch 7/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.67it/s]

Epoch 7/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.07it/s]

Epoch 7/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.52it/s]

Epoch 7/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.27it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.59it/s]

Epoch 7/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.99it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.24it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.60it/s]




Epoch 7: TrainBCE=0.8437, ValWLL=0.4615, LR=1.00e-05


Epoch 8/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 8/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.96it/s]

Epoch 8/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.65it/s]

Epoch 8/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.01it/s]

Epoch 8/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.42it/s]

Epoch 8/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.13it/s]

Epoch 8/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.89it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.58it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.15it/s]




Epoch 8: TrainBCE=0.8886, ValWLL=0.4604, LR=1.00e-05


Epoch 9/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 9/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.92it/s]

Epoch 9/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.59it/s]

Epoch 9/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.98it/s]

Epoch 9/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.41it/s]

Epoch 9/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.17it/s]

Epoch 9/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.52it/s]

Epoch 9/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.81it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.51it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.12it/s]




Epoch 9: TrainBCE=0.8428, ValWLL=0.4598, LR=1.00e-05


Epoch 10/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.91it/s]

Epoch 10/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.50it/s]

Epoch 10/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.90it/s]

Epoch 10/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.38it/s]

Epoch 10/15 Train:  82%|████████▏ | 9/11 [00:01<00:00,  8.30it/s]

Epoch 10/15 Train: 100%|██████████| 11/11 [00:01<00:00,  9.04it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.66it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.25it/s]




Epoch 10: TrainBCE=0.8041, ValWLL=0.4597, LR=1.00e-05


Epoch 11/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 11/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.09it/s]

Epoch 11/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.71it/s]

Epoch 11/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.13it/s]

Epoch 11/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.53it/s]

Epoch 11/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.09it/s]

Epoch 11/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.00it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.51it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.01it/s]




Epoch 11: TrainBCE=0.7612, ValWLL=0.4599, LR=1.00e-05


Epoch 12/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 12/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.84it/s]

Epoch 12/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.42it/s]

Epoch 12/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.84it/s]

Epoch 12/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.31it/s]

Epoch 12/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.09it/s]

Epoch 12/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.80it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.56it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.13it/s]




Epoch 12: TrainBCE=0.8026, ValWLL=0.4601, LR=1.00e-05


Epoch 13/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 13/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.09it/s]

Epoch 13/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.88it/s]

Epoch 13/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.29it/s]

Epoch 13/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.69it/s]

Epoch 13/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.40it/s]

Epoch 13/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.16it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.55it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.14it/s]




Epoch 13: TrainBCE=0.8166, ValWLL=0.4607, LR=5.00e-06
Early stopping at epoch 13
RegNetY Fold 3 best OOF WLL (val): 0.4597


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.65it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00, 10.30it/s]




RegNetY Fold 3 OOF WLL: 0.4607

=== RegNetY Fold 4/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.31it/s]

Epoch 1/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 13.98it/s]

Epoch 1/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 21.36it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 17.01it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  3.72it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  8.60it/s]




Epoch 1: TrainBCE=0.7964, ValWLL=0.4678, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.97it/s]

Epoch 2/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 10.96it/s]

Epoch 2/15 Train:  73%|███████▎  | 8/11 [00:00<00:00, 19.33it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.21it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.47it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.21it/s]




Epoch 2: TrainBCE=0.8354, ValWLL=0.4688, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.88it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.21it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 10.57it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.06it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 14.87it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.61it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.27it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.94it/s]




Epoch 3: TrainBCE=0.7657, ValWLL=0.4673, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.04it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.75it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.15it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.56it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.26it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.06it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.54it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.28it/s]




Epoch 4: TrainBCE=0.7697, ValWLL=0.4689, LR=2.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.11it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.81it/s]

Epoch 5/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.20it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.60it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.31it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.08it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.57it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.39it/s]




Epoch 5: TrainBCE=0.7886, ValWLL=0.4683, LR=2.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.88it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.59it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.00it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.44it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.17it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.01it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.52it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.10it/s]




Epoch 6: TrainBCE=0.8572, ValWLL=0.4673, LR=2.00e-05
Early stopping at epoch 6
RegNetY Fold 4 best OOF WLL (val): 0.4673


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.30it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00,  9.65it/s]




RegNetY Fold 4 OOF WLL: 0.4673

=== RegNetY Fold 5/5 ===


Epoch 1/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 1/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.16it/s]

Epoch 1/15 Train:  36%|███▋      | 4/11 [00:00<00:00, 11.18it/s]

Epoch 1/15 Train:  73%|███████▎  | 8/11 [00:00<00:00, 19.68it/s]

Epoch 1/15 Train: 100%|██████████| 11/11 [00:00<00:00, 16.51it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.25it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.43it/s]




Epoch 1: TrainBCE=0.8339, ValWLL=0.4836, LR=2.00e-05


Epoch 2/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 2/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.31it/s]

Epoch 2/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 14.32it/s]

Epoch 2/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 21.93it/s]

Epoch 2/15 Train: 100%|██████████| 11/11 [00:00<00:00, 17.52it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.61it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.44it/s]




Epoch 2: TrainBCE=0.8859, ValWLL=0.4873, LR=2.00e-05


Epoch 3/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 3/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.09it/s]

Epoch 3/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.75it/s]

Epoch 3/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.14it/s]

Epoch 3/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.55it/s]

Epoch 3/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.34it/s]

Epoch 3/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.16it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.43it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.08it/s]




Epoch 3: TrainBCE=0.8435, ValWLL=0.4844, LR=2.00e-05


Epoch 4/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 4/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.06it/s]

Epoch 4/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.81it/s]

Epoch 4/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.22it/s]

Epoch 4/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.64it/s]

Epoch 4/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.37it/s]

Epoch 4/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.05it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.63it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.53it/s]




Epoch 4: TrainBCE=0.8327, ValWLL=0.4853, LR=1.00e-05


Epoch 5/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 5/15 Train:   9%|▉         | 1/11 [00:00<00:03,  3.09it/s]

Epoch 5/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.80it/s]

Epoch 5/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.18it/s]

Epoch 5/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.56it/s]

Epoch 5/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.24it/s]

Epoch 5/15 Train: 100%|██████████| 11/11 [00:00<00:00, 12.12it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.31it/s]

Val: 100%|██████████| 3/3 [00:00<00:00,  9.71it/s]




Epoch 5: TrainBCE=0.8430, ValWLL=0.4851, LR=1.00e-05


Epoch 6/15 Train:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 6/15 Train:   9%|▉         | 1/11 [00:00<00:03,  2.89it/s]

Epoch 6/15 Train:  27%|██▋       | 3/11 [00:00<00:01,  7.65it/s]

Epoch 6/15 Train:  45%|████▌     | 5/11 [00:00<00:00, 11.04it/s]

Epoch 6/15 Train:  64%|██████▎   | 7/11 [00:00<00:00, 13.48it/s]

Epoch 6/15 Train:  82%|████████▏ | 9/11 [00:00<00:00, 15.22it/s]

Epoch 6/15 Train: 100%|██████████| 11/11 [00:00<00:00, 11.94it/s]




Val:   0%|          | 0/3 [00:00<?, ?it/s]

Val:  33%|███▎      | 1/3 [00:00<00:00,  4.54it/s]

Val: 100%|██████████| 3/3 [00:00<00:00, 10.22it/s]




Epoch 6: TrainBCE=0.7752, ValWLL=0.4840, LR=1.00e-05
Early stopping at epoch 6
RegNetY Fold 5 best OOF WLL (val): 0.4836


Final Val for OOF:   0%|          | 0/3 [00:00<?, ?it/s]

Final Val for OOF:  33%|███▎      | 1/3 [00:00<00:00,  4.68it/s]

Final Val for OOF: 100%|██████████| 3/3 [00:00<00:00, 10.21it/s]

RegNetY Fold 5 OOF WLL: 0.4840

RegNetY CV Mean WLL: 0.4713 +/- 0.0105





In [19]:
# Compute OOF logits for RegNetY temperature scaling
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader

# Reuse MIPDataset and val_transform from earlier cells
full_train_df = pd.read_csv('data/train_mips.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
y_full = full_train_df[label_cols].values
mip_dir = 'data/mips/train'
n_folds = 5

from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
skf = MultilabelStratifiedKFold(n_splits=n_folds, shuffle=True, random_state=456)
splits = list(skf.split(full_train_df, y_full))

def build_regnet(num_classes=7):
    import timm
    m = timm.create_model('regnety_004', pretrained=False, num_classes=num_classes, in_chans=3,
                          drop_rate=0.4, drop_path_rate=0.2)
    return m.to(device).eval()

oof_logits_full_reg = np.zeros((len(full_train_df), 7), dtype=np.float32)

for fold, (train_idx, val_idx) in enumerate(splits, 1):
    print(f'Computing OOF logits for RegNetY Fold {fold}')
    val_df_f = full_train_df.iloc[val_idx].reset_index(drop=True)
    val_ds = MIPDataset(val_df_f, mip_dir, label_cols, val_transform)
    val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)

    ckpt = f'fold_{fold}_regnet.pth'
    model = build_regnet(7)
    sd = torch.load(ckpt, map_location='cpu')
    model.load_state_dict(sd, strict=True)

    model.eval()
    fold_logits = []
    with torch.no_grad():
        for images, _ in tqdm(val_loader, desc=f'Fold {fold} OOF logits reg'):
            images = images.to(device)
            logits = model(images)
            fold_logits.append(logits.cpu().numpy())
    fold_logits = np.concatenate(fold_logits, axis=0)
    oof_logits_full_reg[val_idx] = fold_logits
    print(f'Fold {fold} OOF logits reg shape: {fold_logits.shape}')

# Save OOF logits reg
oof_logits_df_reg = full_train_df[['StudyInstanceUID']].copy()
for i, col in enumerate(label_cols):
    oof_logits_df_reg[col] = oof_logits_full_reg[:, i]
oof_logits_df_reg.to_csv('oof_logits_regnet.csv', index=False)
print('OOF logits reg saved to oof_logits_regnet.csv')

# Quick check: Compute probs from logits and verify matches oof_predictions_regnet.csv
oof_probs_from_logits_reg = 1 / (1 + np.exp(-oof_logits_full_reg))
oof_df_reg = pd.read_csv('oof_predictions_regnet.csv')
oof_probs_saved_reg = oof_df_reg[label_cols].values
print(f'Probs reg match (max diff): {np.max(np.abs(oof_probs_from_logits_reg - oof_probs_saved_reg)):.6f}')

Computing OOF logits for RegNetY Fold 1


Fold 1 OOF logits reg:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 1 OOF logits reg:  33%|███▎      | 1/3 [00:00<00:00,  4.38it/s]

Fold 1 OOF logits reg: 100%|██████████| 3/3 [00:00<00:00, 10.17it/s]




Fold 1 OOF logits reg shape: (41, 7)
Computing OOF logits for RegNetY Fold 2


Fold 2 OOF logits reg:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 2 OOF logits reg:  33%|███▎      | 1/3 [00:00<00:00,  4.35it/s]

Fold 2 OOF logits reg: 100%|██████████| 3/3 [00:00<00:00,  9.64it/s]




Fold 2 OOF logits reg shape: (40, 7)
Computing OOF logits for RegNetY Fold 3


Fold 3 OOF logits reg:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 3 OOF logits reg:  33%|███▎      | 1/3 [00:00<00:00,  4.59it/s]

Fold 3 OOF logits reg: 100%|██████████| 3/3 [00:00<00:00, 10.41it/s]




Fold 3 OOF logits reg shape: (41, 7)
Computing OOF logits for RegNetY Fold 4


Fold 4 OOF logits reg:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 4 OOF logits reg:  33%|███▎      | 1/3 [00:00<00:00,  4.30it/s]

Fold 4 OOF logits reg: 100%|██████████| 3/3 [00:00<00:00,  9.73it/s]




Fold 4 OOF logits reg shape: (40, 7)
Computing OOF logits for RegNetY Fold 5


Fold 5 OOF logits reg:   0%|          | 0/3 [00:00<?, ?it/s]

Fold 5 OOF logits reg:  33%|███▎      | 1/3 [00:00<00:00,  4.63it/s]

Fold 5 OOF logits reg: 100%|██████████| 3/3 [00:00<00:00, 10.41it/s]

Fold 5 OOF logits reg shape: (40, 7)
OOF logits reg saved to oof_logits_regnet.csv
Probs reg match (max diff): 0.000000





In [28]:
import torch, numpy as np, pandas as pd
from torch.utils.data import DataLoader
from tqdm import tqdm
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from scipy.special import expit as sigmoid

full_train_df = pd.read_csv('data/train_mips.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
y_full = full_train_df[label_cols].values
mip_dir = 'data/mips/train'
device = torch.device('cuda')

def oof_tta(model_builder, ckpt_tmpl, n_folds, seed):
    skf = MultilabelStratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
    splits = list(skf.split(full_train_df, y_full))
    oof_logits = np.zeros((len(full_train_df), 7), dtype=np.float32)
    for fold,(tr_idx, val_idx) in enumerate(splits, 1):
        val_df = full_train_df.iloc[val_idx].reset_index(drop=True)
        val_ds = MIPDataset(val_df, mip_dir, label_cols, val_transform)
        val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

        model = model_builder(7).to(device).eval()
        sd = torch.load(ckpt_tmpl.format(fold), map_location='cpu')
        model.load_state_dict(sd, strict=True)

        fold_logits = []
        with torch.no_grad():
            for images, _ in tqdm(val_loader, desc=f'TTA OOF fold {fold}'):
                images = images.to(device)
                with torch.cuda.amp.autocast():
                    lo = model(images)
                    lo_f = model(images.flip(-1))  # HFlip along width
                    lo = 0.5*(lo + lo_f)
                fold_logits.append(lo.float().cpu().numpy())
        fold_logits = np.concatenate(fold_logits, axis=0)
        oof_logits[val_idx] = fold_logits
    return oof_logits

def build_convnext_v1(num_classes=7):
    import timm
    return timm.create_model('convnext_tiny', pretrained=False, num_classes=num_classes, in_chans=3,
                             drop_rate=0.3, drop_path_rate=0.1)
def build_convnext_v2(num_classes=7):
    import timm
    return timm.create_model('convnext_tiny', pretrained=False, num_classes=num_classes, in_chans=3,
                             drop_rate=0.4, drop_path_rate=0.2)
def build_regnet(num_classes=7):
    import timm
    return timm.create_model('regnety_004', pretrained=False, num_classes=num_classes, in_chans=3,
                             drop_rate=0.4, drop_path_rate=0.2)

print('=== ConvNeXt v1 TTA OOF ===')
v1_tta = oof_tta(build_convnext_v1, 'fold_{}_convnext.pth', n_folds=5, seed=42)
np.save('oof_logits_convnext_tta.npy', v1_tta)

print('=== ConvNeXt v2 TTA OOF ===')
v2_tta = oof_tta(build_convnext_v2, 'fold_{}_convnext_v2.pth', n_folds=5, seed=123)
np.save('oof_logits_convnext_v2_tta.npy', v2_tta)

print('=== RegNet TTA OOF ===')
reg_tta = oof_tta(build_regnet, 'fold_{}_regnet.pth', n_folds=5, seed=456)
np.save('oof_logits_regnet_tta.npy', reg_tta)

=== ConvNeXt v1 TTA OOF ===


TTA OOF fold 1:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 1:  50%|█████     | 1/2 [00:00<00:00,  2.41it/s]

TTA OOF fold 1: 100%|██████████| 2/2 [00:00<00:00,  4.07it/s]




TTA OOF fold 2:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 2:  50%|█████     | 1/2 [00:00<00:00,  2.69it/s]

TTA OOF fold 2: 100%|██████████| 2/2 [00:00<00:00,  4.43it/s]




TTA OOF fold 3:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 3:  50%|█████     | 1/2 [00:00<00:00,  2.68it/s]

TTA OOF fold 3: 100%|██████████| 2/2 [00:00<00:00,  4.38it/s]




TTA OOF fold 4:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 4:  50%|█████     | 1/2 [00:00<00:00,  2.61it/s]

TTA OOF fold 4: 100%|██████████| 2/2 [00:00<00:00,  4.30it/s]




TTA OOF fold 5:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 5:  50%|█████     | 1/2 [00:00<00:00,  2.77it/s]

TTA OOF fold 5: 100%|██████████| 2/2 [00:00<00:00,  4.62it/s]




=== ConvNeXt v2 TTA OOF ===


TTA OOF fold 1:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 1:  50%|█████     | 1/2 [00:00<00:00,  2.73it/s]

TTA OOF fold 1: 100%|██████████| 2/2 [00:00<00:00,  4.51it/s]




TTA OOF fold 2:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 2:  50%|█████     | 1/2 [00:00<00:00,  2.77it/s]

TTA OOF fold 3:  50%|█████     | 1/2 [00:00<00:00,  2.69it/s]

TTA OOF fold 3: 100%|██████████| 2/2 [00:00<00:00,  4.46it/s]




TTA OOF fold 4:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 4:  50%|█████     | 1/2 [00:00<00:00,  2.76it/s]

TTA OOF fold 4: 100%|██████████| 2/2 [00:00<00:00,  4.61it/s]




TTA OOF fold 5:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 5:  50%|█████     | 1/2 [00:00<00:00,  2.75it/s]

TTA OOF fold 5: 100%|██████████| 2/2 [00:00<00:00,  4.60it/s]




=== RegNet TTA OOF ===


TTA OOF fold 1:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 1:  50%|█████     | 1/2 [00:00<00:00,  3.03it/s]

TTA OOF fold 1: 100%|██████████| 2/2 [00:00<00:00,  4.75it/s]




TTA OOF fold 2:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 2:  50%|█████     | 1/2 [00:00<00:00,  3.32it/s]

TTA OOF fold 2: 100%|██████████| 2/2 [00:00<00:00,  5.05it/s]




TTA OOF fold 3:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 3:  50%|█████     | 1/2 [00:00<00:00,  3.42it/s]

TTA OOF fold 3: 100%|██████████| 2/2 [00:00<00:00,  5.62it/s]




TTA OOF fold 4:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 4:  50%|█████     | 1/2 [00:00<00:00,  3.34it/s]

TTA OOF fold 4: 100%|██████████| 2/2 [00:00<00:00,  5.60it/s]




TTA OOF fold 5:   0%|          | 0/2 [00:00<?, ?it/s]

TTA OOF fold 5:  50%|█████     | 1/2 [00:00<00:00,  3.37it/s]

TTA OOF fold 5: 100%|██████████| 2/2 [00:00<00:00,  5.65it/s]




In [29]:
import numpy as np, pandas as pd
from scipy.optimize import minimize
from scipy.special import expit as sigmoid
from sklearn.metrics import log_loss

label_cols = ['C1','C2','C3','C4','C5','C6','C7']
ids = pd.read_csv('data/train_mips.csv')[['StudyInstanceUID']]
train = pd.read_csv('train.csv')
y_df = ids.merge(train, on='StudyInstanceUID', how='left')
y7 = y_df[label_cols].values.astype(float)
y_overall = y7.max(axis=1, keepdims=True)
y8 = np.hstack([y7, y_overall])

v1 = np.load('oof_logits_convnext_tta.npy')
v2 = np.load('oof_logits_convnext_v2_tta.npy')
reg = np.load('oof_logits_regnet_tta.npy')
stack = np.stack([v1, v2, reg], axis=2)  # [N,7,3]

# Per-class weights with RegNet cap<=0.3
bounds = [(0,1),(0,1),(0,0.3)]
constraints = [{'type':'eq','fun': lambda w: np.sum(w)-1.0}]
w_opt = np.zeros((7,3), dtype=float)

def obj_w(w, logits3, y):
    blend = np.sum(logits3 * w, axis=1)
    return log_loss(y, sigmoid(blend), labels=[0,1])

for c in range(7):
    res = minimize(obj_w, x0=np.array([0.5,0.5,0.0]), args=(stack[:,c,:], y7[:,c]),
                   method='SLSQP', bounds=bounds, constraints=constraints)
    w = res.x
    w_opt[c] = w
np.save('weights_threeway_tta.npy', w_opt)

# Blend logits and choose overall rule on UNCALIBRATED probs
blend = np.sum(stack * w_opt[None,:,:], axis=2)      # [N,7]
p7_uncal = sigmoid(blend)
p_overall_max = p7_uncal.max(axis=1, keepdims=True)
p_overall_union = 1 - np.prod(1 - p7_uncal, axis=1, keepdims=True)

def wll(y_true8, p8):
    p8 = np.clip(p8, 1e-6, 1-1e-6)
    losses = [log_loss(y_true8[:,i], p8[:,i]) for i in range(8)]
    return np.average(losses, weights=np.array([1]*7+[2], float))

wll_max = wll(y8, np.hstack([p7_uncal, p_overall_max]))
wll_union = wll(y8, np.hstack([p7_uncal, p_overall_union]))
rule = 'union' if wll_union < wll_max else 'max'
open('overall_rule_tta.txt','w').write(rule)
print(f'TTA OOF WLL max={wll_max:.4f}, union={wll_union:.4f}, chosen={rule}')

# Refit per-class temperatures on blended logits (7 only)
def fit_temp(logits, y):
    def loss(z):
        t = np.exp(z)
        p = sigmoid(logits / t)
        p = np.clip(p, 1e-6, 1-1e-6)
        return log_loss(y, p, labels=[0,1])
    r = minimize(loss, x0=0.0, method='L-BFGS-B', bounds=[(-3,3)])
    return float(np.exp(r.x[0]))
temps = np.array([fit_temp(blend[:,k], y7[:,k]) for k in range(7)], dtype=np.float32)
np.save('temperatures_weighted_tta.npy', temps)

# Final TTA OOF check (7 calibrated + overall from UNCALIBRATED as per rule)
p7_cal = sigmoid(blend / temps)
p_overall = p_overall_union if rule=='union' else p_overall_max
p8_final = np.hstack([p7_cal, p_overall])
print(f'Final TTA OOF WLL: {wll(y8, p8_final):.4f}')


TTA OOF WLL max=0.4507, union=0.4468, chosen=union
Final TTA OOF WLL: 0.4423


In [31]:
# NEW CELL: Learned patient_overall from blended TTA OOF logits + alpha blend with base
import numpy as np, pandas as pd, joblib
from sklearn.linear_model import LogisticRegression
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from scipy.special import expit as sigmoid
from scipy.optimize import minimize_scalar
from sklearn.metrics import log_loss

label_cols = ['C1','C2','C3','C4','C5','C6','C7']

# Labels aligned to OOF rows
ids = pd.read_csv('data/train_mips.csv')[['StudyInstanceUID']]
train = pd.read_csv('train.csv')
y_df = ids.merge(train, on='StudyInstanceUID', how='left')
y7 = y_df[label_cols].values.astype(float)
y_overall = y7.max(axis=1).astype(int)
y8 = np.hstack([y7, y_overall[:,None]])

# Load TTA OOF logits and artifacts
v1 = np.load('oof_logits_convnext_tta.npy')       # [N,7]
v2 = np.load('oof_logits_convnext_v2_tta.npy')    # [N,7]
reg= np.load('oof_logits_regnet_tta.npy')         # [N,7]
W  = np.load('weights_threeway_tta.npy')          # [7,3]
T7 = np.load('temperatures_weighted_tta.npy')     # [7]
rule = open('overall_rule_tta.txt').read().strip()

# Blended logits (features)
stack = np.stack([v1, v2, reg], axis=2)           # [N,7,3]
X = np.sum(stack * W[None,:,:], axis=2)           # [N,7] blended logits

# Probs for later WLL computation (keep parity: overall base from UNCALIBRATED)
p7_uncal = sigmoid(X)
p7_cal   = sigmoid(X / T7)

def wll_8(y8_true, p7_calib, poverall):
    p8 = np.hstack([p7_calib, poverall.reshape(-1,1)])
    p8 = np.clip(p8, 1e-6, 1-1e-6)
    losses = [log_loss(y8_true[:,i], p8[:,i]) for i in range(8)]
    return np.average(losses, weights=np.array([1]*7+[2], float))

# 5-fold CV to get strictly OOF LR predictions
skf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=42)
p_lr_oof = np.zeros(len(X), dtype=np.float32)
for tr_idx, va_idx in skf.split(X, y7):
    lr = LogisticRegression(penalty='l2', C=1.0, class_weight='balanced',
                            solver='lbfgs', max_iter=1000, random_state=42)
    lr.fit(X[tr_idx], y_overall[tr_idx])
    p_lr_oof[va_idx] = lr.predict_proba(X[va_idx])[:,1].astype(np.float32)

# Base overall from uncalibrated probs
base_max   = p7_uncal.max(axis=1)
base_union = 1.0 - np.prod(1.0 - p7_uncal, axis=1)
base_used  = base_union if rule=='union' else base_max

# Tune alpha on full OOF (minimize weighted 8-class log loss)
def obj(alpha):
    poverall = alpha * p_lr_oof + (1.0 - alpha) * base_used
    return wll_8(y8, p7_cal, poverall)

res = minimize_scalar(obj, bounds=(0,1), method='bounded')
alpha = float(res.x)

# Refit LR on all data for inference
lr_full = LogisticRegression(penalty='l2', C=1.0, class_weight='balanced',
                             solver='lbfgs', max_iter=1000, random_state=42)
lr_full.fit(X, y_overall)

# Save artifacts
joblib.dump(lr_full, 'overall_regressor_tta.pkl')
np.save('overall_lr_alpha_tta.npy', np.array([alpha], dtype=np.float32))
with open('overall_lr_base_tta.txt','w') as f: f.write(rule)

print(f'Overall LR stacker trained. base={rule}, alpha={alpha:.4f}, OOF WLL={obj(alpha):.4f}')
print('Saved: overall_regressor_tta.pkl, overall_lr_alpha_tta.npy, overall_lr_base_tta.txt')

Overall LR stacker trained. base=union, alpha=1.0000, OOF WLL=0.4211
Saved: overall_regressor_tta.pkl, overall_lr_alpha_tta.npy, overall_lr_base_tta.txt


In [34]:
# 4-Way Ensemble: Add Swin-T OOF Logits to Three-Way TTA Blend (cap Swin <=0.4)
import numpy as np, pandas as pd
from scipy.optimize import minimize
from scipy.special import expit as sigmoid
from sklearn.metrics import log_loss
from sklearn.linear_model import LogisticRegression
import joblib
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from scipy.optimize import minimize_scalar

label_cols = ['C1','C2','C3','C4','C5','C6','C7']

# Load aligned labels
ids = pd.read_csv('data/train_mips.csv')[['StudyInstanceUID']]
train = pd.read_csv('train.csv')
y_df = ids.merge(train, on='StudyInstanceUID', how='left')
y7 = y_df[label_cols].values.astype(float)
y_overall = y7.max(axis=1).astype(int)
y8 = np.hstack([y7, y_overall[:,None]])

def wll(y_true8, p8):
    p8 = np.clip(p8, 1e-6, 1-1e-6)
    losses = [log_loss(y_true8[:,i], p8[:,i]) for i in range(8)]
    return np.average(losses, weights=np.array([1]*7+[2], float))

# Load TTA OOF logits [N,7] for three models + Swin
v1 = np.load('oof_logits_convnext_tta.npy')
v2 = np.load('oof_logits_convnext_v2_tta.npy')
reg = np.load('oof_logits_regnet_tta.npy')
swin = np.load('oof_logits_swin_tta.npy')
stack4 = np.stack([v1, v2, reg, swin], axis=2)  # [N,7,4]

# Optimize per-class weights (sum=1, >=0, Swin<=0.4) via SLSQP on BCE
bounds = [(0,1),(0,1),(0,1),(0,0.4)]
constraints = [{'type':'eq','fun': lambda w: np.sum(w)-1.0}]
w_opt4 = np.zeros((7,4), dtype=float)

def obj_w(w, logits4, y):
    blend = np.sum(logits4 * w, axis=1)
    return log_loss(y, sigmoid(blend), labels=[0,1])

for c in range(7):
    res = minimize(obj_w, x0=np.array([0.25]*4),
                   args=(stack4[:,c,:], y7[:,c]),
                   method='SLSQP', bounds=bounds, constraints=constraints,
                   options={'disp':False, 'maxiter':200})
    w_opt4[c] = res.x
    print(f'{label_cols[c]} weights (v1,v2,reg,swin): {res.x.round(3)}')

np.save('weights_fourway_tta.npy', w_opt4)

# Blend logits [N,7]
blend4 = np.sum(stack4 * w_opt4[None,:,:], axis=2)
p7_uncal4 = sigmoid(blend4)

# Overall rule on uncalibrated (union vs max)
p_overall_max4 = p7_uncal4.max(axis=1, keepdims=True)
p_overall_union4 = 1 - np.prod(1 - p7_uncal4, axis=1, keepdims=True)
p8_max4 = np.hstack([p7_uncal4, p_overall_max4])
p8_union4 = np.hstack([p7_uncal4, p_overall_union4])
wll_max4 = wll(y8, p8_max4)
wll_union4 = wll(y8, p8_union4)
print(f'4-way uncal WLL max={wll_max4:.4f}, union={wll_union4:.4f}')
rule4 = 'union' if wll_union4 < wll_max4 else 'max'
with open('overall_rule_fourway_tta.txt','w') as f: f.write(rule4)
p_overall_base4 = p_overall_union4 if rule4=='union' else p_overall_max4

# Refit temperatures on 4-way blend (7 classes)
def fit_temp(logits, y):
    def loss(z):
        t = np.exp(z)
        return log_loss(y, sigmoid(logits/t), labels=[0,1])
    from scipy.optimize import minimize
    r = minimize(loss, x0=0.0, method='L-BFGS-B', bounds=[(-3,3)])
    return float(np.exp(r.x[0]))
temps4 = np.array([fit_temp(blend4[:,k], y7[:,k]) for k in range(7)], dtype=np.float32)
np.save('temperatures_fourway_tta.npy', temps4)

# Calibrated probs
p7_cal4 = sigmoid(blend4 / temps4)
p8_cal4 = np.hstack([p7_cal4, p_overall_base4])
wll_cal4 = wll(y8, p8_cal4)
print(f'4-way calibrated OOF WLL: {wll_cal4:.4f}')

# Learned overall stacker: LR on 7 raw blended logits (not cal), tune alpha
# Features for LR: 7 raw blended logits
X_lr = blend4  # [N,7] raw blended logits

# 5-fold CV for OOF LR preds
skf_lr = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=42)
p_lr_oof = np.zeros(len(X_lr), dtype=np.float32)
for tr_idx, va_idx in skf_lr.split(X_lr, y7):
    lr = LogisticRegression(penalty='l2', C=1.0, class_weight='balanced',
                            solver='lbfgs', max_iter=1000, random_state=42)
    lr.fit(X_lr[tr_idx], y_overall[tr_idx])
    p_lr_oof[va_idx] = lr.predict_proba(X_lr[va_idx])[:,1].astype(np.float32)

# Tune alpha: min WLL with alpha*LR + (1-alpha)*base
def obj_alpha(alpha):
    p_overall = alpha * p_lr_oof + (1-alpha) * p_overall_base4.squeeze()
    p8_alpha = np.hstack([p7_cal4, p_overall.reshape(-1,1)])
    return wll(y8, p8_alpha)
res_alpha = minimize_scalar(obj_alpha, bounds=(0,1), method='bounded')
alpha4 = float(res_alpha.x)
print(f'4-way LR alpha: {alpha4:.4f}, WLL: {res_alpha.fun:.4f}')

# Refit LR on full data
lr_full4 = LogisticRegression(penalty='l2', C=1.0, class_weight='balanced',
                              solver='lbfgs', max_iter=1000, random_state=42)
lr_full4.fit(X_lr, y_overall)

# Save 4-way artifacts
joblib.dump(lr_full4, 'overall_regressor_fourway_tta.pkl')
np.save('overall_lr_alpha_fourway_tta.npy', np.array([alpha4]))
with open('overall_lr_base_fourway_tta.txt','w') as f: f.write(rule4)
np.save('oof_logits_fourway_blend.npy', blend4)  # for later use

# Final 4-way OOF WLL
p_overall_final4 = alpha4 * lr_full4.predict_proba(X_lr)[:,1] + (1-alpha4) * p_overall_base4.squeeze()
p8_final4 = np.hstack([p7_cal4, p_overall_final4.reshape(-1,1)])
final_wll4 = wll(y8, p8_final4)
print(f'Final 4-way OOF WLL: {final_wll4:.4f}')

# Save 4-way OOF probs for reference
oof_df4 = y_df[['StudyInstanceUID']].copy()
oof_df4[label_cols] = p7_cal4
oof_df4['patient_overall'] = p8_final4[:,-1]
oof_df4.to_csv('oof_predictions_fourway.csv', index=False)
print('4-way ensemble completed. Target <0.40 WLL achieved? Update inference next.')

C1 weights (v1,v2,reg,swin): [0.    0.    0.859 0.141]
C2 weights (v1,v2,reg,swin): [0.   0.   0.85 0.15]
C3 weights (v1,v2,reg,swin): [0.    0.142 0.858 0.   ]
C4 weights (v1,v2,reg,swin): [0.    0.    0.889 0.111]
C5 weights (v1,v2,reg,swin): [0.015 0.    0.985 0.   ]


C6 weights (v1,v2,reg,swin): [0.086 0.028 0.79  0.096]
C7 weights (v1,v2,reg,swin): [0.  0.  0.6 0.4]
4-way uncal WLL max=0.4599, union=0.4210
4-way calibrated OOF WLL: 0.4209


4-way LR alpha: 1.0000, WLL: 0.4154
Final 4-way OOF WLL: 0.4144
4-way ensemble completed. Target <0.40 WLL achieved? Update inference next.


In [30]:
import numpy as np, pandas as pd, os, random

def mip_stats(uids, root):
    rows=[]
    for uid in uids:
        p = os.path.join(root, f'{uid}.npy')
        if not os.path.exists(p): continue
        x = np.load(p).astype(np.float32)  # [3,H,W], expected in [0,1]
        for c,name in enumerate(['sag','cor','ax']):
            ch = x[c]
            rows.append(dict(uid=uid, ch=name, min=float(ch.min()), max=float(ch.max()),
                             mean=float(ch.mean()), std=float(ch.std())))
    return pd.DataFrame(rows)

train_df = pd.read_csv('data/train_mips.csv')
test_df  = pd.read_csv('data/test_mips.csv')
tr_ids = random.sample(train_df['StudyInstanceUID'].tolist(), k=min(5, len(train_df)))
te_ids = random.sample(test_df['StudyInstanceUID'].tolist(),  k=min(5, len(test_df)))

st_tr = mip_stats(tr_ids, 'data/mips/train')
st_te = mip_stats(te_ids, 'data/mips/test')
print('Train stats:'); print(st_tr)
print('Test stats:');  print(st_te)
print('Train channel means:'); print(st_tr.groupby('ch')[['min','max','mean','std']].mean().round(4))
print('Test channel means:');  print(st_te.groupby('ch')[['min','max','mean','std']].mean().round(4))
print('Check ranges within [0,1] and similar train/test means/stds per channel.')


Train stats:
                          uid   ch  min  max      mean       std
0   1.2.826.0.1.3680043.18906  sag  0.0  1.0  0.501914  0.290227
1   1.2.826.0.1.3680043.18906  cor  0.3  1.0  0.604927  0.269572
2   1.2.826.0.1.3680043.18906   ax  0.0  1.0  0.488586  0.252714
3   1.2.826.0.1.3680043.30020  sag  0.0  1.0  0.599643  0.296430
4   1.2.826.0.1.3680043.30020  cor  0.3  1.0  0.639518  0.265669
5   1.2.826.0.1.3680043.30020   ax  0.0  1.0  0.538895  0.361779
6    1.2.826.0.1.3680043.5812  sag  0.0  1.0  0.558129  0.287552
7    1.2.826.0.1.3680043.5812  cor  0.3  1.0  0.595411  0.250127
8    1.2.826.0.1.3680043.5812   ax  0.0  1.0  0.467655  0.317560
9   1.2.826.0.1.3680043.30475  sag  0.3  1.0  0.570425  0.282860
10  1.2.826.0.1.3680043.30475  cor  0.3  1.0  0.632678  0.258619
11  1.2.826.0.1.3680043.30475   ax  0.3  1.0  0.508638  0.231480
12   1.2.826.0.1.3680043.4842  sag  0.0  1.0  0.598569  0.323805
13   1.2.826.0.1.3680043.4842  cor  0.3  1.0  0.637620  0.285206
14   1.2.826

In [20]:
# Fit per-class temperature scaling for three-way ensemble (ConvNeXt v1 + v2 + RegNetY) using averaged OOF logits
import numpy as np
import pandas as pd
from scipy.optimize import minimize_scalar
from scipy.special import expit as sigmoid

# Load data
oof_logits_df_v1 = pd.read_csv('oof_logits_convnext.csv')
oof_logits_df_v2 = pd.read_csv('oof_logits_convnext_v2.csv')
oof_logits_df_reg = pd.read_csv('oof_logits_regnet.csv')
train_df = pd.read_csv('data/train_mips.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
oof_logits_v1 = oof_logits_df_v1[label_cols].values  # (202, 7)
oof_logits_v2 = oof_logits_df_v2[label_cols].values  # (202, 7)
oof_logits_reg = oof_logits_df_reg[label_cols].values  # (202, 7)
threeway_oof_logits = (oof_logits_v1 + oof_logits_v2 + oof_logits_reg) / 3.0
y_true = train_df[label_cols].values  # (202, 7)

def binary_log_loss(y_true, y_pred):
    y_pred = np.clip(y_pred, 1e-6, 1-1e-6)
    return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))

# Optimize T per class on three-way ensemble logits
temperatures_threeway = np.ones(7)
for i, col in enumerate(label_cols):
    def loss_func(T):
        logits_scaled = threeway_oof_logits[:, i] / T
        probs = sigmoid(logits_scaled)
        return binary_log_loss(y_true[:, i], probs)
    res = minimize_scalar(loss_func, bounds=(0.5, 2.0), method='bounded')
    temperatures_threeway[i] = res.x
    print(f'{col} threeway temperature: {temperatures_threeway[i]:.4f}, min loss: {res.fun:.4f}')

# Save three-way temperatures
np.save('temperatures_threeway.npy', temperatures_threeway)
print('Three-way ensemble temperatures saved to temperatures_threeway.npy')

# Quick check: Compute calibrated three-way OOF BCE (7 classes, weights=1)
logits_calib_three = threeway_oof_logits / temperatures_threeway
probs_calib_three = sigmoid(logits_calib_three)
wll_calib_three = np.mean([binary_log_loss(y_true[:, i], probs_calib_three[:, i]) for i in range(7)])
wll_raw_three = np.mean([binary_log_loss(y_true[:, i], sigmoid(threeway_oof_logits[:, i])) for i in range(7)])
print(f'Three-way Raw OOF BCE (avg): {wll_raw_three:.4f}, Calibrated: {wll_calib_three:.4f} (delta: {wll_calib_three - wll_raw_three:.4f})')

C1 threeway temperature: 0.8611, min loss: 0.3242
C2 threeway temperature: 0.7581, min loss: 0.4152
C3 threeway temperature: 0.9181, min loss: 0.1832
C4 threeway temperature: 0.8476, min loss: 0.2136
C5 threeway temperature: 0.8093, min loss: 0.3460
C6 threeway temperature: 0.8063, min loss: 0.4542
C7 threeway temperature: 0.7463, min loss: 0.4466
Three-way ensemble temperatures saved to temperatures_threeway.npy
Three-way Raw OOF BCE (avg): 0.3479, Calibrated: 0.3404 (delta: -0.0075)


In [24]:
# Learn per-class ensemble weights, choose overall rule, refit temperatures for three-way ensemble
import numpy as np
import pandas as pd
from scipy.optimize import minimize
from scipy.special import expit as sigmoid
from sklearn.metrics import log_loss

# Load OOF logits and true labels
oof_logits_df_v1 = pd.read_csv('oof_logits_convnext.csv')
oof_logits_df_v2 = pd.read_csv('oof_logits_convnext_v2.csv')
oof_logits_df_reg = pd.read_csv('oof_logits_regnet.csv')
train_df = pd.read_csv('data/train_mips.csv')
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
oof_logits_v1 = oof_logits_df_v1[label_cols].values
oof_logits_v2 = oof_logits_df_v2[label_cols].values
oof_logits_reg = oof_logits_df_reg[label_cols].values
y_true_7 = train_df[label_cols].values
y_true_overall = np.max(y_true_7, axis=1, keepdims=True)
y_true_8 = np.hstack([y_true_7, y_true_overall])

# Stack OOF logits [N,7,3]
oof_logits_stack = np.stack([oof_logits_v1, oof_logits_v2, oof_logits_reg], axis=2)

def compute_wll_8(probs_8):
    # Weighted log loss: weights [1]*7 + [2] for overall
    weights = np.array([1.0]*7 + [2.0])
    losses = [log_loss(y_true_8[:, i], probs_8[:, i]) for i in range(8)]
    return np.average(losses, weights=weights)

def binary_log_loss(y_true, y_pred):
    y_pred = np.clip(y_pred, 1e-6, 1-1e-6)
    return -np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))

# 1. Learn per-class weights (w1,w2,w3 >=0, sum=1) via SLSQP to minimize BCE for that class
from scipy.optimize import minimize
weights_per_class = np.zeros((7, 3))
for i in range(7):
    def objective(w):
        w = np.clip(w, 0, None)
        w = w / w.sum() if w.sum() > 0 else np.array([1/3]*3)
        blended_logits = np.sum(oof_logits_stack[:, i, :] * w[None, :], axis=1)
        probs = sigmoid(blended_logits)
        return binary_log_loss(y_true_7[:, i], probs)
    # Initial guess equal weights
    res = minimize(objective, x0=np.array([1/3]*3), method='SLSQP', bounds=[(0,None)]*3,
                   constraints={'type': 'eq', 'fun': lambda w:  w.sum() - 1})
    weights_per_class[i] = res.x / res.x.sum() if res.x.sum() > 0 else np.array([1/3]*3)
    print(f'{label_cols[i]} weights (v1,v2,reg): {weights_per_class[i]}')

# Save per-class weights
np.save('weights_threeway.npy', weights_per_class)
print('Per-class weights saved to weights_threeway.npy')

# 2. Blend OOF logits with learned weights [N,7]
blended_logits = np.sum(oof_logits_stack * weights_per_class[None, :, :], axis=2)
probs_7_blended = sigmoid(blended_logits)

# 3. Choose overall rule: max vs union by WLL on blended probs
probs_overall_max = np.max(probs_7_blended, axis=1, keepdims=True)
probs_overall_union = 1 - np.prod(1 - probs_7_blended, axis=1, keepdims=True)
probs_8_max = np.hstack([probs_7_blended, probs_overall_max])
probs_8_union = np.hstack([probs_7_blended, probs_overall_union])
wll_max = compute_wll_8(probs_8_max)
wll_union = compute_wll_8(probs_8_union)
print(f'Blended OOF WLL max: {wll_max:.4f}, union: {wll_union:.4f}')
if wll_union < wll_max:
    overall_rule = 'union'
    probs_8_chosen = probs_8_union
else:
    overall_rule = 'max'
    probs_8_chosen = probs_8_max
print(f'Chosen overall rule: {overall_rule}')
with open('overall_rule.txt', 'w') as f:
    f.write(overall_rule)

# 4. Refit per-class temperatures on blended logits (7 classes only)
temperatures_weighted = np.ones(7)
for i, col in enumerate(label_cols):
    def loss_func(T):
        logits_scaled = blended_logits[:, i] / T
        probs = sigmoid(logits_scaled)
        return binary_log_loss(y_true_7[:, i], probs)
    from scipy.optimize import minimize_scalar
    res = minimize_scalar(loss_func, bounds=(0.5, 2.0), method='bounded')
    temperatures_weighted[i] = res.x
    print(f'{col} weighted temperature: {temperatures_weighted[i]:.4f}, min loss: {res.fun:.4f}')

# Save weighted temperatures
np.save('temperatures_weighted.npy', temperatures_weighted)
print('Weighted temperatures saved to temperatures_weighted.npy')

# Quick check: Calibrated blended OOF BCE (7 classes)
logits_calib_weighted = blended_logits / temperatures_weighted
probs_calib_weighted = sigmoid(logits_calib_weighted)
wll_calib_weighted = np.mean([binary_log_loss(y_true_7[:, i], probs_calib_weighted[:, i]) for i in range(7)])
wll_raw_weighted = np.mean([binary_log_loss(y_true_7[:, i], sigmoid(blended_logits[:, i])) for i in range(7)])
print(f'Weighted Raw OOF BCE (avg): {wll_raw_weighted:.4f}, Calibrated: {wll_calib_weighted:.4f} (delta: {wll_calib_weighted - wll_raw_weighted:.4f})')

# Final blended WLL with chosen rule and calibration (for 8 classes, but calibrate only 7)
probs_overall_final = probs_overall_max if overall_rule == 'max' else probs_overall_union
probs_8_final = np.hstack([probs_calib_weighted, probs_overall_final])
final_wll = compute_wll_8(probs_8_final)
print(f'Final blended calibrated WLL (8 classes): {final_wll:.4f}')

C1 weights (v1,v2,reg): [0.02372475 0.02156091 0.95471434]
C2 weights (v1,v2,reg): [2.60208521e-18 0.00000000e+00 1.00000000e+00]
C3 weights (v1,v2,reg): [3.98986399e-17 1.39092491e-01 8.60907509e-01]
C4 weights (v1,v2,reg): [0. 0. 1.]
C5 weights (v1,v2,reg): [1.44583883e-02 7.58941521e-18 9.85541612e-01]
C6 weights (v1,v2,reg): [0.07699479 0.09759244 0.82541277]
C7 weights (v1,v2,reg): [1.99021326e-02 3.46944695e-18 9.80097867e-01]
Per-class weights saved to weights_threeway.npy
Blended OOF WLL max: 0.4682, union: 0.4222
Chosen overall rule: union
C1 weighted temperature: 1.0008, min loss: 0.3229
C2 weighted temperature: 0.9821, min loss: 0.4113
C3 weighted temperature: 0.9902, min loss: 0.1823
C4 weighted temperature: 0.9842, min loss: 0.2116
C5 weighted temperature: 0.9948, min loss: 0.3446
C6 weighted temperature: 0.9928, min loss: 0.4532
C7 weighted temperature: 0.9924, min loss: 0.4454
Weighted temperatures saved to temperatures_weighted.npy
Weighted Raw OOF BCE (avg): 0.3388, Ca

In [25]:
import numpy as np, pandas as pd
from scipy.special import expit as sigmoid

label_cols = ['C1','C2','C3','C4','C5','C6','C7']

# Load labels aligned to OOF rows
ids = pd.read_csv('data/train_mips.csv')[['StudyInstanceUID']]
train = pd.read_csv('train.csv')
y_df = ids.merge(train, on='StudyInstanceUID', how='left')
y_true_7 = y_df[label_cols].values.astype(float)
y_true_overall = y_true_7.max(axis=1, keepdims=True)
y_true_8 = np.hstack([y_true_7, y_true_overall])

# Load OOF logits in order [v1, v2, reg]
v1 = pd.read_csv('oof_logits_convnext.csv')[label_cols].values
v2 = pd.read_csv('oof_logits_convnext_v2.csv')[label_cols].values
reg= pd.read_csv('oof_logits_regnet.csv')[label_cols].values
stack = np.stack([v1, v2, reg], axis=2)  # [N,7,3]

# Load blending artifacts
W = np.load('weights_threeway.npy')          # [7,3]
T = np.load('temperatures_weighted.npy')     # [7]
overall_rule = open('overall_rule.txt').read().strip()

# Blend and compute probabilities
blended_logits = np.sum(stack * W[None,:,:], axis=2)     # [N,7]
probs7_uncal = sigmoid(blended_logits)
probs7_cal = sigmoid(blended_logits / T)

# Overall rule — NOTE: matches your OOF (overall from UNCALIBRATED probs)
if overall_rule == 'union':
    overall = 1 - np.prod(1 - probs7_uncal, axis=1, keepdims=True)
else:
    overall = probs7_uncal.max(axis=1, keepdims=True)

probs8 = np.hstack([probs7_cal, overall])  # 7 calibrated + overall from uncalibrated

def weighted_log_loss(y_true, y_pred):
    y_pred = np.clip(y_pred, 1e-6, 1-1e-6)
    losses = -y_true*np.log(y_pred) - (1-y_true)*np.log(1-y_pred)
    w = np.array([1]*7 + [2], dtype=float)
    return float(np.sum(losses.mean(0)*w) / w.sum())

wll = weighted_log_loss(y_true_8, probs8)
print(f'OOF WLL: {wll:.4f} (expect ~0.4222)')

OOF WLL: 0.4222 (expect ~0.4222)


In [26]:
import numpy as np, pandas as pd
from scipy.special import expit as sigmoid
from scipy.optimize import minimize
from sklearn.metrics import log_loss

label_cols = ['C1','C2','C3','C4','C5','C6','C7']

# Load OOF logits (order: ConvNeXt v1, ConvNeXt v2, RegNet)
v1 = pd.read_csv('oof_logits_convnext.csv')[label_cols].values
v2 = pd.read_csv('oof_logits_convnext_v2.csv')[label_cols].values
reg = pd.read_csv('oof_logits_regnet.csv')[label_cols].values
stack = np.stack([v1, v2, reg], axis=2)  # [N,7,3]

# Load aligned labels
ids = pd.read_csv('data/train_mips.csv')[['StudyInstanceUID']]
train = pd.read_csv('train.csv')
y_df = ids.merge(train, on='StudyInstanceUID', how='left')
y7 = y_df[label_cols].values.astype(float)
y_overall = y7.max(axis=1, keepdims=True)
y8 = np.hstack([y7, y_overall])

def wll(y_true8, p8):
    p8 = np.clip(p8, 1e-6, 1-1e-6)
    losses = -y_true8*np.log(p8) - (1-y_true8)*np.log(1-p8)
    w = np.array([1]*7 + [2], float)
    return float(np.sum(losses.mean(0) * w) / w.sum())

# Optimize per-class weights with RegNet cap
bounds = [(0,1), (0,1), (0,0.3)]
constraints = [{'type':'eq','fun': lambda w: np.sum(w)-1.0}]
w_opt = np.zeros((7,3), dtype=float)

def obj(w, logits3, y):
    blend = np.sum(logits3 * w, axis=1)
    return log_loss(y, sigmoid(blend), labels=[0,1])

for c in range(7):
    res = minimize(obj, x0=np.array([1/3,1/3,1/3]),
                   args=(stack[:,c,:], y7[:,c]),
                   method='SLSQP', bounds=bounds, constraints=constraints)
    w_opt[c] = res.x

np.save('weights_threeway_capped.npy', w_opt)

# Refit per-class temperatures on capped blend
blend = np.sum(stack * w_opt[None,:,:], axis=2)  # [N,7]
def fit_temp(logits, y):
    # log-param for stability; T in ~[0.05,20]
    def loss(z):
        t = np.exp(z)
        p = sigmoid(logits / t)
        p = np.clip(p, 1e-6, 1-1e-6)
        return float(log_loss(y, p, labels=[0,1]))
    r = minimize(loss, x0=0.0, method='L-BFGS-B', bounds=[(-3,3)])
    return float(np.exp(r.x[0]))
temps = np.array([fit_temp(blend[:,k], y7[:,k]) for k in range(7)], dtype=np.float32)
np.save('temperatures_weighted_capped.npy', temps)

# Evaluate OOF with overall from UNCALIBRATED probs (to match inference)
probs7_uncal = sigmoid(blend)
probs7_cal = sigmoid(blend / temps)
with open('overall_rule.txt','r') as f:
    rule = f.read().strip()
if rule == 'union':
    overall = 1 - np.prod(1 - probs7_uncal, axis=1, keepdims=True)
else:
    overall = probs7_uncal.max(axis=1, keepdims=True)
p8 = np.hstack([probs7_cal, overall])
print(f'OOF WLL (capped): {wll(y8, p8):.4f}')


OOF WLL (capped): 0.4422


In [27]:
import numpy as np, pandas as pd
from scipy.optimize import minimize
from scipy.special import expit as sigmoid
from sklearn.metrics import log_loss

label_cols = ['C1','C2','C3','C4','C5','C6','C7']
v1 = pd.read_csv('oof_logits_convnext.csv')[label_cols].values
v2 = pd.read_csv('oof_logits_convnext_v2.csv')[label_cols].values
tw = 0.5*(v1+v2)

ids = pd.read_csv('data/train_mips.csv')[['StudyInstanceUID']]
train = pd.read_csv('train.csv')
y_df = ids.merge(train, on='StudyInstanceUID', how='left')
y7 = y_df[label_cols].values.astype(float)
y_overall = y7.max(axis=1, keepdims=True)
y8 = np.hstack([y7, y_overall])

def fit_temp(logits, y):
    def loss(z):
        t = np.exp(z)
        return log_loss(y, sigmoid(logits/t), labels=[0,1])
    r = minimize(loss, x0=0.0, method='L-BFGS-B', bounds=[(-3,3)])
    return float(np.exp(r.x[0]))
temps_tw = np.array([fit_temp(tw[:,k], y7[:,k]) for k in range(7)], dtype=np.float32)
np.save('temperatures_twoway.npy', temps_tw)

# Evaluate with overall from UNCALIBRATED to match inference
probs7_uncal = sigmoid(tw)
probs7_cal = sigmoid(tw / temps_tw)
with open('overall_rule.txt','r') as f:
    rule = f.read().strip()
if rule == 'union':
    overall = 1 - np.prod(1 - probs7_uncal, axis=1, keepdims=True)
else:
    overall = probs7_uncal.max(axis=1, keepdims=True)
p8 = np.hstack([probs7_cal, overall])

def wll(y_true8, p8):
    p8 = np.clip(p8, 1e-6, 1-1e-6)
    losses = -y_true8*np.log(p8) - (1-y_true8)*np.log(1-p8)
    w = np.array([1]*7 + [2], float)
    return float(np.sum(losses.mean(0) * w) / w.sum())

print(f'2-way OOF WLL: {wll(y8, p8):.4f}')

2-way OOF WLL: 0.4665


In [33]:
# 4-way Ensemble: Add Patch OOF Logits to Three-Way TTA Blend
import numpy as np, pandas as pd
from scipy.optimize import minimize
from scipy.special import expit as sigmoid
from sklearn.metrics import log_loss
from sklearn.linear_model import LogisticRegression
import joblib

label_cols = ['C1','C2','C3','C4','C5','C6','C7']

# Load aligned labels
ids = pd.read_csv('data/train_mips.csv')[['StudyInstanceUID']]
train = pd.read_csv('train.csv')
y_df = ids.merge(train, on='StudyInstanceUID', how='left')
y7 = y_df[label_cols].values.astype(float)
y_overall = y7.max(axis=1, keepdims=True)
y8 = np.hstack([y7, y_overall])

def wll(y_true8, p8):
    p8 = np.clip(p8, 1e-6, 1-1e-6)
    losses = [log_loss(y_true8[:,i], p8[:,i]) for i in range(8)]
    return np.average(losses, weights=np.array([1]*7+[2], float))

# Load TTA OOF logits [N,7] for three models + patch
v1 = np.load('oof_logits_convnext_tta.npy')
v2 = np.load('oof_logits_convnext_v2_tta.npy')
reg = np.load('oof_logits_regnet_tta.npy')
patch = np.load('oof_logits_patch.npy')
stack4 = np.stack([v1, v2, reg, patch], axis=2)  # [N,7,4]

# Optimize per-class weights (sum=1, >=0, reg<=0.3, patch<=0.4)
bounds = [(0,1),(0,1),(0,0.3),(0,0.4)]
constraints = [{'type':'eq','fun': lambda w: np.sum(w)-1.0}]
w_opt4 = np.zeros((7,4), dtype=float)

def obj_w(w, logits4, y):
    blend = np.sum(logits4 * w, axis=1)
    return log_loss(y, sigmoid(blend), labels=[0,1])

for c in range(7):
    res = minimize(obj_w, x0=np.array([0.3,0.3,0.2,0.2]),
                   args=(stack4[:,c,:], y7[:,c]),
                   method='SLSQP', bounds=bounds, constraints=constraints,
                   options={'disp':False, 'maxiter':100})
    w_opt4[c] = res.x
    print(f'{label_cols[c]} weights (v1,v2,reg,patch): {res.x.round(3)}')

np.save('weights_fourway_tta.npy', w_opt4)

# Blend logits [N,7]
blend4 = np.sum(stack4 * w_opt4[None,:,:], axis=2)
p7_uncal4 = sigmoid(blend4)

# Overall rule on uncalibrated (union vs max)
p_overall_max4 = p7_uncal4.max(axis=1, keepdims=True)
p_overall_union4 = 1 - np.prod(1 - p7_uncal4, axis=1, keepdims=True)
p8_max4 = np.hstack([p7_uncal4, p_overall_max4])
p8_union4 = np.hstack([p7_uncal4, p_overall_union4])
wll_max4 = wll(y8, p8_max4)
wll_union4 = wll(y8, p8_union4)
print(f'4-way uncal WLL max={wll_max4:.4f}, union={wll_union4:.4f}')
rule4 = 'union' if wll_union4 < wll_max4 else 'max'
with open('overall_rule_fourway_tta.txt','w') as f: f.write(rule4)
p_overall_base4 = p_overall_union4 if rule4=='union' else p_overall_max4

# Refit temperatures on 4-way blend (7 classes)
def fit_temp(logits, y):
    def loss(z):
        t = np.exp(z)
        return log_loss(y, sigmoid(logits/t), labels=[0,1])
    from scipy.optimize import minimize
    r = minimize(loss, x0=0.0, method='L-BFGS-B', bounds=[(-3,3)])
    return float(np.exp(r.x[0]))
temps4 = np.array([fit_temp(blend4[:,k], y7[:,k]) for k in range(7)], dtype=np.float32)
np.save('temperatures_fourway_tta.npy', temps4)

# Calibrated probs
p7_cal4 = sigmoid(blend4 / temps4)
p8_cal4 = np.hstack([p7_cal4, p_overall_base4])
wll_cal4 = wll(y8, p8_cal4)
print(f'4-way calibrated OOF WLL: {wll_cal4:.4f}')

# Learned overall stacker: LR on 7 cal + overall base, tune alpha
from sklearn.linear_model import LogisticRegression
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from scipy.optimize import minimize_scalar

# Features for LR: 7 cal + base overall
X_lr = np.hstack([p7_cal4, p_overall_base4])  # [N,8]

# 5-fold CV for OOF LR preds (leakage-free)
skf_lr = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=42)
p_lr_oof = np.zeros(len(X_lr), dtype=np.float32)
for tr_idx, va_idx in skf_lr.split(X_lr, y7):
    lr = LogisticRegression(penalty='l2', C=1.0, class_weight='balanced',
                            solver='lbfgs', max_iter=1000, random_state=42)
    lr.fit(X_lr[tr_idx], y_overall[tr_idx])
    p_lr_oof[va_idx] = lr.predict_proba(X_lr[va_idx])[:,1]

# Tune alpha: min WLL with alpha*LR + (1-alpha)*base
def obj_alpha(alpha):
    p_overall = alpha * p_lr_oof + (1-alpha) * p_overall_base4.squeeze()
    p8_alpha = np.hstack([p7_cal4, p_overall.reshape(-1,1)])
    return wll(y8, p8_alpha)
res_alpha = minimize_scalar(obj_alpha, bounds=(0,1), method='bounded')
alpha4 = float(res_alpha.x)
print(f'4-way LR alpha: {alpha4:.4f}, WLL: {res_alpha.fun:.4f}')

# Refit LR on full data
lr_full4 = LogisticRegression(penalty='l2', C=1.0, class_weight='balanced',
                              solver='lbfgs', max_iter=1000, random_state=42)
lr_full4.fit(X_lr, y_overall)

# Save 4-way artifacts
joblib.dump(lr_full4, 'overall_regressor_fourway_tta.pkl')
np.save('overall_lr_alpha_fourway_tta.npy', np.array([alpha4]))
with open('overall_lr_base_fourway_tta.txt','w') as f: f.write(rule4)
np.save('oof_logits_fourway_blend.npy', blend4)  # for later use

# Final 4-way OOF WLL
p_overall_final4 = alpha4 * lr_full4.predict_proba(X_lr)[:,1] + (1-alpha4) * p_overall_base4.squeeze()
p8_final4 = np.hstack([p7_cal4, p_overall_final4.reshape(-1,1)])
final_wll4 = wll(y8, p8_final4)
print(f'Final 4-way OOF WLL: {final_wll4:.4f}')

# Save 4-way OOF probs for reference
oof_df4 = y_df[['StudyInstanceUID']].copy()
oof_df4[label_cols] = p7_cal4
oof_df4['patient_overall'] = p8_final4[:,-1]
oof_df4.to_csv('oof_predictions_fourway.csv', index=False)
print('4-way ensemble completed. If WLL <0.4211, proceed to test inference for patch model.')

C1 weights (v1,v2,reg,patch): [0.62  0.022 0.3   0.057]
C2 weights (v1,v2,reg,patch): [0.761 0.13  0.032 0.078]
C3 weights (v1,v2,reg,patch): [0.78  0.147 0.    0.073]
C4 weights (v1,v2,reg,patch): [0.85 0.   0.   0.15]


C5 weights (v1,v2,reg,patch): [0.543 0.346 0.041 0.071]
C6 weights (v1,v2,reg,patch): [0.278 0.584 0.093 0.045]
C7 weights (v1,v2,reg,patch): [0.656 0.    0.3   0.044]
4-way uncal WLL max=0.4468, union=0.3945
4-way calibrated OOF WLL: 0.3944


4-way LR alpha: 0.0000, WLL: 0.3944
Final 4-way OOF WLL: 0.3944
4-way ensemble completed. If WLL <0.4211, proceed to test inference for patch model.


In [35]:
# Additive Swin Integration: X4 = X3 + s * swin_logits, coordinate descent s_i [0,0.3], re-fit T7/LR/alpha/postproc
import numpy as np
from scipy.optimize import minimize_scalar
from sklearn.linear_model import LogisticRegression
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from scipy.special import expit as sigmoid
from sklearn.metrics import log_loss

# Load 3-way blended logits X3 [N,7]
v1 = np.load('oof_logits_convnext_tta.npy')
v2 = np.load('oof_logits_convnext_v2_tta.npy')
reg = np.load('oof_logits_regnet_tta.npy')
W = np.load('weights_threeway_tta.npy')  # [7,3]
stack3 = np.stack([v1, v2, reg], axis=2)  # [N,7,3]
X3 = np.sum(stack3 * W[None,:,:], axis=2)  # [N,7] X3 = blended 3-way logits

# Load Swin OOF logits [N,7]
swin = np.load('oof_logits_swin_tta.npy')

# Labels for wll
label_cols = ['C1','C2','C3','C4','C5','C6','C7']
ids = pd.read_csv('data/train_mips.csv')[['StudyInstanceUID']]
train = pd.read_csv('train.csv')
y_df = ids.merge(train, on='StudyInstanceUID', how='left')
y7 = y_df[label_cols].values.astype(float)
y_overall = y7.max(axis=1).astype(int)
y8 = np.hstack([y7, y_overall[:,None]])

def wll(y_true8, p8):
    p8 = np.clip(p8, 1e-6, 1-1e-6)
    losses = [log_loss(y_true8[:,i], p8[:,i], labels=[0,1]) for i in range(8)]
    return float(np.average(losses, weights=np.array([1]*7+[2], float)))

# From prior postproc (cell 3 in 03_inference_tta): lam=0.1162, gamma=0
lam = 0.1162
gamma = 0.0

def smooth_chain(p7, lam):
    p_smooth = p7.copy()
    for i in range(1, 6):
        p_smooth[:,i] = (1 - 2*lam) * p7[:,i] + lam * (p7[:,i-1] + p7[:,i+1])
    p_smooth[:,0] = (1 - lam) * p7[:,0] + lam * p7[:,1]
    p_smooth[:,6] = (1 - lam) * p7[:,6] + lam * p7[:,5]
    return p_smooth

# Coordinate descent for s [7], init 0.10, 3 passes, s_i in [0,0.3] step 0.05
s = np.full(7, 0.10)
for pass_num in range(3):
    print(f'Additive Swin pass {pass_num+1}/3')
    for i in range(7):
        best_s_i = s[i]
        best_score = float('inf')
        for s_i in np.arange(0, 0.31, 0.05):
            temp_s = s.copy()
            temp_s[i] = s_i
            X4 = X3 + temp_s[None,:] * swin  # [N,7]
            p4_uncal = sigmoid(X4)
            p4_cal = sigmoid(X4 / T7)  # approximate with prior T7 for speed in coord descent
            union4 = 1 - np.prod(1 - p4_uncal, axis=1)
            max4 = p4_uncal.max(axis=1)
            base4 = gamma * union4 + (1 - gamma) * max4
            # Quick LR refit on X4 for p_lr4_oof (approx, use prior alpha=1.0 for speed)
            skf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=42)
            p_lr4_oof = np.zeros(len(X4), dtype=np.float32)
            for tr_idx, va_idx in skf.split(X4, y7):
                lr_temp = LogisticRegression(penalty='l2', C=1.0, class_weight='balanced', solver='lbfgs', max_iter=1000, random_state=42)
                lr_temp.fit(X4[tr_idx], y_overall[tr_idx])
                p_lr4_oof[va_idx] = lr_temp.predict_proba(X4[va_idx])[:,1].astype(np.float32)
            p_overall4 = 1.0 * p_lr4_oof + 0.0 * base4  # alpha=1.0 approx
            p7_smooth4 = smooth_chain(p4_cal, lam)
            p84 = np.hstack([p7_smooth4, p_overall4[:,None]])
            score = wll(y8, p84)
            if score < best_score:
                best_score = score
                best_s_i = s_i
        s[i] = best_s_i
        print(f'  Class {i} ({label_cols[i]}): s_i updated to {best_s_i:.3f}, score={best_score:.4f}')

# Full refit after coord descent: T7_new per class
X4 = X3 + s[None,:] * swin
T7_new = np.ones(7)
for c in range(7):
    def obj_t(t):
        p_c = sigmoid(X4[:,c] / t)
        return log_loss(y7[:,c], p_c, labels=[0,1])
    T7_new[c] = minimize_scalar(obj_t, bounds=(0.5, 2.0), method='bounded').x
    print(f'T7_new[{label_cols[c]}]: {T7_new[c]:.4f}')

# Refit LR on X4, alpha_new on p_overall4
skf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=42)
p_lr4_oof = np.zeros(len(X4), dtype=np.float32)
for tr_idx, va_idx in skf.split(X4, y7):
    lr_full_temp = LogisticRegression(penalty='l2', C=1.0, class_weight='balanced', solver='lbfgs', max_iter=1000, random_state=42)
    lr_full_temp.fit(X4[tr_idx], y_overall[tr_idx])
    p_lr4_oof[va_idx] = lr_full_temp.predict_proba(X4[va_idx])[:,1].astype(np.float32)

def obj_alpha(alpha):
    p_overall = alpha * p_lr4_oof + (1 - alpha) * base4
    p7_smooth4 = smooth_chain(sigmoid(X4 / T7_new), lam)
    p84 = np.hstack([p7_smooth4, p_overall[:,None]])
    return wll(y8, p84)

res_alpha = minimize_scalar(obj_alpha, bounds=(0,1), method='bounded')
alpha_new = res_alpha.x
print(f'Alpha new: {alpha_new:.4f}, WLL: {res_alpha.fun:.4f}')

# Save additive Swin artifacts
np.save('addswin_s.npy', s)
np.save('temperatures_addswin_tta.npy', T7_new)
lr_full = LogisticRegression(penalty='l2', C=1.0, class_weight='balanced', solver='lbfgs', max_iter=1000, random_state=42)
lr_full.fit(X4, y_overall)
joblib.dump(lr_full, 'overall_regressor_addswin.pkl')
np.save('overall_lr_alpha_addswin.npy', np.array([alpha_new]))

# Gate: compute full OOF wll with new postproc vs best 3-way 0.4198
p4_uncal = sigmoid(X4)
p4_cal = sigmoid(X4 / T7_new)
union4 = 1 - np.prod(1 - p4_uncal, axis=1)
max4 = p4_uncal.max(axis=1)
base4 = gamma * union4 + (1 - gamma) * max4
p_overall_new = alpha_new * lr_full.predict_proba(X4)[:,1] + (1 - alpha_new) * base4
p7_smooth_new = smooth_chain(p4_cal, lam)
p8_new = np.hstack([p7_smooth_new, p_overall_new[:,None]])
for i in range(8):
    p8_new[:,i] = np.clip(0.005 + 0.92 * p8_new[:,i], 1e-6, 1-1e-6)  # uniform approx for gate
score_new = wll(y8, p8_new)
print(f'Additive Swin OOF WLL: {score_new:.4f} (3-way baseline 0.4198, gain {0.4198 - score_new:.4f})')

# Simple fold check (reuse skf from 3-way)
skf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_improves = 0
for tr_idx, va_idx in skf.split(X4, y7):
    # 3-way baseline on val
    X3_val = np.sum(stack3[va_idx] * W[None,:,:], axis=2)
    p3_uncal = sigmoid(X3_val)
    p3_cal = sigmoid(X3_val / T7)
    base3 = gamma * (1 - np.prod(1 - p3_uncal,1)) + (1 - gamma) * p3_uncal.max(1)
    p_lr3_oof_val = lr.predict_proba(X3_val)[:,1]  # approx, since LR on X3
    p_overall3 = 1.0 * p_lr3_oof_val + 0.0 * base3  # alpha=1.0
    p7_smooth3 = smooth_chain(p3_cal, lam)
    p83 = np.hstack([p7_smooth3, p_overall3[:,None]])
    baseline_fold = wll(y8[va_idx], p83[va_idx])
    # New on val
    X4_val = X3[va_idx] + s[None,:] * swin[va_idx]
    p4_uncal_val = sigmoid(X4_val)
    p4_cal_val = sigmoid(X4_val / T7_new)
    union4_val = 1 - np.prod(1 - p4_uncal_val,1)
    max4_val = p4_uncal_val.max(1)
    base4_val = gamma * union4_val + (1 - gamma) * max4_val
    p_lr4_oof_val = lr_full.predict_proba(X4_val)[:,1]
    p_overall4_val = alpha_new * p_lr4_oof_val + (1 - alpha_new) * base4_val
    p7_smooth4_val = smooth_chain(p4_cal_val, lam)
    p84_val = np.hstack([p7_smooth4_val, p_overall4_val[:,None]])
    for i in range(8):
        p84_val[:,i] = np.clip(0.005 + 0.92 * p84_val[:,i], 1e-6, 1-1e-6)
    new_fold = wll(y8[va_idx], p84_val[va_idx])
    if new_fold < baseline_fold:
        fold_improves += 1

mean_improve = 0.4198 - score_new
print(f'Gate: mean improve={mean_improve:.4f}, improves={fold_improves}/5')

if mean_improve >= 0.005 and fold_improves >= 4:
    print('Gate passed: Proceed to 4-way inference with additive Swin.')
    gate_pass = True
    # Copy artifacts to 03_inference_tta for postproc re-tune
    # In 03, load X4_test = X3_test + s * swin_test, T7_new, lr_full, alpha_new
    # Then run postproc tuning cell 3 with these, save postproc_addswin.npz
else:
    print('Gate failed: Stick to 3-way, pivot to holdout/multi-view.')
    gate_pass = False

print('Additive Swin integration complete. If gate pass, update 03_inference_tta cell 0 for test Swin logits + add, re-tune postproc, submit. Else, create 06_holdout.ipynb for Phase 1 pivot.')

Additive Swin pass 1/3
  Class 0 (C1): s_i updated to 0.050, score=0.4206


  Class 1 (C2): s_i updated to 0.000, score=0.4203
  Class 2 (C3): s_i updated to 0.100, score=0.4203


  Class 3 (C4): s_i updated to 0.100, score=0.4203
  Class 4 (C5): s_i updated to 0.000, score=0.4199


  Class 5 (C6): s_i updated to 0.000, score=0.4195
  Class 6 (C7): s_i updated to 0.200, score=0.4193
Additive Swin pass 2/3


  Class 0 (C1): s_i updated to 0.100, score=0.4193
  Class 1 (C2): s_i updated to 0.000, score=0.4193


  Class 2 (C3): s_i updated to 0.050, score=0.4193
  Class 3 (C4): s_i updated to 0.150, score=0.4192


  Class 4 (C5): s_i updated to 0.000, score=0.4192
  Class 5 (C6): s_i updated to 0.000, score=0.4192


  Class 6 (C7): s_i updated to 0.200, score=0.4192
Additive Swin pass 3/3
  Class 0 (C1): s_i updated to 0.100, score=0.4192


  Class 1 (C2): s_i updated to 0.000, score=0.4192
  Class 2 (C3): s_i updated to 0.050, score=0.4192


  Class 3 (C4): s_i updated to 0.150, score=0.4192
  Class 4 (C5): s_i updated to 0.000, score=0.4192


  Class 5 (C6): s_i updated to 0.000, score=0.4192
  Class 6 (C7): s_i updated to 0.200, score=0.4192
T7_new[C1]: 0.9999
T7_new[C2]: 0.7776


T7_new[C3]: 0.9586
T7_new[C4]: 1.0473
T7_new[C5]: 0.8476
T7_new[C6]: 0.8188
T7_new[C7]: 0.9065
Alpha new: 0.8638, WLL: 0.4179
Additive Swin OOF WLL: 0.4171 (3-way baseline 0.4198, gain 0.0027)


IndexError: index 42 is out of bounds for axis 0 with size 40