In [1]:
import json
import os
from datetime import datetime
from pathlib import Path

import pandas as pd
import torch
from torch.utils.data import DataLoader

from candidate_extractors.TemplateMatchExtractor import TemplateMatchExtractor
from candidate_extractors.TemplateMatchExtractor import generate_dataset_from_slides
from datasets.SlideSeperatedCSVDataset import SlideSeperatedCSVDataset
from models.mlp import MLPBinaryClassifier
from models.resnet import Resnet18BinaryClassifier
from train import train_classifier
from utils import extract_features_from_dataset
from utils import plot_model_metrics
from utils import reduce_dataset, split_dataset
from datasets.SlideSeperatedImageDataset import SlideSeperatedImageDataset


In [2]:
slides_root_dir = "data/whole-slides/gut"
annotations_root_dir = "data/annotations/json"
candidates_dataset_dir = "output/candidates"
model_output_dir = "output/models"
PretrainedModelClass = Resnet18BinaryClassifier
features_csv_file_name = f"{PretrainedModelClass.get_pretrained_model_name()}_{PretrainedModelClass.pretrained_output_size}_features.csv"
print(f"{PretrainedModelClass.get_pretrained_model_name()}: {PretrainedModelClass.pretrained_output_size} features")

Resnet18: 512 features


In [3]:
extractor = TemplateMatchExtractor()
generate_dataset_from_slides(slides_root_dir, extractor, candidates_dataset_dir)

Found cached candidates dataset output/candidates


In [4]:
extract_features_from_dataset(candidates_dataset_dir,
                              [Resnet18BinaryClassifier])

Device: cuda:0
Found cached output/candidates/Resnet18_512_features.csv


In [5]:
def load_annotations(directory):
    slide_data = {"slide_name": [], "n_positive_annotations": []}
    for annotations_file in os.listdir(directory):
        with open(f"{directory}/{annotations_file}") as f:
            annotations = json.load(f)
            slide_name = Path(annotations_file).stem
            slide_data["slide_name"].append(slide_name)
            slide_data["n_positive_annotations"].append(len(annotations))
    return pd.DataFrame(slide_data)


def assign_categories(dataframe):
    q1, median, q3 = dataframe['n_positive_annotations'].quantile([0.25, 0.5, 0.75])

    def categorize_quartiles(n_annotations):
        if n_annotations <= q1:
            return "Low"
        elif q1 < n_annotations <= median:
            return "Medium"
        elif median < n_annotations <= q3:
            return "High"
        else:
            return "Very High"

    dataframe['category'] = dataframe['n_positive_annotations'].apply(lambda x: categorize_quartiles(x))
    return dataframe


def split_data(dataframe, train_portion=0.7):
    train_set = pd.DataFrame()
    test_set = pd.DataFrame()
    for category in dataframe['category'].unique():
        category_slides = dataframe[dataframe['category'] == category]
        train_samples = category_slides.sample(frac=train_portion)
        test_samples = category_slides.drop(train_samples.index)
        train_set = pd.concat([train_set, train_samples])
        test_set = pd.concat([test_set, test_samples])
    return train_set, test_set


slides_df = load_annotations(annotations_root_dir)
slides_df = assign_categories(slides_df)
train_slides, test_slides = split_data(slides_df)
print("Train Slides")
train_slides

Train Slides


Unnamed: 0,slide_name,n_positive_annotations,category
15,593448,13,Low
10,593441,4,Low
0,522021,1,Low
2,593433,3,Low
1,522934,129,Very High
12,593445,100,Very High
16,593449,226,Very High
4,593435,91,Very High
20,593453,18,Medium
14,593447,25,Medium


In [6]:
print("Test Slides")
test_slides

Test Slides


Unnamed: 0,slide_name,n_positive_annotations,category
13,593446,15,Low
18,593451,1,Low
5,593436,151,Very High
19,593452,167,Very High
8,593439,27,Medium
11,593444,30,High


In [7]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

batch_size = 4096
# dataset = SlideSeperatedImageDataset(candidates_dataset_dir, set(train_slides["slide_name"]))
dataset = SlideSeperatedCSVDataset(f"{candidates_dataset_dir}/{features_csv_file_name}",
                                   set(train_slides["slide_name"]))
dataset = reduce_dataset(dataset, discard_ratio=0.0)
train_dataset, validation_dataset = split_dataset(dataset, train_ratio=0.9)
# train_dataset = undersample_dataset(train_dataset)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
validation_loader = DataLoader(validation_dataset,
                               batch_size=batch_size,
                               shuffle=True, )

# model = Resnet18BinaryClassifier(hidden_layers=1, units_per_layer=2048,
#                       dropout=0.3, focal_alpha=0.9, focal_gamma=2.0)
model = MLPBinaryClassifier(in_features=PretrainedModelClass.pretrained_output_size, hidden_layers=2,
                            units_per_layer=PretrainedModelClass.pretrained_output_size,
                            dropout=0.3, focal_alpha=0.85, focal_gamma=4.0)

print(f"Dataset: {len(train_dataset):,} training, {len(validation_dataset):,} validation")


Device: cuda:0
Dataset: 111,977 training, 12,442 validation


In [8]:
print(model)

MLPBinaryClassifier(
  (model): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=512, out_features=1, bias=True)
    (10): Sigmoid()
  )
)


In [9]:
print(f"Training starts {datetime.now().isoformat()}")

Training starts 2025-02-04T23:58:04.674577


In [10]:

model = model.to(device)
model, model_metrics = train_classifier(model, train_loader, validation_loader, device,
                                        start_learning_rate=0.000050,
                                        min_learning_rate=0.000025,
                                        lr_warmup_steps=20,
                                        max_epochs=30,
                                        checkpoint_every=1,
                                        eval_every=1)


Epoch 1 training: 100%|██████████| 28/28 [00:01<00:00, 16.83it/s]


Train: 1/30: lr: 0.000050000000 loss:0.002735402127395251


Epoch 1 testing: 100%|██████████| 4/4 [00:00<00:00, 15.67it/s]


Test: 1/30: loss:0.001641034905333072, accuracy:0.99298095703125, precision:0.0, recall:0.0, f1:0.0, mcc:0.0, ece:0.00701904296875, epoch:1


Epoch 2 training: 100%|██████████| 28/28 [00:01<00:00, 19.72it/s]


Train: 2/30: lr: 0.000050000000 loss:0.0018909183314203151


Epoch 2 testing: 100%|██████████| 4/4 [00:00<00:00, 23.92it/s]


Test: 2/30: loss:0.0014327362732728943, accuracy:0.99298095703125, precision:0.0, recall:0.0, f1:0.0, mcc:0.0, ece:0.00701904296875, epoch:2


Epoch 3 training: 100%|██████████| 28/28 [00:01<00:00, 18.48it/s]


Train: 3/30: lr: 0.000050000000 loss:0.0016403938560480519


Epoch 3 testing: 100%|██████████| 4/4 [00:00<00:00, 21.50it/s]


Test: 3/30: loss:0.00143946404568851, accuracy:0.9898562740969967, precision:0.0, recall:0.0, f1:0.0, mcc:0.0, ece:0.010143725899979472, epoch:3


Epoch 4 training: 100%|██████████| 28/28 [00:01<00:00, 18.75it/s]


Train: 4/30: lr: 0.000050000000 loss:0.0014200537739920297


Epoch 4 testing: 100%|██████████| 4/4 [00:00<00:00, 15.45it/s]


Test: 4/30: loss:0.0012280839728191495, accuracy:0.9790411614752434, precision:0.13208037178786222, recall:0.21957431760837987, f1:0.1641565450527715, mcc:0.16009175301989495, ece:0.020958838518708944, epoch:4


Epoch 5 training: 100%|██████████| 28/28 [00:01<00:00, 15.71it/s]


Train: 5/30: lr: 0.000050000000 loss:0.001319202031092053


Epoch 5 testing: 100%|██████████| 4/4 [00:00<00:00, 18.29it/s]


Test: 5/30: loss:0.0010982887324644253, accuracy:0.9739752435064934, precision:0.20833333333333331, recall:0.6111197339246119, f1:0.30867108283215666, mcc:0.3454390591951983, ece:0.026024756021797657, epoch:5


Epoch 6 training: 100%|██████████| 28/28 [00:01<00:00, 14.94it/s]


Train: 6/30: lr: 0.000050000000 loss:0.0012986362188322736


Epoch 6 testing: 100%|██████████| 4/4 [00:00<00:00, 21.33it/s]


Test: 6/30: loss:0.001102386275306344, accuracy:0.9697146725344967, precision:0.2313695879408043, recall:0.6944444444444444, f1:0.33745591570306366, mcc:0.38315388540307893, ece:0.030285327462479472, epoch:6


Epoch 7 training: 100%|██████████| 28/28 [00:01<00:00, 18.49it/s]


Train: 7/30: lr: 0.000050000000 loss:0.0012496020582537831


Epoch 7 testing: 100%|██████████| 4/4 [00:00<00:00, 15.04it/s]


Test: 7/30: loss:0.000996049217064865, accuracy:0.9666755973518669, precision:0.1147165387894288, recall:0.42938443670150983, f1:0.18083668070931713, mcc:0.2129437249722829, ece:0.033324402756989, epoch:7


Epoch 8 training: 100%|██████████| 28/28 [00:01<00:00, 15.34it/s]


Train: 8/30: lr: 0.000050000000 loss:0.0012287849177872495


Epoch 8 testing: 100%|██████████| 4/4 [00:00<00:00, 17.04it/s]


Test: 8/30: loss:0.0009775288635864854, accuracy:0.9702766715706169, precision:0.12930920399655926, recall:0.4025083612040134, f1:0.19471945132574225, mcc:0.21920593918424736, ece:0.029723328072577715, epoch:8


Epoch 9 training: 100%|██████████| 28/28 [00:02<00:00, 11.76it/s]


Train: 9/30: lr: 0.000050000000 loss:0.0012079225000759055


Epoch 9 testing: 100%|██████████| 4/4 [00:00<00:00, 16.79it/s]


Test: 9/30: loss:0.0010473898146301508, accuracy:0.9708378779423701, precision:0.180351022448959, recall:0.6766369047619047, f1:0.2841108756896502, mcc:0.3389211824390591, ece:0.029162121936678886, epoch:9


Epoch 10 training: 100%|██████████| 28/28 [00:02<00:00, 12.63it/s]


Train: 10/30: lr: 0.000050000000 loss:0.0011827225805193717


Epoch 10 testing: 100%|██████████| 4/4 [00:00<00:00, 11.98it/s]


Test: 10/30: loss:0.001048139572958462, accuracy:0.9541031478287337, precision:0.10338394204793447, recall:0.4673669094851361, f1:0.16842438736828064, mcc:0.20947890351160392, ece:0.045896852388978004, epoch:10


Epoch 11 training: 100%|██████████| 28/28 [00:02<00:00, 11.55it/s]


Train: 11/30: lr: 0.000050000000 loss:0.0011808320060871275


Epoch 11 testing: 100%|██████████| 4/4 [00:00<00:00, 15.37it/s]


Test: 11/30: loss:0.0010322316957172006, accuracy:0.9718144404423701, precision:0.22241236496269073, recall:0.6737307408039115, f1:0.332043624193746, mcc:0.3752618275255189, ece:0.028185559902340174, epoch:11


Epoch 12 training: 100%|██████████| 28/28 [00:02<00:00, 11.03it/s]


Train: 12/30: lr: 0.000050000000 loss:0.0011456013016868383


Epoch 12 testing: 100%|██████████| 4/4 [00:00<00:00, 13.92it/s]


Test: 12/30: loss:0.0009904627513606101, accuracy:0.9639900504768669, precision:0.10913805179840981, recall:0.4563847610722611, f1:0.17591968790148216, mcc:0.21384112793186685, ece:0.036009949631989, epoch:12


Epoch 13 training: 100%|██████████| 28/28 [00:02<00:00, 10.42it/s]


Train: 13/30: lr: 0.000050000000 loss:0.0011360710154154471


Epoch 13 testing: 100%|██████████| 4/4 [00:00<00:00, 11.34it/s]


Test: 13/30: loss:0.0011361215147189796, accuracy:0.9711065911627434, precision:0.22472797303803566, recall:0.5958308126638435, f1:0.3172784281907717, mcc:0.34853350186882526, ece:0.02889340929687023, epoch:13


Epoch 14 training: 100%|██████████| 28/28 [00:02<00:00, 13.46it/s]


Train: 14/30: lr: 0.000050000000 loss:0.001104607550327533


Epoch 14 testing: 100%|██████████| 4/4 [00:00<00:00, 21.33it/s]


Test: 14/30: loss:0.0010150254383916035, accuracy:0.9676030146611201, precision:0.17270399935494274, recall:0.7085649461145774, f1:0.2769699891943509, mcc:0.33882240279161224, ece:0.032396985217928886, epoch:14


Epoch 15 training: 100%|██████████| 28/28 [00:01<00:00, 14.37it/s]


Train: 15/30: lr: 0.000050000000 loss:0.0010997194247985525


Epoch 15 testing: 100%|██████████| 4/4 [00:00<00:00, 16.24it/s]


Test: 15/30: loss:0.0009596779927960597, accuracy:0.9655397093141234, precision:0.10126589050147734, recall:0.501072501072501, f1:0.16824998800982355, mcc:0.21533688337516488, ece:0.034460290684364736, epoch:15


Epoch 16 training: 100%|██████████| 28/28 [00:02<00:00, 12.65it/s]


Train: 16/30: lr: 0.000050000000 loss:0.0010846745405745292


Epoch 16 testing: 100%|██████████| 4/4 [00:00<00:00, 19.69it/s]


Test: 16/30: loss:0.0012179831974208355, accuracy:0.964979295606737, precision:0.13771336659663866, recall:0.4584881756756757, f1:0.21078662006542098, mcc:0.23814649033127977, ece:0.03502070438116789, epoch:16


Epoch 17 training: 100%|██████████| 28/28 [00:01<00:00, 16.86it/s]


Train: 17/30: lr: 0.000050000000 loss:0.0010604683006282098


Epoch 17 testing: 100%|██████████| 4/4 [00:00<00:00, 18.88it/s]


Test: 17/30: loss:0.0010031050478573889, accuracy:0.9675300895393669, precision:0.16519410265546214, recall:0.7097812097812097, f1:0.26604019998140643, mcc:0.3302186468491323, ece:0.03246990963816643, epoch:17


Epoch 18 training: 100%|██████████| 28/28 [00:01<00:00, 14.33it/s]


Train: 18/30: lr: 0.000050000000 loss:0.0010680915480146983


Epoch 18 testing: 100%|██████████| 4/4 [00:00<00:00, 13.70it/s]


Test: 18/30: loss:0.0009477081621298566, accuracy:0.9776492428469967, precision:0.14685709810239886, recall:0.41223290598290596, f1:0.21652797673141005, mcc:0.23848457121611505, ece:0.022350757149979472, epoch:18


Epoch 19 training: 100%|██████████| 28/28 [00:02<00:00, 10.57it/s]


Train: 19/30: lr: 0.000050000000 loss:0.0010421376874936478


Epoch 19 testing: 100%|██████████| 4/4 [00:00<00:00, 10.53it/s]


Test: 19/30: loss:0.0009615892922738567, accuracy:0.9719000481939934, precision:0.13222451519916142, recall:0.4258928571428572, f1:0.2011245842988006, mcc:0.22875891863356085, ece:0.028099951799958944, epoch:19


Epoch 20 training: 100%|██████████| 28/28 [00:02<00:00, 10.13it/s]


Train: 20/30: lr: 0.000025000000 loss:0.0010459919784417643


Epoch 20 testing: 100%|██████████| 4/4 [00:00<00:00, 13.93it/s]


Test: 20/30: loss:0.0009902508318191394, accuracy:0.9708870231331169, precision:0.17956912878787878, recall:0.6933687200956937, f1:0.28103323054750784, mcc:0.3392793589680974, ece:0.029112976044416428, epoch:20


Epoch 21 training: 100%|██████████| 28/28 [00:02<00:00, 12.38it/s]


Train: 21/30: lr: 0.000025000000 loss:0.0010137998802487605


Epoch 21 testing: 100%|██████████| 4/4 [00:00<00:00, 14.29it/s]


Test: 21/30: loss:0.0012227226106915623, accuracy:0.9709115957284903, precision:0.23384298862712624, recall:0.601984126984127, f1:0.3341840360394658, mcc:0.3617759901711327, ece:0.029088404029607773, epoch:21


Epoch 22 training: 100%|██████████| 28/28 [00:02<00:00, 12.55it/s]


Train: 22/30: lr: 0.000025000000 loss:0.0010009429430023634


Epoch 22 testing: 100%|██████████| 4/4 [00:00<00:00, 11.02it/s]


Test: 22/30: loss:0.0010117861384060234, accuracy:0.9740481686282467, precision:0.2562454506231022, recall:0.5919123321500231, f1:0.32950482920311774, mcc:0.3626191256433808, ece:0.025951831368729472, epoch:22


Epoch 23 training: 100%|██████████| 28/28 [00:02<00:00,  9.97it/s]


Train: 23/30: lr: 0.000025000000 loss:0.0010069489612111024


Epoch 23 testing: 100%|██████████| 4/4 [00:00<00:00,  7.14it/s]


Test: 23/30: loss:0.0011684346973197535, accuracy:0.9679454456676136, precision:0.19852347860957337, recall:0.5993589743589743, f1:0.2974690688632414, mcc:0.33217763355012264, ece:0.03205455373972654, epoch:23


Epoch 24 training: 100%|██████████| 28/28 [00:02<00:00, 10.91it/s]


Train: 24/30: lr: 0.000025000000 loss:0.0009852697964691157


Epoch 24 testing: 100%|██████████| 4/4 [00:00<00:00, 13.12it/s]


Test: 24/30: loss:0.0010343853064114228, accuracy:0.9715338372564934, precision:0.13326259689922482, recall:0.4468248736541419, f1:0.20519687549816673, mcc:0.23313998135520392, ece:0.028466162737458944, epoch:24


Epoch 25 training: 100%|██████████| 28/28 [00:02<00:00, 10.38it/s]


Train: 25/30: lr: 0.000025000000 loss:0.0009862955485004932


Epoch 25 testing: 100%|██████████| 4/4 [00:00<00:00, 14.59it/s]


Test: 25/30: loss:0.0011101903655799106, accuracy:0.9744754147219967, precision:0.29072779541566995, recall:0.6953703703703703, f1:0.3918600354714751, mcc:0.4279857320538054, ece:0.02552458504214883, epoch:25


Epoch 26 training: 100%|██████████| 28/28 [00:02<00:00, 11.16it/s]


Train: 26/30: lr: 0.000025000000 loss:0.0009778140000500052


Epoch 26 testing: 100%|██████████| 4/4 [00:00<00:00, 14.18it/s]


Test: 26/30: loss:0.00221591311856173, accuracy:0.961683397169237, precision:0.15710138129856277, recall:0.5704545454545454, f1:0.24418943747071442, mcc:0.2834393485502644, ece:0.03831660374999046, epoch:26


Epoch 27 training: 100%|██████████| 28/28 [00:03<00:00,  9.26it/s]


Train: 27/30: lr: 0.000025000000 loss:0.0009749616770672479


Epoch 27 testing: 100%|██████████| 4/4 [00:00<00:00, 10.63it/s]


Test: 27/30: loss:0.0011438362998887897, accuracy:0.9662364612926136, precision:0.16615454693448206, recall:0.5740079365079365, f1:0.25764527948193594, mcc:0.2965413903255768, ece:0.03376353904604912, epoch:27


Epoch 28 training: 100%|██████████| 28/28 [00:03<00:00,  9.33it/s]


Train: 28/30: lr: 0.000025000000 loss:0.0009577268688839727


Epoch 28 testing: 100%|██████████| 4/4 [00:00<00:00, 10.16it/s]


Test: 28/30: loss:0.0010022634087363258, accuracy:0.9730716061282467, precision:0.12734848484848485, recall:0.46426762940735183, f1:0.19922920798929838, mcc:0.23275382505547362, ece:0.02692839433439076, epoch:28


Epoch 29 training: 100%|██████████| 28/28 [00:02<00:00, 10.76it/s]


Train: 29/30: lr: 0.000025000000 loss:0.0009637923711644751


Epoch 29 testing: 100%|██████████| 4/4 [00:00<00:00, 13.47it/s]


Test: 29/30: loss:0.0009605345403542742, accuracy:0.9650767933238636, precision:0.13014756216686024, recall:0.46230811403508765, f1:0.20184590690208667, mcc:0.2360062415516005, ece:0.03492320701479912, epoch:29


Epoch 30 training: 100%|██████████| 28/28 [00:02<00:00, 10.74it/s]


Train: 30/30: lr: 0.000025000000 loss:0.0009437822071569306


Epoch 30 testing: 100%|██████████| 4/4 [00:00<00:00, 11.80it/s]

Test: 30/30: loss:0.001146105263615027, accuracy:0.9700571035409903, precision:0.14244908097367115, recall:0.4384920634920635, f1:0.21488076743824763, mcc:0.23851560320880855, ece:0.029942896217107773, epoch:30





In [None]:


print(f"Training ends {datetime.now().isoformat()}")
for metric in ["accuracy", "precision", "recall", "f1", "mcc"]:
    if len(model_metrics[f"test_{metric}"]) > 0:
        print(f"Test {metric}:", model_metrics[f"test_{metric}"][-1])

plot_model_metrics(model_metrics)

Training ends 2025-02-04T23:59:20.901146
Test accuracy: 0.9700571035409903
Test precision: 0.14244908097367115
Test recall: 0.4384920634920635
Test f1: 0.21488076743824763
Test mcc: 0.23851560320880855


In [24]:
os.makedirs(model_output_dir, exist_ok=True)
torch.save(model, f"{model_output_dir}/model.pickle")
torch.save({
    "train_slides": set(train_slides["slide_name"]),
    "test_slides": set(test_slides["slide_name"])
},
    f"{model_output_dir}/data-split.pickle"
)