In [1]:

import os
from datetime import datetime

import pandas as pd
import torch
from torch.utils.data import DataLoader

from datasets.SlideSeperatedCSVDataset import SlideSeperatedCSVDataset
from extractors.TemplateMatchExtractor import TemplateMatchExtractor, generate_dataset_from_slides
from labelers.GroundTruthLabeler import GroundTruthLabeler
from models.mlp import MLPBinaryClassifier
from models.resnet import Resnet18BinaryClassifier
from train import train_classifier
from utils import extract_features_from_dataset
from utils import plot_model_metrics
from utils import reduce_dataset, split_dataset


In [2]:
slides_root_dir = "data/whole-slides/gut"
labels_root_dir = "data/labels"
candidates_dataset_dir = "output/candidates"
model_output_dir = "output/models"
PretrainedModelClass = Resnet18BinaryClassifier
features_csv_file_name = f"{PretrainedModelClass.get_pretrained_model_name()}_{PretrainedModelClass.pretrained_output_size}_features.csv"
print(f"{PretrainedModelClass.get_pretrained_model_name()}: {PretrainedModelClass.pretrained_output_size} features")

Resnet18: 512 features


In [3]:
ground_truth_labeler = GroundTruthLabeler(f"{labels_root_dir}/slide-annotations/all.json",
                                          f"{labels_root_dir}/patch-classifications.csv")
extractor = TemplateMatchExtractor(ground_truth_labeler)
generate_dataset_from_slides(slides_root_dir, extractor, candidates_dataset_dir)

Found cached candidates dataset output/candidates


In [4]:
extract_features_from_dataset(candidates_dataset_dir,
                              [Resnet18BinaryClassifier])

Device: cuda:0
Found cached output/candidates/Resnet18_512_features.csv


In [5]:
def assign_categories(dataframe):
    q1, median, q3 = dataframe['n_gt_positive_regions'].quantile([0.25, 0.5, 0.75])

    def categorize_quartiles(n_annotations):
        if n_annotations <= q1:
            return "Low"
        elif q1 < n_annotations <= median:
            return "Medium"
        elif median < n_annotations <= q3:
            return "High"
        else:
            return "Very High"

    dataframe['category'] = dataframe['n_gt_positive_regions'].apply(lambda x: categorize_quartiles(x))
    return dataframe


def split_data(dataframe, train_portion=0.7):
    train_set = pd.DataFrame()
    test_set = pd.DataFrame()
    for category in dataframe['category'].unique():
        category_slides = dataframe[dataframe['category'] == category]
        train_samples = category_slides.sample(frac=train_portion)
        test_samples = category_slides.drop(train_samples.index)
        train_set = pd.concat([train_set, train_samples])
        test_set = pd.concat([test_set, test_samples])
    return train_set, test_set


slides_df = ground_truth_labeler.positive_regions_summary
slides_df = assign_categories(slides_df)
train_slides, test_slides = split_data(slides_df)
print("Train Slides")
train_slides

Train Slides


Unnamed: 0,slide_name,n_gt_positive_regions,category
4,593453,19,Medium
1,593439,31,Medium
7,593454,22,Medium
0,593444,31,Medium
16,593451,1,Low
21,522021,3,Low
2,593446,15,Low
3,593441,4,Low
5,593438,92,High
12,593450,55,High


In [6]:
print("Test Slides")
test_slides

Test Slides


Unnamed: 0,slide_name,n_gt_positive_regions,category
10,593434,27,Medium
14,593447,25,Medium
9,593448,13,Low
11,593433,3,Low
20,593437,94,High
8,593436,170,Very High
19,522934,182,Very High


In [7]:

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

batch_size = 4096
# dataset = SlideSeperatedImageDataset(candidates_dataset_dir, set(train_slides["slide_name"]))
dataset = SlideSeperatedCSVDataset(f"{candidates_dataset_dir}/{features_csv_file_name}",
                                   set(train_slides["slide_name"]))
dataset = reduce_dataset(dataset, discard_ratio=0.0)
train_dataset, validation_dataset = split_dataset(dataset, train_ratio=0.9)
# train_dataset = undersample_dataset(train_dataset)

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
validation_loader = DataLoader(validation_dataset,
                               batch_size=batch_size,
                               shuffle=True, )

# model = Resnet18BinaryClassifier(hidden_layers=1, units_per_layer=2048,
#                       dropout=0.3, focal_alpha=0.9, focal_gamma=2.0)
model = MLPBinaryClassifier(in_features=PretrainedModelClass.pretrained_output_size, hidden_layers=2,
                            units_per_layer=PretrainedModelClass.pretrained_output_size,
                            dropout=0.3, focal_alpha=0.85, focal_gamma=4.0)

print(f"Dataset: {len(train_dataset):,} training, {len(validation_dataset):,} validation")


Device: cuda:0
Dataset: 144,276 training, 16,031 validation


In [8]:
print(model)

MLPBinaryClassifier(
  (model): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=512, out_features=512, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=512, out_features=1, bias=True)
    (10): Sigmoid()
  )
)


In [9]:
print(f"Training starts {datetime.now().isoformat()}")

Training starts 2025-02-06T17:51:30.526412


In [10]:

model = model.to(device)
model, model_metrics = train_classifier(model, train_loader, validation_loader, device,
                                        start_learning_rate=0.000050,
                                        min_learning_rate=0.000025,
                                        lr_warmup_steps=20,
                                        max_epochs=30,
                                        checkpoint_every=1,
                                        eval_every=1)


Epoch 1 training: 100%|██████████| 36/36 [00:02<00:00, 17.53it/s]


Train: 1/30: lr: 0.000050000000 loss:0.002970276662381366


Epoch 1 testing: 100%|██████████| 4/4 [00:00<00:00, 17.81it/s]


Test: 1/30: loss:0.001959073852049187, accuracy:0.9921426408723367, precision:0.0, recall:0.0, f1:0.0, mcc:0.0, ece:0.007857359130866826, epoch:1


Epoch 2 training: 100%|██████████| 36/36 [00:01<00:00, 19.13it/s]


Train: 2/30: lr: 0.000050000000 loss:0.001882268604175705


Epoch 2 testing: 100%|██████████| 4/4 [00:00<00:00, 18.22it/s]


Test: 2/30: loss:0.0016228199528995901, accuracy:0.9921944465606632, precision:0.0, recall:0.0, f1:0.0, mcc:0.0, ece:0.007805553381331265, epoch:2


Epoch 3 training: 100%|██████████| 36/36 [00:01<00:00, 18.71it/s]


Train: 3/30: lr: 0.000050000000 loss:0.0016703772449141575


Epoch 3 testing: 100%|██████████| 4/4 [00:00<00:00, 12.49it/s]


Test: 3/30: loss:0.0013741978036705405, accuracy:0.9921599094351122, precision:0.0, recall:0.0, f1:0.0, mcc:0.0, ece:0.007840090547688305, epoch:3


Epoch 4 training: 100%|██████████| 36/36 [00:02<00:00, 15.24it/s]


Train: 4/30: lr: 0.000050000000 loss:0.0014614377966305863


Epoch 4 testing: 100%|██████████| 4/4 [00:00<00:00, 12.65it/s]


Test: 4/30: loss:0.0012666729744523764, accuracy:0.9916520767150765, precision:0.4369047619047619, recall:0.10099605622408736, f1:0.1601190476190476, mcc:0.20342666267220472, ece:0.008347923401743174, epoch:4


Epoch 5 training: 100%|██████████| 36/36 [00:03<00:00, 11.84it/s]


Train: 5/30: lr: 0.000050000000 loss:0.0013798181381490496


Epoch 5 testing: 100%|██████████| 4/4 [00:00<00:00,  9.66it/s]


Test: 5/30: loss:0.0011944971629418433, accuracy:0.989193401902301, precision:0.25334975369458124, recall:0.1985971835845679, f1:0.2223347968028819, mcc:0.21877224391382452, ece:0.01080659800209105, epoch:5


Epoch 6 training: 100%|██████████| 36/36 [00:03<00:00, 11.85it/s]


Train: 6/30: lr: 0.000050000000 loss:0.001332792529461181


Epoch 6 testing: 100%|██████████| 4/4 [00:00<00:00, 10.58it/s]


Test: 6/30: loss:0.001158578583272174, accuracy:0.9884667362148928, precision:0.2644789081885856, recall:0.2448613928329953, f1:0.25187718095712863, mcc:0.24751558502713156, ece:0.011533263605087996, epoch:6


Epoch 7 training: 100%|██████████| 36/36 [00:03<00:00, 11.61it/s]


Train: 7/30: lr: 0.000050000000 loss:0.0012884968471351182


Epoch 7 testing: 100%|██████████| 4/4 [00:00<00:00, 11.12it/s]


Test: 7/30: loss:0.0011431673920014873, accuracy:0.9864410636834591, precision:0.2500652657447083, recall:0.3668105099940294, f1:0.29564071009635523, mcc:0.2953541803205114, ece:0.013558936305344105, epoch:7


Epoch 8 training: 100%|██████████| 36/36 [00:02<00:00, 12.23it/s]


Train: 8/30: lr: 0.000050000000 loss:0.0012495459296688852


Epoch 8 testing: 100%|██████████| 4/4 [00:00<00:00, 13.82it/s]


Test: 8/30: loss:0.0011366091930540279, accuracy:0.9844626310258315, precision:0.23245732544175673, recall:0.44067132285926114, f1:0.3021816695737417, mcc:0.31169580833108745, ece:0.015537369064986706, epoch:8


Epoch 9 training: 100%|██████████| 36/36 [00:03<00:00, 11.25it/s]


Train: 9/30: lr: 0.000050000000 loss:0.0012342899911648904


Epoch 9 testing: 100%|██████████| 4/4 [00:00<00:00,  8.35it/s]


Test: 9/30: loss:0.0011112412175862119, accuracy:0.9850545236524846, precision:0.239725158319348, recall:0.4352156432748538, f1:0.307594696969697, mcc:0.3151916689436142, ece:0.014945476315915585, epoch:9


Epoch 10 training: 100%|██████████| 36/36 [00:03<00:00, 10.80it/s]


Train: 10/30: lr: 0.000050000000 loss:0.0012059850867242655


Epoch 10 testing: 100%|██████████| 4/4 [00:00<00:00, 10.63it/s]


Test: 10/30: loss:0.0011753258586395532, accuracy:0.9769634437408162, precision:0.19082461378572432, recall:0.5808945105820106, f1:0.2855030080213904, mcc:0.32299163574371675, ece:0.023036555852741003, epoch:10


Epoch 11 training: 100%|██████████| 36/36 [00:03<00:00, 10.69it/s]


Train: 11/30: lr: 0.000050000000 loss:0.0011939658886856502


Epoch 11 testing: 100%|██████████| 4/4 [00:00<00:00,  8.63it/s]


Test: 11/30: loss:0.0010625259310472757, accuracy:0.9873888452455, precision:0.28789198606271776, recall:0.4087749363327674, f1:0.3354385850333674, mcc:0.33568502852221155, ece:0.012611154466867447, epoch:11


Epoch 12 training: 100%|██████████| 36/36 [00:03<00:00, 10.27it/s]


Train: 12/30: lr: 0.000050000000 loss:0.0011434647943436478


Epoch 12 testing: 100%|██████████| 4/4 [00:00<00:00, 10.35it/s]


Test: 12/30: loss:0.0010609611199470237, accuracy:0.9830035434634234, precision:0.23450823068870336, recall:0.5000627285110043, f1:0.316621866431649, mcc:0.33346247561074704, ece:0.01699645654298365, epoch:12


Epoch 13 training: 100%|██████████| 36/36 [00:03<00:00, 10.06it/s]


Train: 13/30: lr: 0.000050000000 loss:0.0011309799916085063


Epoch 13 testing: 100%|██████████| 4/4 [00:00<00:00, 10.20it/s]


Test: 13/30: loss:0.0010334258258808404, accuracy:0.9858859910896173, precision:0.26301509252705024, recall:0.45308924485125857, f1:0.32977903650500917, mcc:0.33690199150938915, ece:0.01411400898359716, epoch:13


Epoch 14 training: 100%|██████████| 36/36 [00:03<00:00, 10.19it/s]


Train: 14/30: lr: 0.000050000000 loss:0.0011282011417077025


Epoch 14 testing: 100%|██████████| 4/4 [00:00<00:00,  8.86it/s]


Test: 14/30: loss:0.0010208929888904095, accuracy:0.9832269422453163, precision:0.23893606738470036, recall:0.5112835686600221, f1:0.3219612018797952, mcc:0.3399012988555477, ece:0.016773057403042912, epoch:14


Epoch 15 training: 100%|██████████| 36/36 [00:03<00:00, 10.22it/s]


Train: 15/30: lr: 0.000050000000 loss:0.0010952682350762188


Epoch 15 testing: 100%|██████████| 4/4 [00:00<00:00, 10.09it/s]


Test: 15/30: loss:0.001007825689157471, accuracy:0.9874648660572652, precision:0.3021075581395349, recall:0.4328670032229543, f1:0.352530644259573, mcc:0.35384834313959357, ece:0.012535133864730597, epoch:15


Epoch 16 training: 100%|██████████| 36/36 [00:03<00:00,  9.49it/s]


Train: 16/30: lr: 0.000050000000 loss:0.0010842404283014024


Epoch 16 testing: 100%|██████████| 4/4 [00:00<00:00,  9.84it/s]


Test: 16/30: loss:0.001009016137686558, accuracy:0.9851627986825408, precision:0.26221749265227523, recall:0.48783554952735986, f1:0.3405015166514205, mcc:0.3505550378233494, ece:0.014837201219052076, epoch:16


Epoch 17 training: 100%|██████████| 36/36 [00:03<00:00,  9.60it/s]


Train: 17/30: lr: 0.000050000000 loss:0.001049328841165536


Epoch 17 testing: 100%|██████████| 4/4 [00:00<00:00, 10.03it/s]


Test: 17/30: loss:0.0010079772328026593, accuracy:0.9818968715560713, precision:0.23490426351872135, recall:0.5703018084066471, f1:0.33180079783340655, mcc:0.3579903769012942, ece:0.018103128764778376, epoch:17


Epoch 18 training: 100%|██████████| 36/36 [00:03<00:00,  9.47it/s]


Train: 18/30: lr: 0.000050000000 loss:0.001033821882123852


Epoch 18 testing: 100%|██████████| 4/4 [00:00<00:00, 10.12it/s]


Test: 18/30: loss:0.0009819643892114982, accuracy:0.9856856170580918, precision:0.27677044404542084, recall:0.5105956404552067, f1:0.35721658934748995, mcc:0.36839556242173577, ece:0.014314383268356323, epoch:18


Epoch 19 training: 100%|██████████| 36/36 [00:03<00:00, 11.14it/s]


Train: 19/30: lr: 0.000050000000 loss:0.0010153607630248491


Epoch 19 testing: 100%|██████████| 4/4 [00:00<00:00,  7.99it/s]


Test: 19/30: loss:0.0009800207044463605, accuracy:0.983692198744949, precision:0.24893297380585516, recall:0.5270575197045785, f1:0.33653385492127885, mcc:0.35411773777235034, ece:0.01630780124105513, epoch:19


Epoch 20 training: 100%|██████████| 36/36 [00:03<00:00, 10.02it/s]


Train: 20/30: lr: 0.000025000000 loss:0.000985942323394637


Epoch 20 testing: 100%|██████████| 4/4 [00:00<00:00,  7.83it/s]


Test: 20/30: loss:0.0009635615715524182, accuracy:0.9851950529008315, precision:0.2713446510820727, recall:0.5244317942230655, f1:0.3557849516990476, mcc:0.36950189106165054, ece:0.014804947189986706, epoch:20


Epoch 21 training: 100%|██████████| 36/36 [00:03<00:00,  9.81it/s]


Train: 21/30: lr: 0.000025000000 loss:0.0009647443673909745


Epoch 21 testing: 100%|██████████| 4/4 [00:00<00:00, 10.66it/s]


Test: 21/30: loss:0.0009677717462182045, accuracy:0.9854772039317142, precision:0.27672418977901503, recall:0.5087928921568627, f1:0.35732220379493684, mcc:0.3679577819957771, ece:0.014522796031087637, epoch:21


Epoch 22 training: 100%|██████████| 36/36 [00:04<00:00,  8.53it/s]


Train: 22/30: lr: 0.000025000000 loss:0.0009480527078267187


Epoch 22 testing: 100%|██████████| 4/4 [00:00<00:00,  9.69it/s]


Test: 22/30: loss:0.0009615275339456275, accuracy:0.9846987325559795, precision:0.26554062419674135, recall:0.5306490384615384, f1:0.3526471168711555, mcc:0.36781922772663433, ece:0.0153012671507895, epoch:22


Epoch 23 training: 100%|██████████| 36/36 [00:04<00:00,  8.16it/s]


Train: 23/30: lr: 0.000025000000 loss:0.0009486933510440091


Epoch 23 testing: 100%|██████████| 4/4 [00:00<00:00,  9.17it/s]


Test: 23/30: loss:0.0009558977617416531, accuracy:0.9847413087763826, precision:0.2587746952501051, recall:0.5060384554266645, f1:0.3418365606463152, mcc:0.3546825756159613, ece:0.015258691040799022, epoch:23


Epoch 24 training: 100%|██████████| 36/36 [00:04<00:00,  8.12it/s]


Train: 24/30: lr: 0.000025000000 loss:0.000945483727264218


Epoch 24 testing: 100%|██████████| 4/4 [00:00<00:00,  9.33it/s]


Test: 24/30: loss:0.0009627771360101178, accuracy:0.983285694494306, precision:0.24299061584847875, recall:0.5367063492063492, f1:0.3344944486820917, mcc:0.3539264933521237, ece:0.016714305616915226, epoch:24


Epoch 25 training: 100%|██████████| 36/36 [00:04<00:00,  7.84it/s]


Train: 25/30: lr: 0.000025000000 loss:0.0009326068102382123


Epoch 25 testing: 100%|██████████| 4/4 [00:00<00:00, 11.33it/s]


Test: 25/30: loss:0.0009513191907899454, accuracy:0.9856245819018418, precision:0.2817512343172249, recall:0.5092936609639196, f1:0.36053577672873505, mcc:0.3709930434934752, ece:0.01437541819177568, epoch:25


Epoch 26 training: 100%|██████████| 36/36 [00:03<00:00,  9.41it/s]


Train: 26/30: lr: 0.000025000000 loss:0.0009225748265938213


Epoch 26 testing: 100%|██████████| 4/4 [00:00<00:00,  8.05it/s]


Test: 26/30: loss:0.0009462883608648553, accuracy:0.9859090158399846, precision:0.27503290330929464, recall:0.49421747967479673, f1:0.3524374234075316, mcc:0.3616914917325139, ece:0.014090984361246228, epoch:26


Epoch 27 training: 100%|██████████| 36/36 [00:03<00:00,  9.23it/s]


Train: 27/30: lr: 0.000025000000 loss:0.0009111630093280433


Epoch 27 testing: 100%|██████████| 4/4 [00:00<00:00,  8.86it/s]


Test: 27/30: loss:0.0009543160413159057, accuracy:0.982919483556806, precision:0.23502298487560094, recall:0.5273114908821249, f1:0.32466690806496634, mcc:0.3444439034533793, ece:0.017080516554415226, epoch:27


Epoch 28 training: 100%|██████████| 36/36 [00:03<00:00,  9.39it/s]


Train: 28/30: lr: 0.000025000000 loss:0.0009319266844411484


Epoch 28 testing: 100%|██████████| 4/4 [00:00<00:00, 10.59it/s]


Test: 28/30: loss:0.0009482354653300717, accuracy:0.9865055721200408, precision:0.29263329263329263, recall:0.5052853192559075, f1:0.3691056910569106, mcc:0.3774356590706469, ece:0.01349442801438272, epoch:28


Epoch 29 training: 100%|██████████| 36/36 [00:03<00:00,  9.24it/s]


Train: 29/30: lr: 0.000025000000 loss:0.0009357534209913057


Epoch 29 testing: 100%|██████████| 4/4 [00:00<00:00,  7.31it/s]


Test: 29/30: loss:0.0009491039963904768, accuracy:0.984262256994306, precision:0.2563203463203463, recall:0.5181776556776556, f1:0.3369746369794462, mcc:0.3539951268089094, ece:0.015737742884084582, epoch:29


Epoch 30 training: 100%|██████████| 36/36 [00:03<00:00,  9.61it/s]


Train: 30/30: lr: 0.000025000000 loss:0.0008825884675995136


Epoch 30 testing: 100%|██████████| 4/4 [00:00<00:00, 10.69it/s]

Test: 30/30: loss:0.0009608395484974608, accuracy:0.9869731115269336, precision:0.29821974492789427, recall:0.48428571428571426, f1:0.3674032928197981, mcc:0.3729301641563404, ece:0.013026888482272625, epoch:30





In [11]:


print(f"Training ends {datetime.now().isoformat()}")
for metric in ["accuracy", "precision", "recall", "f1", "mcc"]:
    if len(model_metrics[f"test_{metric}"]) > 0:
        print(f"Test {metric}:", model_metrics[f"test_{metric}"][-1])

plot_model_metrics(model_metrics)

Training ends 2025-02-06T17:53:26.513654
Test accuracy: 0.9869731115269336
Test precision: 0.29821974492789427
Test recall: 0.48428571428571426
Test f1: 0.3674032928197981
Test mcc: 0.3729301641563404


In [12]:
os.makedirs(model_output_dir, exist_ok=True)
torch.save(model, f"{model_output_dir}/model.pickle")
torch.save({
    "train_slides": set(train_slides["slide_name"]),
    "test_slides": set(test_slides["slide_name"])
},
    f"{model_output_dir}/data-split.pickle"
)