In [None]:
## 1.2. Imports & Basic Setup
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import roc_auc_score
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm.auto import tqdm
import gc
import importlib
import random

import utils_preproc
importlib.reload(utils_preproc) # Force reload to pick up changes
from utils_preproc import load_and_preprocess

# --- Determinism ---
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Configuration: Pivoting to a full 5-fold CV run based on successful baseline
class CFG:
    # Execution control
    run_single_fold = False # <-- SWITCHING TO FULL 5-FOLD CV
    target_fold = 0
    seed = 42
    
    # Paths
    data_dir = '.'
    train_path = os.path.join(data_dir, 'train')
    train_labels_path = os.path.join(data_dir, 'train_labels.csv')
    
    # Preprocessing
    preprocess_transform_type = 'asinh'
    clip_percentiles = (0.1, 99.9)
    
    # Model
    model_name = 'tf_efficientnet_b2_ns'
    img_size = 256
    in_channels = 3
    num_classes = 1
    
    # Training
    n_epochs = 15
    batch_size = 32
    n_folds = 5
    
    # Optimizer & Scheduler
    lr = 3e-4
    weight_decay = 1e-6
    scheduler_type = 'OneCycleLR'
    one_cycle_pct_start = 0.3
    grad_clip_norm = 1.0
    
    # Loss & Early Stopping
    use_sampler = False
    loss_type = 'BCE' # Sticking with plain BCE as it worked well
    pos_weight_val = 2.0
    patience = 4
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- Apply Seed ---
seed_everything(CFG.seed)

print(f"Using device: {CFG.device}")
print(f"Torch: {torch.__version__}, Timm: {timm.__version__}, Albumentations: {A.__version__}")

print("--- STARTING FULL 5-FOLD CV RUN ---")
print(f"Model: {CFG.model_name}, Img Size: {CFG.img_size}")
print(f"Loss Type: {CFG.loss_type}")
print(f"LR: {CFG.lr}, Epochs: {CFG.n_epochs}")

In [None]:
# 2. EDA & Data Preparation

## 2.1. Load Labels and Prepare for CV

df = pd.read_csv(CFG.train_labels_path)

# Create a 'group' column for StratifiedGroupKFold
# We group by the first three characters of the ID for a more granular split, as per expert advice.
df['group'] = df['id'].str[:3]

print("Train labels dataframe:")
print(df.head())
print(f"\nShape: {df.shape}")
print(f"\nNumber of unique groups: {df['group'].nunique()}")

print("\nTarget distribution:")
print(df['target'].value_counts(normalize=True))

# Calculate pos_weight and store it in the config to avoid cell order bugs
neg_count = df['target'].value_counts()[0]
pos_count = df['target'].value_counts()[1]
pos_weight_value = neg_count / pos_count
CFG.calculated_pos_weight = float(pos_weight_value)
print(f"\nCalculated positive class weight: {CFG.calculated_pos_weight:.2f}")
print("Stored in CFG.calculated_pos_weight")

def get_train_file_path(image_id):
    return f"{CFG.train_path}/{image_id[0]}/{image_id}.npy"

df['file_path'] = df['id'].apply(get_train_file_path)

print("\nDataframe with file paths:")
print(df.head())

In [None]:
## 2.2. Dataset & Augmentations

def get_transforms(*, data):
    # Per expert advice, re-enabling HorizontalFlip for the full CV run.
    if data == 'train':
        return A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
            A.HorizontalFlip(p=0.5),
            ToTensorV2(),
        ])
    elif data == 'valid':
        return A.Compose([
            A.Resize(CFG.img_size, CFG.img_size),
            ToTensorV2(),
        ])

class SETIDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_paths = df['file_path'].values
        self.labels = df['target'].values
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        
        # Use the centralized preprocessing function with options from CFG
        image = load_and_preprocess(
            file_path,
            transform_type=CFG.preprocess_transform_type,
            clip_percentiles=CFG.clip_percentiles
        )
        
        # EXPERT ADVICE: Robustly handle potential CHW format before Albumentations (which expects HWC)
        if image.ndim == 3 and image.shape[0] == 3:
            # This condition suggests a CHW format, so we transpose it to HWC.
            image = np.transpose(image, (1, 2, 0))
        
        # Final check to ensure the image is in HWC format for Albumentations
        assert image.ndim == 3 and image.shape[2] == 3, f"Unexpected image shape: {image.shape}. Expected (H, W, 3)."
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        label = torch.tensor(self.labels[idx]).float()
        
        return image, label

In [None]:
# 3. Model & Training Functions

## 3.1. Model Definition

class SETIModel(nn.Module):
    def __init__(self, model_name=CFG.model_name, pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=CFG.in_channels, num_classes=CFG.num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

## 3.2. Loss Functions
# As per expert advice, adding FocalLoss for ablation experiments.
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        bce_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt)**self.gamma * bce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

## 3.3. Training & Validation Functions (AMP DISABLED for deterministic run)

def train_fn(train_loader, model, criterion, optimizer, scheduler, epoch, device):
    model.train()
    losses = []
    
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch+1}')
    for step, (images, labels) in pbar:
        images = images.to(device)
        labels = labels.to(device).unsqueeze(1)
        
        # No AMP for this run
        y_preds = model(images)
        loss = criterion(y_preds, labels)
        
        # --- Diagnostic print for first batch of first epoch ---
        if epoch == 0 and step == 0:
            print(f"\n  First batch diagnostics:")
            print(f"    Loss: {loss.item():.4f}")
            print(f"    Labels mean: {labels.float().mean().item():.4f}")
            print(f"    Sigmoid preds (first 5): {torch.sigmoid(y_preds[:5].detach()).cpu().numpy().flatten()}")

        losses.append(loss.item())
        loss.backward()
        
        # Gradient Clipping
        if CFG.grad_clip_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.grad_clip_norm)
            
        optimizer.step()
        optimizer.zero_grad()
        
        if CFG.scheduler_type == 'OneCycleLR':
            scheduler.step()
            
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr']
        pbar.set_postfix(loss=f'{np.mean(losses):.4f}', lr=f'{current_lr:.2e}', mem_gb=f'{mem:.2f}')
        
    return np.mean(losses)

def valid_fn(valid_loader, model, criterion, device):
    model.eval()
    losses = []
    preds = []
    targets = []
    
    pbar = tqdm(enumerate(valid_loader), total=len(valid_loader), desc='Validating')
    with torch.no_grad():
        for step, (images, labels) in pbar:
            images = images.to(device)
            labels = labels.to(device).unsqueeze(1)
            
            # No AMP for this run
            y_preds = model(images)
            
            loss = criterion(y_preds, labels)
            losses.append(loss.item())
            
            preds.append(y_preds.sigmoid().to('cpu').numpy())
            targets.append(labels.to('cpu').numpy())
            
            mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
            pbar.set_postfix(loss=f'{np.mean(losses):.4f}', mem_gb=f'{mem:.2f}')
            
    predictions = np.concatenate(preds).flatten()
    targets = np.concatenate(targets).flatten()
    
    # --- Diagnostic print for validation predictions ---
    print(f"  Validation preds stats: Min={predictions.min():.4f}, Mean={predictions.mean():.4f}, Max={predictions.max():.4f}")
    
    val_auc = roc_auc_score(targets, predictions)
    return np.mean(losses), val_auc, predictions, targets

In [None]:
# 4. Main Training Loop (Simplified for Ablation)
def run_fold(fold, df):
    print(f"========== FOLD {fold} TRAINING ==========")
    
    # --- Clean up stale artifacts before run ---
    model_path = f'{CFG.model_name}_fold{fold}_best.pth'
    if os.path.exists(model_path):
        print(f"Removing stale model checkpoint: {model_path}")
        os.remove(model_path)
    if os.path.exists('oof_predictions.csv'):
        print("Removing stale oof_predictions.csv")
        os.remove('oof_predictions.csv')
    
    # Create train/valid splits
    train_idx = df[df['fold'] != fold].index
    valid_idx = df[df['fold'] == fold].index
    
    train_df = df.loc[train_idx].reset_index(drop=True)
    valid_df = df.loc[valid_idx].reset_index(drop=True)
    
    print(f"Fold {fold} Train Target Distribution:\n{train_df['target'].value_counts(normalize=True)}")
    print(f"Fold {fold} Valid Target Distribution:\n{valid_df['target'].value_counts(normalize=True)}")
    
    # Create datasets
    train_dataset = SETIDataset(train_df, transform=get_transforms(data='train'))
    valid_dataset = SETIDataset(valid_df, transform=get_transforms(data='valid'))
    
    # --- Dataloaders (simplified for deterministic run) ---
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    g = torch.Generator()
    g.manual_seed(CFG.seed)

    # Sampler is disabled, so shuffle=True
    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=0, pin_memory=True, worker_init_fn=seed_worker, generator=g)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size * 2, shuffle=False, num_workers=0, pin_memory=True)
    
    # Init model, optimizer, scheduler
    model = SETIModel().to(CFG.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    
    if CFG.scheduler_type == 'OneCycleLR':
        scheduler = OneCycleLR(optimizer, max_lr=CFG.lr, epochs=CFG.n_epochs, steps_per_epoch=len(train_loader), pct_start=CFG.one_cycle_pct_start)
    else:
        scheduler = None

    # --- Loss Function Ablation ---
    if CFG.loss_type == 'BCE':
        criterion = nn.BCEWithLogitsLoss()
        print("Using plain BCEWithLogitsLoss.")
    elif CFG.loss_type == 'BCE_weighted':
        pos_weight_tensor = torch.tensor(CFG.pos_weight_val, device=CFG.device)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
        print(f"Using BCEWithLogitsLoss with pos_weight: {CFG.pos_weight_val:.2f}")
    elif CFG.loss_type == 'Focal':
        criterion = FocalLoss(alpha=0.25, gamma=2.0)
        print("Using FocalLoss (alpha=0.25, gamma=2.0).")
    else:
        raise ValueError(f"Unknown or unsupported loss_type for this run: {CFG.loss_type}")
    
    best_score = 0.
    patience_counter = 0
    fold_oof_df = None
    
    for epoch in range(CFG.n_epochs):
        train_loss = train_fn(train_loader, model, criterion, optimizer, scheduler, epoch, CFG.device)
        valid_loss, val_auc, predictions, _ = valid_fn(valid_loader, model, criterion, CFG.device)
        
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Valid Loss={valid_loss:.4f}, Valid AUC={val_auc:.4f}")
        
        if val_auc > best_score:
            best_score = val_auc
            patience_counter = 0
            print(f"==> New best score: {best_score:.4f}. Saving model and OOF preds.")
            torch.save(model.state_dict(), model_path)
            temp_df = valid_df.copy()
            temp_df['preds'] = predictions
            fold_oof_df = temp_df[['id', 'target', 'preds']]
        else:
            patience_counter += 1
            print(f"Score not improved. Patience: {patience_counter}/{CFG.patience}")
        
        if patience_counter >= CFG.patience:
            print("Early stopping triggered.")
            break
            
    if fold_oof_df is not None:
        print("\nVerifying best model checkpoint...")
        model.load_state_dict(torch.load(model_path))
        _, recomputed_auc, _, _ = valid_fn(valid_loader, model, criterion, CFG.device)
        print(f"  Best recorded AUC: {best_score:.4f}")
        print(f"  Recomputed AUC on loaded model: {recomputed_auc:.4f}")
    
    del model, train_loader, valid_loader, train_dataset, valid_dataset, optimizer, scheduler, criterion
    gc.collect()
    torch.cuda.empty_cache()
    
    return best_score, fold_oof_df

# --- Prepare CV Folds (Load or Create) ---
folds_csv_path = 'folds.csv'

# FIX: Drop 'fold' column from the main dataframe if it exists from a previous run.
# This prevents the MergeError caused by re-running this cell.
if 'fold' in df.columns:
    print("Dropping existing 'fold' column from main dataframe to prevent merge error.")
    df = df.drop(columns=['fold'])

# Attempt to load and merge folds
if os.path.exists(folds_csv_path):
    print("Loading folds from folds.csv")
    folds_df = pd.read_csv(folds_csv_path)
    if 'fold' not in folds_df.columns:
        print("WARNING: 'fold' column missing in folds.csv. Recreating folds.")
        os.remove(folds_csv_path) # Delete bad file, will trigger recreation below
    else:
        df = df.merge(folds_df, on='id', how='left')
        if df['fold'].isnull().any():
            print("WARNING: Mismatch or nulls found in folds after merge. Recreating folds.")
            df = df.drop(columns=['fold']) # Drop the newly merged, bad column
            os.remove(folds_csv_path) # Delete bad file

# If folds.csv didn't exist or was deleted, create it
if not os.path.exists(folds_csv_path):
    print("Creating new folds and saving to folds.csv")
    df['fold'] = -1
    skf = StratifiedGroupKFold(n_splits=CFG.n_folds, shuffle=True, random_state=CFG.seed)
    for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['target'], df['group'])):
        df.loc[val_idx, 'fold'] = int(fold)
    df[['id', 'fold']].to_csv(folds_csv_path, index=False)

df['fold'] = df['fold'].astype(int)

# --- Run Training ---
all_oof_dfs = []
fold_scores = []

# RECOVERY MODE: Manually specify folds to run to recover from previous failure.
folds_to_run = [1]
print(f"RECOVERY MODE: Will only run training for folds: {folds_to_run}")

for fold in folds_to_run:
    score, oof_df_fold = run_fold(fold, df)
    fold_scores.append(score)
    if oof_df_fold is not None:
        all_oof_dfs.append(oof_df_fold)

# --- Summarize Results ---
if all_oof_dfs:
    oof_df = pd.concat(all_oof_dfs).reset_index(drop=True)
    # Note: This OOF score will only be for the folds that were run in this session.
    oof_auc = roc_auc_score(oof_df['target'], oof_df['preds'])
    
    print(f"\n========== PARTIAL CV SUMMARY ==========")
    print(f"Fold scores (best epoch) for folds {folds_to_run}: {fold_scores}")
    print(f"Mean Fold Score for this run: {np.mean(fold_scores):.4f}")
    print(f"OOF AUC for this run: {oof_auc:.4f}")

    oof_df.to_csv('oof_predictions_recovery.csv', index=False)
    print("\nOOF predictions for this run saved to oof_predictions_recovery.csv")
else:
    print("\nTraining did not produce any valid OOF predictions.")

In [None]:
import os
print(os.listdir('.'))

In [None]:
# 5. Verify OOF Score

import pandas as pd
from sklearn.metrics import roc_auc_score

try:
    oof_df = pd.read_csv('oof_predictions.csv')
    oof_auc = roc_auc_score(oof_df['target'], oof_df['preds'])
    print(f"Verified Overall OOF AUC from 'oof_predictions.csv': {oof_auc:.4f}")
except FileNotFoundError:
    print("'oof_predictions.csv' not found. Cannot verify score.")

In [None]:
# 6. Manual Verification of Fold 0
print("Manually verifying the score for the last completed run (Fold 0)...")

# --- Recreate validation set for Fold 0 ---
fold_to_verify = 0
valid_idx = df[df['fold'] == fold_to_verify].index
valid_df = df.loc[valid_idx].reset_index(drop=True)
valid_dataset = SETIDataset(valid_df, transform=get_transforms(data='valid'))
valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size * 2, shuffle=False, num_workers=0, pin_memory=True)

# --- Load the best model from the last run ---
model = SETIModel().to(CFG.device)
model_path = f'{CFG.model_name}_fold{fold_to_verify}_best.pth'
try:
    model.load_state_dict(torch.load(model_path))
    print(f"Successfully loaded model from: {model_path}")

    # --- Re-run validation ---
    criterion = nn.BCEWithLogitsLoss() # The criterion doesn't affect AUC calculation during validation
    _, val_auc, _, _ = valid_fn(valid_loader, model, criterion, CFG.device)
    print(f"\n========== MANUAL VERIFICATION RESULT ==========")
    print(f"Recomputed AUC for Fold {fold_to_verify}: {val_auc:.4f}")

except FileNotFoundError:
    print(f"Model checkpoint not found at: {model_path}")

# --- Clean up ---
del model, valid_loader, valid_dataset, valid_df
gc.collect()
torch.cuda.empty_cache()