In [None]:
"""
Train a CNN for Kaggle's Histopathologic Cancer Detection.
- Dataset: https://www.kaggle.com/competitions/histopathologic-cancer-detection
- Task: predict the probability that the CENTER 32x32 region contains tumor tissue.
- Metric: ROC-AUC on the test set (submit probabilities).

This script is Kaggle-ready:
  * Reads train_labels.csv and images from /kaggle/input/histopathologic-cancer-detection/
  * Stratified split, weighted loss, mixed precision, early stopping
  * Optional center-emphasis mask (Gaussian) to bias the model toward the 32x32 center
  * Supports a simple CNN or a pretrained ResNet18 backbone
  * Writes submission.csv with id, label (probabilities)

Run (Kaggle Notebook / terminal cell):
  !python train_cancer_cnn.py --epochs 8 --model resnet18 --img-size 128 --batch-size 256 --center-emphasis 1 --tta 4

Tip: Increase epochs (e.g., 15–20) and tune LR/augmentations for better scores.
"""
from __future__ import annotations
import argparse
import os
import math
import random
import sys
import time
from dataclasses import dataclass

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms, models

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

import gc
import pdb
from tqdm import tqdm

In [None]:
# -----------------------------
# Utilities
# -----------------------------

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [None]:
# -----------------------------
# Dataset
# -----------------------------
class HistopathDataset(Dataset):
    def __init__(self, df: pd.DataFrame, img_root: str, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_root = img_root
        self.transform = transform
        self.has_label = 'label' in self.df.columns

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_root, f"{row['id']}.tif")
        # Robust image open
        with Image.open(img_path) as im:
            im = im.convert('RGB')
            if self.transform is not None:
                im = self.transform(im)
        if self.has_label:
            y = torch.tensor(row['label'], dtype=torch.float32)
            return im, y
        else:
            return im, row['id']

In [None]:
class CenterEmphasis:
    """Multiply image by a fixed 2D Gaussian mask emphasizing the center.
    Expects a CHW torch tensor in [0,1] after ToTensor().
    """
    def __init__(self, size: int, sigma_frac: float = 0.30, min_gain: float = 0.6):
        self.size = size
        self.sigma = sigma_frac * size
        self.min_gain = min_gain
        self.mask = self._make_mask(size)

    def _make_mask(self, s):
        xs = torch.linspace(-(s-1)/2, (s-1)/2, s)
        yy, xx = torch.meshgrid(xs, xs, indexing='ij')
        rr2 = (xx**2 + yy**2)
        gauss = torch.exp(-rr2 / (2 * (self.sigma**2)))
        gauss = (gauss - gauss.min()) / (gauss.max() - gauss.min() + 1e-8)
        gauss = (1 - self.min_gain) * gauss + self.min_gain  # in [min_gain,1]
        return gauss  # HxW

    def __call__(self, x: torch.Tensor):
        if not torch.is_tensor(x):
            return x
        # x: CxHxW in [0,1]
        if x.ndim == 3 and x.shape[-1] != self.size:
            # If resized elsewhere, resize mask on the fly
            mask = F.interpolate(self.mask.unsqueeze(0).unsqueeze(0), size=x.shape[-2:], mode='bilinear', align_corners=False).squeeze()
        else:
            mask = self.mask
        return x * mask.unsqueeze(0)


def build_transforms(img_size: int, center_emphasis: bool, training: bool):
    """
    Data Augmentations
    """
    aug = []
    if training:
        aug += [
            transforms.RandomResizedCrop(img_size, scale=(0.85, 1.0), ratio=(0.9, 1.1)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.10, hue=0.02),
        ]
    else:
        aug += [transforms.Resize(img_size), transforms.CenterCrop(img_size)]

    aug += [transforms.ToTensor()]

    # If using ImageNet-pretrained models, normalize accordingly
    aug += [transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]

    if center_emphasis:
        aug += [CenterEmphasis(size=img_size, sigma_frac=0.30, min_gain=0.6)]

    return transforms.Compose(aug)

In [None]:
# -----------------------------
# Focal Loss implementation
# -----------------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # inputs: logits [B]
        # targets: labels [B]
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        probs = torch.sigmoid(inputs)
        pt = torch.where(targets == 1, probs, 1 - probs)
        focal_weight = self.alpha * (1 - pt) ** self.gamma
        loss = focal_weight * bce_loss
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

In [None]:
# -----------------------------
# Models
# -----------------------------
class SimpleCNN(nn.Module):
    def __init__(self, in_ch=3):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # /2

            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # /4

            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # /8

            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.2), # 0.3
            nn.Linear(256, 1)
        )
    def forward(self, x):
        x = self.features(x)
        x = self.head(x)
        return x.squeeze(1)


def build_model(name: str = "resnet18") -> nn.Module:
    name = name.lower()
    if name == 'resnet18':
        try:
            weights = models.ResNet18_Weights.IMAGENET1K_V1
        except Exception:
            weights = None
        net = models.resnet18(weights=weights)
        in_features = net.fc.in_features
        net.fc = nn.Linear(in_features, 1)
        return net
    elif name == 'simple_cnn':
        return SimpleCNN()
    else:
        raise ValueError(f"Unknown model: {name}")

In [None]:
# -----------------------------
# Training / Validation
# -----------------------------

def make_loaders(cfg: Config):
    labels_path = os.path.join(cfg.data_dir, cfg.labels_csv)
    train_img_root = os.path.join(cfg.data_dir, cfg.train_dir)
    test_img_root = os.path.join(cfg.data_dir, cfg.test_dir)

    df = pd.read_csv(labels_path)
    # Basic sanity checks
    df = df.drop_duplicates(subset=['id'])

    train_df, val_df = train_test_split(
        df, test_size=cfg.val_size, random_state=cfg.seed, stratify=df['label'])

    t_train = build_transforms(cfg.img_size, bool(cfg.center_emphasis), training=True)
    t_val = build_transforms(cfg.img_size, bool(cfg.center_emphasis), training=False)

    ds_train = HistopathDataset(train_df, img_root=train_img_root, transform=t_train)
    ds_val = HistopathDataset(val_df, img_root=train_img_root, transform=t_val)

    if cfg.use_weighted_sampler:
        # Weighted sampling to handle imbalance
        class_sample_count = train_df['label'].value_counts().sort_index().values.astype(float)
        weight = 1. / class_sample_count
        samples_weight = train_df['label'].map({0: weight[0], 1: weight[1]}).values
        sampler = WeightedRandomSampler(samples_weight, num_samples=len(samples_weight), replacement=True)
        train_loader = DataLoader(ds_train, batch_size=cfg.batch_size, sampler=sampler, num_workers=cfg.num_workers, pin_memory=True)
    else:
        train_loader = DataLoader(ds_train, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)

    val_loader = DataLoader(ds_val, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)

    # Test loader (ids only)
    test_ids = pd.DataFrame({'id': [p[:-4] for p in os.listdir(test_img_root) if p.endswith('.tif')]})
    ds_test = HistopathDataset(test_ids, img_root=test_img_root, transform=t_val)
    test_loader = DataLoader(ds_test, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)

    # pos_weight for BCEWithLogitsLoss
    n_pos = (train_df['label'] == 1).sum()
    n_neg = (train_df['label'] == 0).sum()
    pos_weight = torch.tensor([n_neg / max(1, n_pos)], dtype=torch.float32)

    return train_loader, val_loader, test_loader, pos_weight


class EarlyStopper:
    def __init__(self, patience=3, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best = -np.inf
        self.stop = False

    def step(self, value):
        if value > self.best + self.min_delta:
            self.best = value
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.stop = True


def train_one_epoch(model, loader, criterion, optimizer, scaler, device):
    model.train()
    running_loss = 0.0
    for x, y in tqdm(loader, desc='train'):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device_type='cuda' if device.type == 'cuda' else 'cpu', dtype=torch.float16):
            logits = model(x)
            logits = logits.view(-1)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item() * x.size(0)
    return running_loss / len(loader.dataset)


def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_probs = []
    all_targets = []
    with torch.no_grad():
        for x, y in tqdm(loader, desc='val'):
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            logits = model(x)
            logits = logits.view(-1)
            loss = criterion(logits, y)
            running_loss += loss.item() * x.size(0)
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            all_probs.append(probs)
            all_targets.append(y.detach().cpu().numpy())
    all_probs = np.concatenate(all_probs)
    all_targets = np.concatenate(all_targets)
    try:
        auc = roc_auc_score(all_targets, all_probs)
    except Exception:
        auc = float('nan')
    return running_loss / len(loader.dataset), auc


def cosine_lr(optimizer, base_lr, epoch, max_epoch, warmup=1):
    if epoch < warmup:
        lr = base_lr * (epoch + 1) / max(1, warmup)
    else:
        t = (epoch - warmup) / max(1, (max_epoch - warmup))
        lr = 0.5 * base_lr * (1 + math.cos(math.pi * t))
    for g in optimizer.param_groups:
        g['lr'] = lr
    return lr


# -----------------------------
# Training / Validation
# -----------------------------

def fit(cfg: Config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    set_seed(cfg.seed)
    train_loader, val_loader, test_loader, pos_weight = make_loaders(cfg)

    model = build_model(cfg.model).to(device)

    if train_on_gpu:
        model.cuda()

    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device))
    # criterion = FocalLoss(alpha=1.0, gamma=2.0)
    optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    scaler = torch.amp.GradScaler('cuda')

    stopper = EarlyStopper(patience=3, min_delta=1e-4)

    best_auc = -np.inf
    best_path = os.path.join(cfg.out_dir, f"best_{cfg.model}.pt")

    start_epoch = 0
    # -----------------------------
    # Resume training if checkpoint exists
    # -----------------------------
    if os.path.exists(best_path):
        print(f"Found checkpoint {best_path}, loading...")
        ckpt = torch.load(best_path, map_location=device, weights_only=False) #
        model.load_state_dict(ckpt['model'])
        if 'optimizer' in ckpt:
            optimizer.load_state_dict(ckpt['optimizer'])
        if 'scaler' in ckpt:
            scaler.load_state_dict(ckpt['scaler'])
        if 'epoch' in ckpt:
            start_epoch = ckpt['epoch'] + 1
        if 'best_auc' in ckpt:
            best_auc = ckpt['best_auc']
        print(f"Resumed from epoch {start_epoch}, best AUC {best_auc:.6f}")
    else:
        print(f"Not found checkpoint")

    for epoch in range(start_epoch, cfg.epochs):
        lr_now = cosine_lr(optimizer, cfg.lr, epoch, cfg.epochs, warmup=cfg.warmup_epochs)
        t0 = time.time()
        tr_loss = train_one_epoch(model, train_loader, criterion, optimizer, scaler, device)
        va_loss, va_auc = evaluate(model, val_loader, criterion, device)
        dt = time.time() - t0
        print(f"Epoch {epoch+1}/{cfg.epochs} | lr {lr_now:.3e} | train {tr_loss:.4f} | val {va_loss:.4f} | AUC {va_auc:.6f} | {dt:.1f}s")
        if va_auc > best_auc:
            best_auc = va_auc
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scaler': scaler.state_dict(),
                'cfg': vars(cfg),
                'epoch': epoch,
                'best_auc': best_auc
            }, best_path)
            print(f"  * Saved new best model to {best_path}")
        stopper.step(va_auc)
        if stopper.stop:
            print("Early stopping triggered.")
            break

    # Load best
    ckpt = torch.load(best_path, map_location=device, weights_only=False)
    model.load_state_dict(ckpt['model'])

    # Inference with optional TTA
    print('start prediction')
    sub = predict_submission(model, test_loader, device, tta=cfg.tta)
    sub_path = os.path.join(cfg.out_dir, "submission.csv")
    sub.to_csv(sub_path, index=False)
    print(f"Wrote {sub_path}")


@torch.no_grad()
def predict_submission(model: nn.Module, test_loader: DataLoader, device, tta: int = 4) -> pd.DataFrame:
    model.eval()
    ids_all = []
    probs_all = []

    # Build simple flip TTAs
    def tta_views(x):
        xs = [x]
        if tta >= 2:
            xs.append(torch.flip(x, dims=[-1]))  # hflip
        if tta >= 3:
            xs.append(torch.flip(x, dims=[-2]))  # vflip
        if tta >= 4:
            xs.append(torch.flip(x, dims=[-2, -1]))  # hvflip
        return xs

    for batch in tqdm(test_loader, desc='test'):
        # test loader returns (tensor, id)
        x, ids = batch
        x = x.to(device, non_blocking=True)
        views = tta_views(x)
        logits_sum = 0.0
        for v in views:
            logits_sum += model(v)
        logits = logits_sum / len(views)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        ids_all.extend(list(ids))
        probs_all.append(probs)

    probs_all = np.concatenate(probs_all).reshape(-1)
    # pdb.set_trace()
    return pd.DataFrame({"id": ids_all, "label": probs_all})

In [None]:
# -----------------------------
# Main
# -----------------------------

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--data-dir', type=str, default=Config.data_dir)
    p.add_argument('--out-dir', type=str, default=Config.out_dir)
    p.add_argument('--model', type=str, default=Config.model)
    p.add_argument('--img-size', type=int, default=Config.img_size)
    p.add_argument('--batch-size', type=int, default=Config.batch_size)
    p.add_argument('--epochs', type=int, default=Config.epochs)
    p.add_argument('--lr', type=float, default=Config.lr)
    p.add_argument('--weight-decay', type=float, default=Config.weight_decay)
    p.add_argument('--val-size', type=float, default=Config.val_size)
    p.add_argument('--seed', type=int, default=Config.seed)
    p.add_argument('--center-emphasis', type=int, default=Config.center_emphasis)
    p.add_argument('--weighted-sampler', type=int, default=Config.use_weighted_sampler)
    p.add_argument('--tta', type=int, default=Config.tta)
    p.add_argument('--num-workers', type=int, default=Config.num_workers)
    return p.parse_args()


def main():
    args = parse_args()
    cfg = Config(
        data_dir=args.data_dir,
        out_dir=args.out_dir,
        model=args.model,
        img_size=args.img_size,
        batch_size=args.batch_size,
        epochs=args.epochs,
        lr=args.lr,
        weight_decay=args.weight_decay,
        val_size=args.val_size,
        seed=args.seed,
        center_emphasis=args.center_emphasis,
        use_weighted_sampler=args.weighted_sampler,
        tta=args.tta,
        num_workers=args.num_workers)
    fit(cfg)


@dataclass
class Config:
    data_dir: str = "/kaggle/input/histopathologic-cancer-detection"
    train_dir: str = "train"
    test_dir: str = "test"
    labels_csv: str = "train_labels.csv"
    out_dir: str = "/kaggle/working"

    # model: str = "resnet18"  # ["resnet18", "simple_cnn"]
    model: str = "simple_cnn"  # ["resnet18", "simple_cnn"]
    img_size: int = 128
    batch_size: int = 256
    epochs: int = 20
    lr: float = 3e-4
    weight_decay: float = 1e-4
    warmup_epochs: int = 1

    num_workers: int = 4
    val_size: float = 0.1
    seed: int = 42

    center_emphasis: int = 1  # 0/1 (apply Gaussian center mask)
    use_weighted_sampler: int = 0  # 0/1
    tta: int = 4  # number of test-time augmentations (flips)

if __name__ == "__main__":
    if sys.argv:
        del sys.argv[1:]
    main()