In [1]:
import json
import os
from datetime import datetime
from pathlib import Path

import pandas as pd

from candidate_extractors.TemplateMatchExtractor import TemplateMatchExtractor
from train import train_classifier

In [2]:
def load_annotations(directory):
    slide_data = {"slide_name": [], "n_positive_annotations": []}
    for annotations_file in os.listdir(directory):
        with open(f"{directory}/{annotations_file}") as f:
            annotations = json.load(f)
            slide_name = Path(annotations_file).stem
            slide_data["slide_name"].append(f"{slide_name}.svs")
            slide_data["n_positive_annotations"].append(len(annotations))
    return pd.DataFrame(slide_data)


def assign_categories(dataframe):
    q1, median, q3 = dataframe['n_positive_annotations'].quantile([0.25, 0.5, 0.75])

    def categorize_quartiles(n_annotations):
        if n_annotations <= q1:
            return "Low"
        elif q1 < n_annotations <= median:
            return "Medium"
        elif median < n_annotations <= q3:
            return "High"
        else:
            return "Very High"

    dataframe['category'] = dataframe['n_positive_annotations'].apply(lambda x: categorize_quartiles(x))
    return dataframe


def split_data(dataframe, train_portion=0.7):
    train_set = pd.DataFrame()
    test_set = pd.DataFrame()
    for category in dataframe['category'].unique():
        category_slides = dataframe[dataframe['category'] == category]
        train_samples = category_slides.sample(frac=train_portion)
        test_samples = category_slides.drop(train_samples.index)
        train_set = pd.concat([train_set, train_samples])
        test_set = pd.concat([test_set, test_samples])
    return train_set, test_set


slides_root_dir = "data/whole-slides/gut"
annotations_root_dir = "data/annotations/json"
training_candidates_dir = "output/candidates/training"
slides_df = load_annotations(annotations_root_dir)
slides_df = assign_categories(slides_df)
train_slides, test_slides = split_data(slides_df)
train_slides

Unnamed: 0,slide_name,n_positive_annotations,category
15,593448.svs,13,Low
2,593433.svs,3,Low
10,593441.svs,4,Low
13,593446.svs,15,Low
5,593436.svs,151,Very High
19,593452.svs,167,Very High
12,593445.svs,100,Very High
1,522934.svs,129,Very High
8,593439.svs,27,Medium
14,593447.svs,25,Medium


In [3]:
test_slides

Unnamed: 0,slide_name,n_positive_annotations,category
0,522021.svs,1,Low
18,593451.svs,1,Low
4,593435.svs,91,Very High
16,593449.svs,226,Very High
3,593434.svs,27,Medium
6,593437.svs,87,High


In [4]:
candidate_extractor = TemplateMatchExtractor(output_dir=training_candidates_dir, verbose=True)
for i, row in train_slides.iterrows():
    slide_filepath = os.path.join(slides_root_dir, row["slide_name"])
    candidate_extractor.extract_candidates(slide_filepath)

Extracting candidates from 593448


Progress: 90it [00:06, 13.99it/s]                        


Extracting candidates from 593433


Progress: 75it [00:03, 19.07it/s]                        


Extracting candidates from 593441


Progress: 301it [00:15, 19.11it/s]                         


Extracting candidates from 593446


Progress: 232it [00:13, 17.39it/s]                         


Extracting candidates from 593436


Progress: 48it [00:10,  4.45it/s]                        


Extracting candidates from 593452


Progress: 609it [02:54,  3.50it/s]                         


Extracting candidates from 593445


Progress: 450it [00:21, 20.89it/s]                         


Extracting candidates from 522934


Progress: 690it [01:20,  8.61it/s]                         


Extracting candidates from 593439


Progress: 310it [00:37,  8.30it/s]                         


Extracting candidates from 593447


Progress: 232it [00:11, 20.23it/s]                         


Extracting candidates from 593454


Progress: 152it [00:14, 10.72it/s]                         


Extracting candidates from 593453


Progress: 1224it [00:45, 26.72it/s]                          


Extracting candidates from 593440


Progress: 576it [01:41,  5.69it/s]                         


Extracting candidates from 593444


Progress: 102it [00:09, 10.55it/s]                       


Extracting candidates from 593438


Progress: 665it [01:23,  7.96it/s]                         


Extracting candidates from 593450


Progress: 95it [00:07, 11.91it/s]                        


In [5]:


import torch
from torch.utils.data import DataLoader

from datasets.LabeledImageDataset import LabeledImageDataset
from models.resnet import Resnet50Model
from utils import reduce_dataset, split_dataset

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

batch_size = 256
dataset = LabeledImageDataset(training_candidates_dir)
dataset = reduce_dataset(dataset, discard_ratio=0.0)
train_dataset, validation_dataset = split_dataset(dataset, train_ratio=0.7)
# train_dataset = undersample_dataset(train_dataset)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(validation_dataset,
                         batch_size=batch_size,
                         shuffle=True, )

model = Resnet50Model(hidden_layers=1, units_per_layer=2048,
                      dropout=0.3, focal_alpha=0.9, focal_gamma=2.0)

print(f"Dataset: {len(train_dataset):,} training, {len(validation_dataset):,} validation")


Device: cuda:0
Dataset: 127,494 training, 54,641 validation


In [6]:
print(model)

Resnet50Model(
  (pretrained_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequent

In [7]:
print(f"Training starts {datetime.now().isoformat()}")

Training starts 2025-01-31T21:49:12.663434


In [None]:

model = model.to(device)
model, model_metrics = train_classifier(model, train_loader, test_loader, device,
                                        start_learning_rate=0.000075,
                                        max_epochs=5,
                                        checkpoint_every=1,
                                        eval_every=1)


Epoch 1 training: 100%|██████████| 499/499 [04:59<00:00,  1.67it/s]


Train: 1/5: lr: 0.000075000000 loss:0.0038263302034290624


Epoch 1 testing: 100%|██████████| 214/214 [02:02<00:00,  1.75it/s]


Test: 1/5: loss:0.0031014878126557604, accuracy:0.9917944842754941, precision:0.25404984423676014, recall:0.22967289719626172, f1:0.21437472185135736, mcc:0.22439971237068035, ece:0.008205515708965815, epoch:0


Epoch 2 training:  47%|████▋     | 233/499 [02:00<02:15,  1.96it/s]