## Lymph Node Metastasis Detection Pipeline

Welcome to the notebook for The Ants machine learning pipeline. This notebook contains a simple machine learning model trained using images from the [PCam](https://github.com/basveeling/pcam) dataset. This model was used to extract areas of interest to image from patient lymph node slides.

In [None]:
import h5py
import numpy as np
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torch
import torchvision.transforms as T
import torch.nn as nn
from tqdm import tqdm
import os
import csv
from glob import glob
from pathlib import Path
from collections import defaultdict
import pandas as pd
from PIL import ImageDraw, ImageFont

## Custom H5 Dataset

In [None]:
class H5PatchDataset(Dataset):
    """Reads images/labels from an .h5 file. Assumes datasets 'x' and 'y' exist.
    x: (N, H, W, C) uint8
    y: (N,) int labels 0/1
    """

    def __init__(
        self, h5_path_x, h5_path_y, transform=None, img_key="x", label_key="y"
    ):
        self.h5_path_x = h5_path_x
        self.h5_path_y = h5_path_y
        self.transform = transform
        self.img_key = img_key
        self.label_key = label_key

        self._h5_img = None
        self._h5_label = None
        self._length = None

    def __len__(self):
        if self._length is not None:
            return int(self._length)
        with h5py.File(self.h5_path_y, "r") as fl:
            return int(fl[self.label_key].shape[0])

    def __getitem__(self, idx):
        if self._h5_img is None:
            self._h5_img = h5py.File(self.h5_path_x, "r")
            self._h5_label = h5py.File(self.h5_path_y, "r")
            self._length = int(self._h5_label[self.label_key].shape[0])

        try:
            img = self._h5_img[self.img_key][idx]
        except Exception as e:
            raise RuntimeError(
                f"Failed reading image index {idx} with key '{self.img_key}'"
            ) from e
        try:
            label = int(self._h5_label[self.label_key][idx])
        except Exception as e:
            raise RuntimeError(
                f"Failed reading label index {idx} with key '{self.label_key}'"
            ) from e

        if not isinstance(img, np.ndarray):
            img = np.array(img)

        pil = Image.fromarray(img.astype("uint8"))

        if self.transform is not None:
            img_tensor = self.transform(pil)
        else:
            img_tensor = T.ToTensor()(pil)

        return img_tensor, torch.tensor(label, dtype=torch.float32)

## Convolutional Neural Network

The code below creates a convolutional neural network to work with the PCam 96 x 96 images. It consists of three convolutional layers with relu activation function and batch normalization, with average pooling between layers. A fully connected neural network is used to achieve the final output.

In [None]:
class CNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_features):
        super(CNN, self).__init__()
        # 96 x 96
        self.conv1 = nn.Conv2d(
            in_channels, hidden_channels[0], kernel_size=3, padding=1
        )
        self.bn1 = nn.BatchNorm2d(hidden_channels[0])
        self.relu1 = nn.ReLU(inplace=True)
        self.max_pool1 = nn.AvgPool2d(2)
        # 48 x 48
        self.conv2 = nn.Conv2d(
            hidden_channels[0], hidden_channels[1], kernel_size=5, padding=2
        )
        self.bn2 = nn.BatchNorm2d(hidden_channels[0])
        self.relu2 = nn.ReLU()
        self.max_pool2 = nn.AvgPool2d(2)
        # 24 x 24
        self.conv3 = nn.Conv2d(
            hidden_channels[1], hidden_channels[2], kernel_size=5, padding=2
        )
        self.relu3 = nn.ReLU()
        self.max_pool3 = nn.AvgPool2d(2)
        # 12 x 12
        self.fc1 = nn.Linear(12 * 12 * hidden_channels[2], 12 * 12)
        self.fc = nn.Linear(144, out_features)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.max_pool1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.max_pool2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.max_pool3(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc(x)
        return x

## Helper Functions

In [None]:
def get_transforms(img_size=96, train=True):
    if train:
        return T.Compose(
            [
                T.Resize((img_size, img_size)),
                T.RandomHorizontalFlip(),
                T.RandomVerticalFlip(),
                T.RandomRotation(90),
                T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )
    else:
        return T.Compose(
            [
                T.Resize((img_size, img_size)),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )


def eval_accuracy(data_loader, cnn, device=torch.device("cpu")):
    cnn.eval()

    accuracy = 0
    n = 0

    for X, y in tqdm(data_loader, desc="evaluation"):
        X, y = X.to(device), y.to(device).long()
        with torch.no_grad():
            preds = ((cnn(X).squeeze(1)) > 0).long()
            accuracy += (preds == y).sum().item()
            n += y.size(0)

    return accuracy / n

## Load Datasets

Run the code below to load the PCam dataset into dataloaders. The code will automatically select `cuda` if available, otherwise the CPU will be used for training.

In [None]:
# args
h5_img = "../data/camelyonpatch_split_training_x.h5"
h5_label = "../data/camelyonpatch_split_training_y.h5"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

transform_train = get_transforms()
transform_test = get_transforms(train=False)

train_data = H5PatchDataset(
    h5_path_x=h5_img,
    h5_path_y=h5_label,
    transform=transform_train,
)
test_data = H5PatchDataset(
    h5_path_x=h5_img,
    h5_path_y=h5_label,
    transform=transform_test,
)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)

## Model Selection

In the code below you can select the model you would like to use. The first model is a custom model that can be adjusted in the `CNN` class and has parameters for input channels, hidden channels, and output features that can be adjusted as desired. The second is the pretrained `resnet18` model optimized for image recognition.

In [None]:
# Uncomment to use our custom model:
in_channels = 3
hidden_channels = [64, 64, 64]
out_features = 1

# cnn = CNN(in_channels, hidden_channels, out_features)
# optimizer = torch.optim.SGD(cnn.parameters(), lr=0.001, weight_decay=1e-4)
# criterion = nn.BCEWithLogitsLoss()

# Uncomment to use resnet18 (pretrained model)
from torchvision import models

cnn = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(device)
cnn.fc = nn.Linear(cnn.fc.in_features, 1)
optimizer = torch.optim.AdamW(cnn.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.BCEWithLogitsLoss()

## Training & Evaluation

The code block below trains the model using pyTorch for a specified number of epochs. Adjust the value of `epochs` as desired. Accuracy for the training set is stored in `train_acc` and for the test set is stored in `test_acc`.

In [None]:
epochs = 3

train_loss = []
train_acc = []
test_acc = []

for epoch in range(epochs):
    cnn.train()
    cnn.to(device)

    for i, (x_batch, y_batch) in enumerate(
        tqdm(train_loader, desc="train", leave=False)
    ):
        x_batch, y_batch = x_batch.to(device), y_batch.float().to(device)

        optimizer.zero_grad()

        y_pred = cnn(x_batch)

        loss = criterion(y_pred.squeeze(1), y_batch)
        train_loss.append(loss.item())

        loss.backward()
        optimizer.step()

    train_accuracy = 100 * eval_accuracy(train_loader, cnn.to("cpu"))
    test_accuracy = 100 * eval_accuracy(test_loader, cnn.to("cpu"))

    train_acc.append(train_accuracy)
    test_acc.append(test_accuracy)

    print(f"Epoch: {epoch + 1}")
    print("Train accuracy: {:.00f}%".format(train_accuracy))
    print("Test accuracy: {:.00f}%".format(test_accuracy))


print(train_acc)
print(test_acc)

## Saving

If you would like to save the model so that it can be loaded again later, run the following code:

In [None]:
output_dir = Path("../results_notebook")
output_dir.mkdir(parents=True, exist_ok=True)

save_path = os.path.join(output_dir, "best.pth")
torch.save(
    {
        "model_state": cnn.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "epoch": epoch,
    },
    save_path,
)

## TIFF Dataset



In [None]:
class TiffPatchDataset(Dataset):
    def __init__(self, pathes, patch_size, transforms):
        self.tiff_pathes = pathes
        self.patch_size = patch_size
        self.transforms = transforms

        self.index = []
        self.image_sizes = []

        for i, f in enumerate(self.tiff_pathes):
            with Image.open(f) as img:
                w, h = img.size
            self.image_sizes.append((w, h))

            x_vals = range(0, w - self.patch_size + 1, self.patch_size)
            y_vals = range(0, h - self.patch_size + 1, self.patch_size)

            for x in x_vals:
                for y in y_vals:
                    self.index.append((i, x, y))

    def __len__(self):
        return len(self.index)

    def __getitem__(self, index):
        image, x, y = self.index[index]
        x2 = x + self.patch_size
        y2 = y + self.patch_size

        tiff_path = self.tiff_pathes[image]
        img = Image.open(tiff_path)
        w, h = img.size

        try:
            img = img.convert("RGB")
        except Exception:
            print("Issue with RGB conversion")

        if x2 <= w and y2 <= h:
            patch = img.crop((x, y, x2, y2))

        patch_trans = self.transforms(patch)

        meta = {
            "image_path": self.tiff_pathes[image],
            "x": int(x),
            "y": int(y),
            "patch_w": self.patch_size,
            "patch_h": self.patch_size,
            "img_idx": int(image),
        }
        return patch_trans, meta

In [None]:
def collate_fn(batch):
    imgs = [b[0] for b in batch]
    metas = [b[1] for b in batch]
    imgs = torch.stack(imgs, dim=0)
    return imgs, metas

## Using the Model on TIFF Files


In [None]:
# save_path = "../results_notebook/tiff_preds_notebook_resnet/best.pth"
save_path = "../results_notebook/tiff_preds_notebook/best.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


in_channels = 3
hidden_channels = [64, 64, 64]
out_features = 1

# Uncomment to use our custom model:
model = CNN(in_channels, hidden_channels, out_features)

# Uncomment to use resnet18 (pretrained model)
# from torchvision import models
# model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(device)
# model.fc = nn.Linear(model.fc.in_features, 1)

model.load_state_dict(torch.load(save_path)["model_state"])
model.eval()

In [None]:
tiff_dir = "../tiff_files"
tiff_paths = sorted(
    glob(os.path.join(tiff_dir, "**", "*.tif"), recursive=True)
    + glob(os.path.join(tiff_dir, "**", "*.tiff"), recursive=True)
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = get_transforms(img_size=96, train=False)

dataset = TiffPatchDataset(tiff_paths, 96, transform)
print(f"Total patches: {len(dataset)}")
print(dataset.__getitem__(1)[0].shape)

# dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, collate_fn=collate_fn, pin_memory=True)
dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn,
    pin_memory=False,
)


output_dir = "../results_notebook/tiff_preds_notebook"
output_csv = "patches_preds"
threshold = 0.7

os.makedirs(output_dir, exist_ok=True)
csv_path = os.path.join(output_dir, output_csv)

results = []
with torch.no_grad():
    for imgs, metas in tqdm(dataloader, desc="inference"):
        imgs = imgs.to(device)
        outputs = model(imgs).squeeze(1)

        probs = torch.sigmoid(outputs).cpu().numpy().tolist()
        preds = [1 if p >= threshold else 0 for p in probs]

        # debug
        # print(f"Outputs: {outputs}")
        # print(f"Probs: {probs}")
        # print(f"Preds: {preds}")
        for m, p, pr in zip(metas, probs, preds):
            results.append(
                {
                    "image_path": m["image_path"],
                    "x": m["x"],
                    "y": m["y"],
                    "probability": float(p),
                    "prediction": int(pr),
                }
            )

with open(csv_path, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["image_path", "x", "y", "probability", "prediction"])
    for r in results:
        writer.writerow(
            [r["image_path"], r["x"], r["y"], r["probability"], r["prediction"]]
        )

print("Saved per-patch predictions to:", csv_path)

## Visualize Outputs
The code below takes the CSV file outputted from the above prediction code and highlights the patches where the model predicted the presence of metastases. 

In [None]:
patch_size = 96
font_size = 30


def draw_box(img_path, boxes, out_path, width=3):
    im = Image.open(img_path).convert("RGB")
    draw = ImageDraw.Draw(im)

    font = ImageFont.truetype("../arial.ttf", font_size)

    # font = ImageFont.load_default()

    for b in boxes:
        x = int(b["x"])
        y = int(b["y"])
        p = b.get("probability", None)
        draw_color = (
            (255, 255, 0) if p < 0.8 else (255, 165, 0) if p < 0.9 else (255, 0, 0)
        )

        rect = (x, y, x + patch_size, y + patch_size)

        for i in range(width):
            draw.rectangle(
                (rect[0] - i, rect[1] - i, rect[2] + i, rect[3] + i),
                outline=draw_color,
            )

        text = f"{p:.2f}"
        text_pos = (x + (patch_size // 2) - 20, y + (patch_size // 2) - 10)
        draw.text(text_pos, text, fill=draw_color, font=font)

    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    im.save(out_path)
    print("Saved:", out_path)
    display(im)


output_dir = "../results_notebook/tiff_preds_notebook"
output_csv = "patches_preds"
csv_path = os.path.join(output_dir, output_csv)
df = pd.read_csv(csv_path)

grouped = defaultdict(list)
for _, row in df.iterrows():
    if int(row["prediction"]) == 1:
        grouped[row["image_path"]].append(
            {
                "x": int(row["x"]),
                "y": int(row["y"]),
                "probability": float(row["probability"]),
            }
        )

print(f"{len(grouped)} files with positive patches found")

for img_path, boxes in grouped.items():
    if not os.path.exists(img_path):
        print(f"WARNING: image not found: {img_path} â€” skipping")
        continue
    base = os.path.basename(img_path)
    out_path = os.path.join(
        output_dir, base.replace(".tif", "_annot.png").replace(".tiff", "_annot.png")
    )
    draw_box(img_path, boxes, out_path, width=3)