In [1]:
import json
import os
from datetime import datetime
from pathlib import Path

import pandas as pd
import torch
from torch.utils.data import DataLoader

from candidate_extractors.TemplateMatchExtractor import TemplateMatchExtractor
from candidate_extractors.TemplateMatchExtractor import generate_dataset_from_slides
from datasets.LabeledImageDataset import LabeledImageDataset
from models.resnet import Resnet50Model
from train import train_classifier
from utils import reduce_dataset, split_dataset

In [2]:
slides_root_dir = "data/whole-slides/gut"
annotations_root_dir = "data/annotations/json"
candidates_dataset_dir = "output/candidates"

In [3]:
extractor = TemplateMatchExtractor()
generate_dataset_from_slides(slides_root_dir, extractor, candidates_dataset_dir)

Found cached candidates dataset output/candidates


In [4]:
def load_annotations(directory):
    slide_data = {"slide_name": [], "n_positive_annotations": []}
    for annotations_file in os.listdir(directory):
        with open(f"{directory}/{annotations_file}") as f:
            annotations = json.load(f)
            slide_name = Path(annotations_file).stem
            slide_data["slide_name"].append(slide_name)
            slide_data["n_positive_annotations"].append(len(annotations))
    return pd.DataFrame(slide_data)


def assign_categories(dataframe):
    q1, median, q3 = dataframe['n_positive_annotations'].quantile([0.25, 0.5, 0.75])

    def categorize_quartiles(n_annotations):
        if n_annotations <= q1:
            return "Low"
        elif q1 < n_annotations <= median:
            return "Medium"
        elif median < n_annotations <= q3:
            return "High"
        else:
            return "Very High"

    dataframe['category'] = dataframe['n_positive_annotations'].apply(lambda x: categorize_quartiles(x))
    return dataframe


def split_data(dataframe, train_portion=0.7):
    train_set = pd.DataFrame()
    test_set = pd.DataFrame()
    for category in dataframe['category'].unique():
        category_slides = dataframe[dataframe['category'] == category]
        train_samples = category_slides.sample(frac=train_portion)
        test_samples = category_slides.drop(train_samples.index)
        train_set = pd.concat([train_set, train_samples])
        test_set = pd.concat([test_set, test_samples])
    return train_set, test_set


slides_df = load_annotations(annotations_root_dir)
slides_df = assign_categories(slides_df)
train_slides, test_slides = split_data(slides_df)
print("Train Slides")
train_slides

Train Slides


Unnamed: 0,slide_name,n_positive_annotations,category
18,593451,1,Low
0,522021,1,Low
2,593433,3,Low
10,593441,4,Low
12,593445,100,Very High
16,593449,226,Very High
19,593452,167,Very High
1,522934,129,Very High
20,593453,18,Medium
21,593454,20,Medium


In [5]:
print("Test Slides")
test_slides

Test Slides


Unnamed: 0,slide_name,n_positive_annotations,category
13,593446,15,Low
15,593448,13,Low
4,593435,91,Very High
5,593436,151,Very High
3,593434,27,Medium
9,593440,51,High


In [6]:
from datasets.SlideSeperatedDataset import SlideSeperatedDataset

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

batch_size = 256
dataset = SlideSeperatedDataset(candidates_dataset_dir, set(train_slides["slide_name"]))
dataset = reduce_dataset(dataset, discard_ratio=0.0)
train_dataset, validation_dataset = split_dataset(dataset, train_ratio=0.7)
# train_dataset = undersample_dataset(train_dataset)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
validation_loader = DataLoader(validation_dataset,
                         batch_size=batch_size,
                         shuffle=True, )

model = Resnet50Model(hidden_layers=1, units_per_layer=2048,
                      dropout=0.3, focal_alpha=0.9, focal_gamma=2.0)

print(f"Dataset: {len(train_dataset):,} training, {len(validation_dataset):,} validation")


Device: cuda:0
Dataset: 123,263 training, 52,828 validation


In [7]:
print(model)

Resnet50Model(
  (pretrained_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequent

In [8]:
print(f"Training starts {datetime.now().isoformat()}")

Training starts 2025-02-01T19:48:53.941843


In [None]:

model = model.to(device)
model, model_metrics = train_classifier(model, train_loader, test_loader, device,
                                        start_learning_rate=0.000075,
                                        max_epochs=20,
                                        checkpoint_every=1,
                                        eval_every=1)


Epoch 1 training: 100%|██████████| 482/482 [04:58<00:00,  1.61it/s]


Train: 1/20: lr: 0.000075000000 loss:0.0016894085660984516


Epoch 1 testing: 100%|██████████| 207/207 [02:03<00:00,  1.68it/s]


Test: 1/20: loss:0.004046864758536967, accuracy:0.988413345410628, precision:0.24889003910743043, recall:0.3231078904991948, f1:0.25899840899840904, mcc:0.26708140195454216, ece:0.011586654571375409, epoch:0


Epoch 2 training: 100%|██████████| 482/482 [05:05<00:00,  1.58it/s]


Train: 2/20: lr: 0.000075000000 loss:0.0014295167671526199


Epoch 2 testing: 100%|██████████| 207/207 [02:06<00:00,  1.63it/s]


Test: 2/20: loss:0.0039480663304541095, accuracy:0.9827184743226214, precision:0.2036538609002377, recall:0.4327697262479871, f1:0.26282662587010414, mcc:0.28192447126790005, ece:0.017281525608620063, epoch:1


Epoch 3 training: 100%|██████████| 482/482 [05:02<00:00,  1.59it/s]


Train: 3/20: lr: 0.000075000000 loss:0.001031630865659413


Epoch 3 testing: 100%|██████████| 207/207 [02:05<00:00,  1.65it/s]


Test: 3/20: loss:0.00604984967324434, accuracy:0.9913949275362319, precision:0.25603864734299514, recall:0.22882447665056357, f1:0.22188587043659508, mcc:0.22854195882016307, ece:0.008605072455894616, epoch:2


Epoch 4 training: 100%|██████████| 482/482 [05:05<00:00,  1.58it/s]


Train: 4/20: lr: 0.000075000000 loss:0.0010169581879339927


Epoch 4 testing: 100%|██████████| 207/207 [02:06<00:00,  1.64it/s]


Test: 4/20: loss:0.004405340841456186, accuracy:0.9877380999264861, precision:0.24051069703243616, recall:0.33977455716586147, f1:0.2605283337167395, mcc:0.2695413517512471, ece:0.012261900041873256, epoch:3


Epoch 5 training: 100%|██████████| 482/482 [05:09<00:00,  1.56it/s]


Train: 5/20: lr: 0.000075000000 loss:0.0007856158584697342


Epoch 5 testing: 100%|██████████| 207/207 [02:08<00:00,  1.61it/s]


Test: 5/20: loss:0.004074031776177206, accuracy:0.9831902436462928, precision:0.19926002607162024, recall:0.3865654474350127, f1:0.24205287466157033, mcc:0.2590554355665164, ece:0.01680975630406992, epoch:4


Epoch 6 training: 100%|██████████| 482/482 [05:14<00:00,  1.53it/s]


Train: 6/20: lr: 0.000075000000 loss:0.000606028018896968


Epoch 6 testing: 100%|██████████| 207/207 [02:10<00:00,  1.59it/s]


Test: 6/20: loss:0.005824271554438339, accuracy:0.9866099558916194, precision:0.2152921555095468, recall:0.3549689440993789, f1:0.24916798177667743, mcc:0.2603237943015829, ece:0.013390044085591456, epoch:5


Epoch 7 training:  32%|███▏      | 156/482 [01:43<03:33,  1.52it/s]