## Lymph Node Metastasis Detection Pipeline

Welcome to the notebook for The Ants machine learning pipeline. This notebook contains a simple machine learning model trained using images from the [PCam](https://github.com/basveeling/pcam) dataset. This model was used to extract areas of interest to image from patient lymph node slides.

In [None]:
import h5py
import numpy as np
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torch
import torchvision.transforms as T    
import torch.nn as nn
from torchvision import models
from tqdm import tqdm
import os

## Custom H5 Dataset

In [None]:
class H5PatchDataset(Dataset):
    """Reads images/labels from an .h5 file. Assumes datasets 'x' and 'y' exist.
    x: (N, H, W, C) uint8
    y: (N,) int labels 0/1
    """

    def __init__(self, h5_path_x, h5_path_y, transform=None, img_key="x", label_key="y"):
        self.h5_path_x = h5_path_x
        self.h5_path_y = h5_path_y
        self.transform = transform
        self.img_key = img_key
        self.label_key = label_key

        self._h5_img = None
        self._h5_label = None
        self._length = None

    def __len__(self):
        if self._length is not None:
            return int(self._length)
        with h5py.File(self.h5_path_y, "r") as fl:
            return int(fl[self.label_key].shape[0])
        
    def __getitem__(self, idx):
        if self._h5_img is None:
            self._h5_img = h5py.File(self.h5_path_x, "r")
            self._h5_label = h5py.File(self.h5_path_y, "r")
            self._length = int(self._h5_label[self.label_key].shape[0])

        try:
            img = self._h5_img[self.img_key][idx]
        except Exception as e:
            raise RuntimeError(
                f"Failed reading image index {idx} with key '{self.img_key}'"
            ) from e
        try:
            label = int(self._h5_label[self.label_key][idx])
        except Exception as e:
            raise RuntimeError(
                f"Failed reading label index {idx} with key '{self.label_key}'"
            ) from e

        if not isinstance(img, np.ndarray):
            img = np.array(img)

        pil = Image.fromarray(img.astype("uint8"))

        if self.transform is not None:
            img_tensor = self.transform(pil)
        else:
            img_tensor = T.ToTensor()(pil)

        return img_tensor, torch.tensor(label, dtype=torch.float32)


## Custom Model

In [None]:
class CNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_features):
        super(CNN, self).__init__()
        # 96 x 96
        self.conv1 = nn.Conv2d(in_channels, hidden_channels[0], kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(hidden_channels[0])
        self.relu1 = nn.ReLU(inplace=True)
        self.max_pool1 = nn.AvgPool2d(2)
        # 48 x 48
        self.conv2 = nn.Conv2d(hidden_channels[0], hidden_channels[1], kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm2d(hidden_channels[0])
        self.relu2 = nn.ReLU()
        self.max_pool2 = nn.AvgPool2d(2)
        # 24 x 24
        self.conv3 = nn.Conv2d(hidden_channels[1], hidden_channels[2], kernel_size=5, padding=2)
        self.relu3 = nn.ReLU()
        self.max_pool3 = nn.AvgPool2d(2)
        # 12 x 12
        self.fc1 = nn.Linear(12*12*hidden_channels[2], 12*12)
        self.fc = nn.Linear(144, out_features)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.max_pool1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.max_pool2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.max_pool3(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc(x)
        return x



## Helper Functions

In [None]:
def get_transforms(img_size=96, train=True):
    if train:
        return T.Compose(
            [
                T.Resize((img_size, img_size)),
                T.RandomHorizontalFlip(),
                T.RandomVerticalFlip(),
                T.RandomRotation(90),
                T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )
    else:
        return T.Compose(
            [
                T.Resize((img_size, img_size)),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )
    
def eval_accuracy(data_loader, cnn, device=torch.device('cpu')):
    cnn.eval()

    accuracy = 0
    n = 0

    for X, y in tqdm(data_loader):
        X, y = X.to(device), y.to(device).long()
        with torch.no_grad():
            preds = ((cnn(X).squeeze(1))>0).long()
            accuracy += (preds == y).sum().item()
            n += y.size(0)
    
    return accuracy/n

## Model Arguments

In [None]:
# args
h5_img = "data\\camelyonpatch_split_training_x.h5"
h5_label = "data\\camelyonpatch_split_training_y.h5"
in_channels = 3
hidden_channels = [64, 64, 64]
out_features = 1
epochs = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

transform_train = get_transforms()
transform_test = get_transforms(train=False)

train_data = H5PatchDataset(
        h5_path_x=h5_img,
        h5_path_y=h5_label,
        transform=transform_train,
    )
test_data = H5PatchDataset(
        h5_path_x=h5_img,
        h5_path_y=h5_label,
        transform=transform_test,
    )

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)

## Model Selection

In [None]:
# Uncomment to use our custom model:
cnn = CNN(in_channels, hidden_channels, out_features) 

# Uncomment to use resnet18 (pretrained model)
# cnn = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(device)
# cnn.fc = nn.Linear(cnn.fc.in_features, 1)

optimizer = torch.optim.SGD(cnn.parameters(), lr=0.001, weight_decay=1e-4)
criterion = nn.BCEWithLogitsLoss()

## Training & Evaluation

In [None]:
train_loss = []
train_acc = []
test_acc = []

for epoch in range(epochs):

    cnn.train()
    cnn.to(device)

    for i, (x_batch, y_batch) in enumerate(tqdm(train_loader, desc="train", leave=False)):

        x_batch, y_batch = x_batch.to(device), y_batch.float().to(device)

        optimizer.zero_grad()

        y_pred = cnn(x_batch)

        loss = criterion(y_pred.squeeze(1), y_batch)
        train_loss.append(loss)

        loss.backward()
        optimizer.step()

    train_accuracy = 100*eval_accuracy(train_loader, cnn.to('cpu'))
    test_accuracy = 100*eval_accuracy(test_loader, cnn.to('cpu'))

    train_acc.append(train_accuracy)
    test_acc.append(test_accuracy)

    print(f"Epoch: {epoch+1}")
    print('Train accuracy: {:.00f}%'.format(train_accuracy))
    print('Test accuracy: {:.00f}%'.format(test_accuracy))


print(train_acc)
print(test_acc)

## Saving

If you would like to save the model so that it can be loaded again later, run the following code:

In [None]:
output_dir = "results"
save_path = os.path.join(output_dir, "best.pth")
torch.save({
    'model_state': cnn.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'epoch': epoch,
}, save_path)