In [None]:
# Expected dataset layout (simple + common)
data/
  train/
    images/   (e.g., .png, .jpg)
    masks/    (same filenames as images)
  val/
    images/
    masks/


In [2]:
# Install requirements
pip install torch torchvision pillow numpy tqdm opencv-python scipy
pip install openslide-python


In [None]:
# Training + inference code (single-file)Loads (image, mask) pairs. Mask is expected to be binary (0 or 255). Applies simple paired augmentations + resize.
# Inference + "cell detections"urns a binary mask into approximate "cell detections" by connected components.
#  This is a crude approach: for true cell detection use instance segmentation (e.g., Hover-Net).

import os
import glob
import math
import random
from dataclasses import dataclass
from typing import Tuple, Dict, List

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

import cv2
from scipy import ndimage as ndi


# -----------------------------
# Config
# -----------------------------
@dataclass
class Config:
    data_dir: str = "data"
    img_size: int = 512
    batch_size: int = 4
    num_workers: int = 2
    lr: float = 1e-4
    epochs: int = 25
    threshold: float = 0.5
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    out_dir: str = "runs/unet_cancer"
    save_name: str = "best_unet.pt"


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# -----------------------------
# Dataset
# -----------------------------
class HistopathSegDataset(Dataset):
    """
    Loads (image, mask) pairs. Mask is expected to be binary (0 or 255).
    Applies simple paired augmentations + resize.
    """
    def __init__(self, img_dir: str, mask_dir: str, img_size: int, augment: bool = True):
        self.img_paths = sorted(glob.glob(os.path.join(img_dir, "*")))
        self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*")))
        assert len(self.img_paths) == len(self.mask_paths), "Images and masks count mismatch."
        self.img_size = img_size
        self.augment = augment

    def __len__(self):
        return len(self.img_paths)

    def _paired_augment(self, img: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        # Random flips
        if random.random() < 0.5:
            img = np.fliplr(img).copy()
            mask = np.fliplr(mask).copy()
        if random.random() < 0.5:
            img = np.flipud(img).copy()
            mask = np.flipud(mask).copy()

        # Random rotation (0, 90, 180, 270)
        k = random.randint(0, 3)
        if k:
            img = np.rot90(img, k).copy()
            mask = np.rot90(mask, k).copy()

        # Mild color jitter in HSV (helps stain variability a bit)
        if random.random() < 0.3:
            hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32)
            hsv[..., 1] *= random.uniform(0.9, 1.1)  # saturation
            hsv[..., 2] *= random.uniform(0.9, 1.1)  # value/brightness
            hsv = np.clip(hsv, 0, 255).astype(np.uint8)
            img = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)

        return img, mask

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        img = np.array(Image.open(self.img_paths[idx]).convert("RGB"))
        mask = np.array(Image.open(self.mask_paths[idx]).convert("L"))

        # Resize
        img = cv2.resize(img, (self.img_size, self.img_size), interpolation=cv2.INTER_AREA)
        mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)

        # Binarize mask (0/1)
        mask = (mask > 127).astype(np.float32)

        if self.augment:
            img, mask = self._paired_augment(img, mask)

        # Normalize to [0,1], CHW
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, (2, 0, 1))
        mask = np.expand_dims(mask, axis=0)

        return {
            "image": torch.tensor(img, dtype=torch.float32),
            "mask": torch.tensor(mask, dtype=torch.float32),
        }


# -----------------------------
# U-Net model
# -----------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=1, features=(64, 128, 256, 512)):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)

        # Down path
        ch = in_ch
        for f in features:
            self.downs.append(DoubleConv(ch, f))
            ch = f

        # Bottleneck
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)

        # Up path
        rev = list(features)[::-1]
        up_ch = features[-1] * 2
        for f in rev:
            self.ups.append(nn.ConvTranspose2d(up_ch, f, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(up_ch, f))
            up_ch = f

        self.final = nn.Conv2d(features[0], out_ch, kernel_size=1)

    def forward(self, x):
        skips = []
        for down in self.downs:
            x = down(x)
            skips.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skips = skips[::-1]

        for i in range(0, len(self.ups), 2):
            x = self.ups[i](x)
            skip = skips[i // 2]

            # Handle odd sizes (just in case)
            if x.shape[-2:] != skip.shape[-2:]:
                x = nn.functional.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)

            x = torch.cat([skip, x], dim=1)
            x = self.ups[i + 1](x)

        return self.final(x)


# -----------------------------
# Losses + metrics
# -----------------------------
def dice_loss(logits: torch.Tensor, targets: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    probs = torch.sigmoid(logits)
    num = 2 * (probs * targets).sum(dim=(2, 3))
    den = (probs + targets).sum(dim=(2, 3)) + eps
    dice = num / den
    return 1 - dice.mean()


@torch.no_grad()
def dice_score(logits: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5, eps: float = 1e-6) -> float:
    probs = torch.sigmoid(logits)
    preds = (probs > threshold).float()
    num = 2 * (preds * targets).sum(dim=(2, 3))
    den = (preds + targets).sum(dim=(2, 3)) + eps
    dice = (num / den).mean().item()
    return float(dice)


# -----------------------------
# Train / Validate
# -----------------------------
def run_epoch(model, loader, optimizer, device, train=True) -> Tuple[float, float]:
    if train:
        model.train()
    else:
        model.eval()

    bce = nn.BCEWithLogitsLoss()
    total_loss = 0.0
    total_dice = 0.0
    n = 0

    for batch in tqdm(loader, leave=False):
        imgs = batch["image"].to(device)
        masks = batch["mask"].to(device)

        with torch.set_grad_enabled(train):
            logits = model(imgs)
            loss = 0.5 * bce(logits, masks) + 0.5 * dice_loss(logits, masks)

            if train:
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()

        total_loss += loss.item()
        total_dice += dice_score(logits, masks)
        n += 1

    return total_loss / max(n, 1), total_dice / max(n, 1)


def train(cfg: Config):
    set_seed(cfg.seed)
    os.makedirs(cfg.out_dir, exist_ok=True)

    train_ds = HistopathSegDataset(
        img_dir=os.path.join(cfg.data_dir, "train/images"),
        mask_dir=os.path.join(cfg.data_dir, "train/masks"),
        img_size=cfg.img_size,
        augment=True,
    )
    val_ds = HistopathSegDataset(
        img_dir=os.path.join(cfg.data_dir, "val/images"),
        mask_dir=os.path.join(cfg.data_dir, "val/masks"),
        img_size=cfg.img_size,
        augment=False,
    )

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                              num_workers=cfg.num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False,
                            num_workers=cfg.num_workers, pin_memory=True)

    model = UNet().to(cfg.device)
    optimizer = optim.AdamW(model.parameters(), lr=cfg.lr)

    best_val_dice = -1.0
    best_path = os.path.join(cfg.out_dir, cfg.save_name)

    for epoch in range(cfg.epochs):
        tr_loss, tr_dice = run_epoch(model, train_loader, optimizer, cfg.device, train=True)
        va_loss, va_dice = run_epoch(model, val_loader, optimizer, cfg.device, train=False)

        print(f"Epoch {epoch+1:02d}/{cfg.epochs} | "
              f"train loss={tr_loss:.4f} dice={tr_dice:.4f} | "
              f"val loss={va_loss:.4f} dice={va_dice:.4f}")

        if va_dice > best_val_dice:
            best_val_dice = va_dice
            torch.save({"model": model.state_dict(), "cfg": cfg.__dict__}, best_path)
            print(f"  ✓ saved best model to: {best_path} (val dice={best_val_dice:.4f})")

    print("Training complete.")


# -----------------------------
# Inference + "cell detections"
# -----------------------------
@torch.no_grad()
def predict_mask(model_path: str, image_path: str, img_size: int = 512, threshold: float = 0.5) -> np.ndarray:
    ckpt = torch.load(model_path, map_location="cpu")
    model = UNet()
    model.load_state_dict(ckpt["model"])
    model.eval()

    img = np.array(Image.open(image_path).convert("RGB"))
    img_resized = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_AREA)
    x = (img_resized.astype(np.float32) / 255.0).transpose(2, 0, 1)
    x = torch.tensor(x, dtype=torch.float32).unsqueeze(0)

    logits = model(x)
    prob = torch.sigmoid(logits)[0, 0].numpy()
    pred = (prob >= threshold).astype(np.uint8)  # 0/1

    return pred  # resized mask


def mask_to_cell_detections(binary_mask: np.ndarray, min_area: int = 40) -> List[Dict]:
    """
    Turns a binary mask into approximate "cell detections" by connected components.
    This is a crude approach: for true cell detection use instance segmentation (e.g., Hover-Net).
    """
    labeled, n = ndi.label(binary_mask)
    detections = []
    for label_id in range(1, n + 1):
        ys, xs = np.where(labeled == label_id)
        if len(xs) == 0:
            continue
        area = len(xs)
        if area < min_area:
            continue
        x1, x2 = int(xs.min()), int(xs.max())
        y1, y2 = int(ys.min()), int(ys.max())
        cx, cy = float(xs.mean()), float(ys.mean())
        detections.append({
            "bbox": [x1, y1, x2, y2],
            "centroid": [cx, cy],
            "area": int(area),
        })
    return detections


def overlay_mask(image_path: str, binary_mask: np.ndarray, img_size: int = 512) -> np.ndarray:
    img = np.array(Image.open(image_path).convert("RGB"))
    img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_AREA)
    mask_rgb = np.zeros_like(img)
    mask_rgb[..., 1] = (binary_mask * 255).astype(np.uint8)  # green overlay channel
    out = cv2.addWeighted(img, 0.8, mask_rgb, 0.2, 0)
    return out


if __name__ == "__main__":
    cfg = Config()
    # 1) Train:
    # train(cfg)

    # 2) Inference example:
    # model_path = os.path.join(cfg.out_dir, cfg.save_name)
    # pred = predict_mask(model_path, "data/val/images/example.png", img_size=cfg.img_size, threshold=cfg.threshold)
    # dets = mask_to_cell_detections(pred, min_area=40)
    # print("Detections:", dets[:5], "… total:", len(dets))

