In [1]:
import json
import os
from datetime import datetime
from pathlib import Path

import pandas as pd
import torch
from torch.utils.data import DataLoader

from candidate_extractors.TemplateMatchExtractor import TemplateMatchExtractor
from candidate_extractors.TemplateMatchExtractor import generate_dataset_from_slides
from datasets.LabeledImageDataset import LabeledImageDataset
from models.resnet import Resnet50Model, Resnet18Model
from train import train_classifier
from utils import reduce_dataset, split_dataset

In [2]:
slides_root_dir = "data/whole-slides/gut"
annotations_root_dir = "data/annotations/json"
candidates_dataset_dir = "output/candidates"
model_output_dir = "output/models"

In [3]:
extractor = TemplateMatchExtractor()
generate_dataset_from_slides(slides_root_dir, extractor, candidates_dataset_dir)

Found cached candidates dataset output/candidates


In [4]:
def load_annotations(directory):
    slide_data = {"slide_name": [], "n_positive_annotations": []}
    for annotations_file in os.listdir(directory):
        with open(f"{directory}/{annotations_file}") as f:
            annotations = json.load(f)
            slide_name = Path(annotations_file).stem
            slide_data["slide_name"].append(slide_name)
            slide_data["n_positive_annotations"].append(len(annotations))
    return pd.DataFrame(slide_data)


def assign_categories(dataframe):
    q1, median, q3 = dataframe['n_positive_annotations'].quantile([0.25, 0.5, 0.75])

    def categorize_quartiles(n_annotations):
        if n_annotations <= q1:
            return "Low"
        elif q1 < n_annotations <= median:
            return "Medium"
        elif median < n_annotations <= q3:
            return "High"
        else:
            return "Very High"

    dataframe['category'] = dataframe['n_positive_annotations'].apply(lambda x: categorize_quartiles(x))
    return dataframe


def split_data(dataframe, train_portion=0.7):
    train_set = pd.DataFrame()
    test_set = pd.DataFrame()
    for category in dataframe['category'].unique():
        category_slides = dataframe[dataframe['category'] == category]
        train_samples = category_slides.sample(frac=train_portion)
        test_samples = category_slides.drop(train_samples.index)
        train_set = pd.concat([train_set, train_samples])
        test_set = pd.concat([test_set, test_samples])
    return train_set, test_set


slides_df = load_annotations(annotations_root_dir)
slides_df = assign_categories(slides_df)
train_slides, test_slides = split_data(slides_df)
print("Train Slides")
train_slides

Train Slides


Unnamed: 0,slide_name,n_positive_annotations,category
2,593433,3,Low
18,593451,1,Low
10,593441,4,Low
0,522021,1,Low
12,593445,100,Very High
1,522934,129,Very High
4,593435,91,Very High
19,593452,167,Very High
20,593453,18,Medium
3,593434,27,Medium


In [5]:
print("Test Slides")
test_slides

Test Slides


Unnamed: 0,slide_name,n_positive_annotations,category
13,593446,15,Low
15,593448,13,Low
5,593436,151,Very High
16,593449,226,Very High
14,593447,25,Medium
6,593437,87,High


In [6]:
from datasets.SlideSeperatedDataset import SlideSeperatedDataset

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

batch_size = 256
dataset = SlideSeperatedDataset(candidates_dataset_dir, set(train_slides["slide_name"]))
dataset = reduce_dataset(dataset, discard_ratio=0.0)
train_dataset, validation_dataset = split_dataset(dataset, train_ratio=0.9)
# train_dataset = undersample_dataset(train_dataset)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
validation_loader = DataLoader(validation_dataset,
                               batch_size=batch_size,
                               shuffle=True, )

model = Resnet18Model(hidden_layers=1, units_per_layer=2048,
                      dropout=0.3, focal_alpha=0.9, focal_gamma=2.0)

print(f"Dataset: {len(train_dataset):,} training, {len(validation_dataset):,} validation")


Device: cuda:0
Dataset: 176,655 training, 19,629 validation


In [7]:
print(model)

Resnet18Model(
  (pretrained_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [8]:
print(f"Training starts {datetime.now().isoformat()}")

Training starts 2025-02-01T23:56:30.762910


In [None]:

model = model.to(device)
model, model_metrics = train_classifier(model, train_loader, validation_loader, device,
                                        start_learning_rate=0.000075,
                                        max_epochs=20,
                                        checkpoint_every=None,
                                        eval_every=1)


Epoch 1 training:  84%|████████▎ | 578/691 [03:48<00:45,  2.46it/s]

In [10]:
os.makedirs(model_output_dir, exist_ok=True)
torch.save({
    "model": model,
    "train_slides": set(train_slides["slide_name"]),
    "test_slides": set(test_slides["slide_name"])
},
    f"{model_output_dir}/model.pickle"
)