In [1]:

import time
from stopper.EarlyStopper import EarlyStopper
from models.resnet import Resnet18BinaryClassifier
import torch

from datasets.SlideSeperatedCSVDataset import SlideSeperatedCSVDataset
from train import kfold_grid_search
from utils import reduce_dataset

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
slides_root_dir = "data/whole-slides/gut"
labels_root_dir = "data/labels"
candidates_dataset_dir = "output/candidates"
model_output_dir = "output/models"
PretrainedModelClass = Resnet18BinaryClassifier
pretrained_model_name = PretrainedModelClass.get_pretrained_model_name()
pretrained_output_size = PretrainedModelClass.pretrained_output_size
features_csv_file_name = f"{PretrainedModelClass.get_pretrained_model_name()}_{pretrained_output_size}_features.csv"

early_stopper = EarlyStopper(patience=3, min_delta=1e-4)
print(f"{pretrained_model_name}: {pretrained_output_size} features")
dataset = SlideSeperatedCSVDataset(f"output/candidates/{features_csv_file_name}")
dataset = reduce_dataset(dataset, discard_ratio=0.0)
print(f"{len(dataset):,} examples")
kfold_grid_search(dataset, in_features=pretrained_output_size, device=device,
                  checkpoint_file_path=f"output/grid-search/grid-search-{int(time.time() * 1000)}.json",
                  batch_size=4096,
                  max_epochs=50,
                  undersample=False,
                  hidden_layer_combs=[1],
                  unit_combs=[1024],
                  learning_rate_combs=[0.00030],
                  dropout_combs=[0.25],
                  threshold_combs=[0.5],
                  focal_alpha_combs=[0.1, 0.5, 0.9],
                  early_stopper=early_stopper
                  )


Resnet18: 512 features
196,496 examples
(0/3) (hidden_layers=1, units=1024, dropout=0.25, threshold=0.5, learning_rate=0.0003, weight_decay=0.0, focal_alpha=0.1, focal_gamma=2.0)


Fold 1:  14%|█▍        | 7/50 [00:25<02:36,  3.65s/it]
Testing: 100%|██████████| 10/10 [00:00<00:00, 12.71it/s]
Fold 2:  14%|█▍        | 7/50 [00:27<02:46,  3.86s/it]
Testing: 100%|██████████| 10/10 [00:00<00:00, 13.51it/s]
Fold 3:  14%|█▍        | 7/50 [00:26<02:43,  3.79s/it]
Testing: 100%|██████████| 10/10 [00:01<00:00,  9.26it/s]
Fold 4:  14%|█▍        | 7/50 [00:26<02:45,  3.85s/it]
Testing: 100%|██████████| 10/10 [00:00<00:00, 12.71it/s]
Fold 5:  14%|█▍        | 7/50 [00:26<02:43,  3.80s/it]
Testing: 100%|██████████| 10/10 [00:00<00:00, 11.96it/s]


{'test_loss': 0.001735282384324819, 'test_accuracy': 0.9900007966223567, 'test_precision': 0.0, 'test_recall': 0.0, 'test_f1': 0.0, 'test_mcc': 0.0, 'test_epoch': 0.001748440086375922}
(1/3) (hidden_layers=1, units=1024, dropout=0.25, threshold=0.5, learning_rate=0.0003, weight_decay=0.0, focal_alpha=0.5, focal_gamma=2.0)


Fold 1:  18%|█▊        | 9/50 [00:33<02:34,  3.78s/it]
Testing: 100%|██████████| 10/10 [00:00<00:00, 13.01it/s]
Fold 2:  26%|██▌       | 13/50 [00:47<02:16,  3.68s/it]
Testing: 100%|██████████| 10/10 [00:00<00:00, 13.39it/s]
Fold 3:  12%|█▏        | 6/50 [00:23<02:54,  3.96s/it]
Testing: 100%|██████████| 10/10 [00:00<00:00, 15.02it/s]
Fold 4:   0%|          | 0/50 [00:00<?, ?it/s]


KeyboardInterrupt: 