In [1]:
import os
import sys
import cv2
import numpy as np
import pandas as pd
from time import time
from pathlib import Path
from collections import defaultdict

# Torch and Torchvision
import torch
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torch.utils.tensorboard import SummaryWriter

# Own functions and logging of git state
repo_path = "/home/ubuntu/source/dermx-experiments/"
sys.path.append(repo_path)
from dermx.utils import make_dir, save_json, create_and_save_class_map, get_git_info
from dermx.models import DermXModelGuidedAttention
from dermx.model_utils import freeze_layers
from dermx.datasets import DermXAttnDataset
from dermx.train_utils import (
    iterate_through_attn_dataloader,
    print_history_result,
    str_optimizer_map,
    str_scheduler_map,
)
from dermx.losses import str_loss_map

## Hyperparameters

In [2]:
diagnoses = [
    "Acne",
    "Actinic keratosis",
    "Psoriasis",
    "Seborrheic dermatitis",
    "Viral warts",
    "Vitiligo",
]
characteristics = [
    "Plaque",
    "Papule",
    "Dermatoglyph disruption",
    "Pustule",
    "Scar",
    "Closed comedo",
    "Open comedo",
    "Patch",
    "Sun damage",
    "Scale",
    # "Macule",
    # "Thrombosed capillaries",
    # "Telangiectasia",
    # "Nodule",
    # "Cyst",
    # "Leukotrichia",
]
diagnoses.sort()
characteristics.sort()

hps = {
    # Model
    "backbone_name": "resnet50",
    "use_pretrained": True,
    "freeze": False,
    "unfreeze_layers": [],
    "unfreeze_bn_layers": True,
    "num_units_in_dx_clf": 64,
    "num_units_in_cx_clf": 64,
    "dropout": 0.2,
    "use_skip_connection": True,
    "use_downsampled_masks": False,
    "use_batchnorm": False,
    "return_gradcam_attributes": True,
    # Training, loss and optimizer
    "loss": {"dx": "ce", "cx": "bce", "attn": "softsesp"},
    "loss_factors": {"dx": 1.0, "cx": 1.0, "attn": 10.0},
    "loss_kwargs": {"dx": {}, "cx": {}, "attn": {"activation": None}},
    "class_weights": {"dx": None, "cx": None, "attn": None},
    "optimizer": "adamw",
    "optimizer_kwargs": {},
    "learning_rate": 0.0005,
    "lr_scheduler": "cosine_annealing_warm_restarts",
    "lr_scheduler_kwargs": {"T_0": 3, "T_mult": 3, "eta_min": 1e-5},
    "batch_size": 32,
    "dataaug_color_jitter_kwargs": {
        "brightness": 0.35,
        "contrast": 0.2,
        "saturation": 0.2,
        "hue": 0.15,
    },
    "dataaug_random_affine_kwargs": {
        "degrees": 10,
        "scale": (0.85, 1.15),
        "translate": (0.15, 0.15),
    },
    # Data
    # "target_size": (256, 256),
    "diagnoses": diagnoses,
    "characteristics": characteristics,
    "num_segm_classes": len(characteristics),
    "num_dx_classes": len(diagnoses),
    "norm_mean": [0.485, 0.456, 0.406],
    "norm_std": [0.229, 0.224, 0.225],
    # IO
    "experiment_root_dir": "/home/ubuntu/store/MedIA/experiments_rebuttal/resnet",
    "label_map_dir": "/home/ubuntu/store/dermx-folds",
    "image_dir": "/home/ubuntu/store/dermx_cleaned_images",
    "mask_dir": "/home/ubuntu/store/dermx_cleaned_masks/fusion_masks",
    "mask_label_fusion_rule": "fuzzy_relative",
    "checkpoint_metric": "loss",
    "git": get_git_info(repo_path),
}

# Set parameters according to backbones
if hps["backbone_name"] == "efficientnet_b0":
    hps["gradcam_target_layer"] = "backbone.blocks.6.0.conv_pwl"
    hps["target_size"] = (224, 224)
    hps["mask_downsample_size"] = (7, 7) if hps["use_downsampled_masks"] else None
elif hps["backbone_name"] == "efficientnet_b1":
    hps["gradcam_target_layer"] = "backbone.blocks.6.1.conv_pwl"
    hps["target_size"] = (240, 240)
    hps["mask_downsample_size"] = (8, 8) if hps["use_downsampled_masks"] else None
elif hps["backbone_name"] == "efficientnet_b2":
    hps["gradcam_target_layer"] = "backbone.blocks.6.1.conv_pwl"
    hps["target_size"] = (260, 260)
    hps["mask_downsample_size"] = (9, 9) if hps["use_downsampled_masks"] else None
elif hps["backbone_name"] == "resnet50":
    hps["gradcam_target_layer"] = "backbone.layer4.2.conv3"
    hps["target_size"] = (256, 256)
    hps["mask_downsample_size"] = (9, 9) if hps["use_downsampled_masks"] else None

# Calculate number of epochs
if hps["lr_scheduler"] == "cosine_annealing_warm_restarts":
    num_lr_scheduler_cycles = 4
    hps["num_epochs"] = sum(
        hps["lr_scheduler_kwargs"]["T_0"]
        * hps["lr_scheduler_kwargs"]["T_mult"] ** cycle_idx
        for cycle_idx in range(num_lr_scheduler_cycles)
    )

## Experiment directories

In [3]:
experiment_dir = make_dir(
    hps["experiment_root_dir"], f"{hps['backbone_name']}_{int(time())}"
)
fold_idx = "9"
experiment_fold_dir = make_dir(experiment_dir, f"fold_{fold_idx}")
save_json(experiment_dir, "hyperparameters.json", hps)

## Parameters that can be set outside the training loop

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transforms
rgb_transforms_train = [
    T.ColorJitter(**hps["dataaug_color_jitter_kwargs"]),
    T.Normalize(mean=hps["norm_mean"], std=hps["norm_std"]),
]
shared_transforms_train = [
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomAffine(**hps["dataaug_random_affine_kwargs"]),
    T.Resize(hps["target_size"]),
]
rgb_transforms_test = [T.Normalize(mean=hps["norm_mean"], std=hps["norm_std"])]
shared_transforms_test = [T.Resize(hps["target_size"])]

# Metrics
metric_names = list(hps["loss"].keys()) + ["loss", "dx_acc", "cx_f1", "soft_iou"]
checkpoint_metric = hps["checkpoint_metric"]
metric_larger_is_better = {
    name: name in ("dx_acc", "cx_f1", "soft_iou") for name in metric_names
}
checkpoint_metric_value = {
    name: 0.0 if metric_larger_is_better[name] else np.inf for name in metric_names
}

## Label and class maps

In [5]:
folds_dir = os.path.join(hps["label_map_dir"], f"fold_{fold_idx}")
train_df = pd.read_csv(
    os.path.join(folds_dir, f"metadata_fold_{fold_idx}_train.csv")
)
test_df = pd.read_csv(
    os.path.join(folds_dir, f"metadata_fold_{fold_idx}_test.csv")
)
class_map_dx = create_and_save_class_map(hps["diagnoses"], "dx", experiment_dir)
class_map_char = create_and_save_class_map(
    hps["characteristics"], "char", experiment_dir
)

## Training initialization

### Dataset and Dataloaders

In [6]:
# Train
data_dermx_train = DermXAttnDataset(
    train_df,
    class_map_dx,
    class_map_char,
    image_dir=hps["image_dir"],
    mask_dir=hps["mask_dir"],
    mask_label_fusion_rule=hps["mask_label_fusion_rule"],
    rgb_transforms=rgb_transforms_train,
    shared_transforms=shared_transforms_train,
    mask_downsample_size=hps["mask_downsample_size"],
)
dataloader_train = DataLoader(
    data_dermx_train, batch_size=hps["batch_size"], shuffle=True, num_workers=8
)

# Test
data_dermx_test = DermXAttnDataset(
    test_df,
    class_map_dx,
    class_map_char,
    image_dir=hps["image_dir"],
    mask_dir=hps["mask_dir"],
    mask_label_fusion_rule=hps["mask_label_fusion_rule"],
    rgb_transforms=rgb_transforms_test,
    shared_transforms=shared_transforms_test,
    mask_downsample_size=hps["mask_downsample_size"],
)
dataloader_test = DataLoader(
    data_dermx_test, batch_size=hps["batch_size"], shuffle=True, num_workers=8
)

### Model

In [7]:
model = DermXModelGuidedAttention(
    num_dx_classes=hps["num_dx_classes"],
    num_cx_classes=hps["num_segm_classes"],
    target_layer=hps["gradcam_target_layer"],
    backbone_name=hps["backbone_name"],
    pretrained=hps["use_pretrained"],
    return_gradcam_attributes=hps["return_gradcam_attributes"],
    interpolate_gradcam_attributes=(not hps["use_downsampled_masks"]),
    num_units_in_dx_clf=hps["num_units_in_dx_clf"],
    num_units_in_cx_clf=hps["num_units_in_cx_clf"],
    dropout=hps["dropout"],
    use_batchnorm=hps["use_batchnorm"],
    use_skip_connection=hps["use_skip_connection"],
    device=device,
)
model = model.to(device)
if hps["freeze"]:
    freeze_layers(
        model,
        ignore_names=hps["unfreeze_layers"],
        ignore_bn=hps["unfreeze_bn_layers"],
    )

In [8]:
model

DermXModelGuidedAttention(
  (backbone): FeatureListNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_ru

### Optimizer

In [9]:
Optimizer = str_optimizer_map[hps["optimizer"]]
optimizer = Optimizer(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=hps["learning_rate"],
    **hps["optimizer_kwargs"],
)
if hps["lr_scheduler"] is not None:
    Scheduler = str_scheduler_map[hps["lr_scheduler"]]
    scheduler = Scheduler(optimizer, **hps["lr_scheduler_kwargs"],)
else:
    scheduler = None

### Loss

In [10]:
criterion = {
    loss_name: str_loss_map[hps["loss"][loss_name]](
        hps["class_weights"][loss_name], 
        **hps["loss_kwargs"][loss_name]
    ) for loss_name in hps["loss"]
}

### Metrics

In [11]:
tb_dir = os.path.join(experiment_fold_dir, "tensorboard")
tb_writer = SummaryWriter(tb_dir)
# tb_writer.add_graph(model, torch.zeros((1, 3) + hps["target_size"]).to(device))
history = defaultdict(list)

## Training

In [None]:
import warnings
warnings.filterwarnings("ignore")

iterate_through_dataloader_common_kwargs = {
    "model": model,
    "optimizer": optimizer,
    "criterion": criterion,
    "criterion_weight": hps["loss_factors"],
    "n_total_epochs": hps["num_epochs"],
    "metric_names": metric_names,
    "scheduler": scheduler,
    "device": device,
}

for epoch in range(1, hps["num_epochs"] + 1):
    start_time = time()
    running_metrics_train = iterate_through_attn_dataloader(
        dataloader_train,
        epoch=epoch,
        train_mode=True,
        **iterate_through_dataloader_common_kwargs,
    )
    running_metrics_test = iterate_through_attn_dataloader(
        dataloader_test,
        epoch=epoch,
        train_mode=False,
        force_enable_grad=hps["return_gradcam_attributes"],
        **iterate_through_dataloader_common_kwargs,
    )

    # Write and report history
    history["epoch"].append(epoch)
    for name in metric_names:
        history[f"{name}_train"].append(running_metrics_train[name])
        history[f"{name}_val"].append(running_metrics_test[name])
        tb_writer.add_scalar(f"{name}/train", running_metrics_train[name], epoch)
        tb_writer.add_scalar(f"{name}/test", running_metrics_test[name], epoch)
    if hps["lr_scheduler"] is not None:
        last_lr = scheduler.get_last_lr()[0]
        history["lr"].append(last_lr)
        tb_writer.add_scalar("lr", last_lr, epoch)
    save_json(experiment_fold_dir, "history.json", history)
    print_history_result(history, start_time)

    # Model checkpoint
    lastest_checkpoint_metric_value = history[f"{checkpoint_metric}_val"][-1]
    smaller_than_comparison = (
        lastest_checkpoint_metric_value < checkpoint_metric_value[checkpoint_metric]
    )
    larger_than_comparison = (
        lastest_checkpoint_metric_value > checkpoint_metric_value[checkpoint_metric]
    )
    if (
        larger_than_comparison
        if metric_larger_is_better[checkpoint_metric]
        else smaller_than_comparison
    ):
        torch.save(
            model.state_dict(),
            os.path.join(
                experiment_fold_dir,
                f"model_with_best_{checkpoint_metric}_at_epoch_{epoch}.pt",
            ),
        )
        print(
            "Model got better and was saved, with "
            f"{checkpoint_metric}_val = {lastest_checkpoint_metric_value:.3f}"
        )
        checkpoint_metric_value[checkpoint_metric] = lastest_checkpoint_metric_value
    if epoch == hps["num_epochs"]:
        torch.save(
            model.state_dict(),
            os.path.join(experiment_fold_dir, f"model_at_train_end.pt"),
        )

tb_writer.close()

Epoch 1/120: 100%|██████████| 16/16 [01:29<00:00,  5.57s/it]


Epoch 1: dx_train: 1.630 -- dx_val: 1.647 -- cx_train: 0.680 -- cx_val: 0.688 -- attn_train: 0.505 -- attn_val: 0.478 -- loss_train: 1.630 -- loss_val: 1.647 -- dx_acc_train: 0.369 -- dx_acc_val: 0.299 -- cx_f1_train: 0.210 -- cx_f1_val: 0.313 -- soft_iou_train: 0.101 -- soft_iou_val: 0.149 -- lr: 0.000 -- completed in 1.6 mins
Model got better and was saved, with loss_val = 1.647


Epoch 2/120: 100%|██████████| 16/16 [01:25<00:00,  5.32s/it]


Epoch 2: dx_train: 1.346 -- dx_val: 1.239 -- cx_train: 0.685 -- cx_val: 0.936 -- attn_train: 0.495 -- attn_val: 0.467 -- loss_train: 1.346 -- loss_val: 1.239 -- dx_acc_train: 0.489 -- dx_acc_val: 0.533 -- cx_f1_train: 0.262 -- cx_f1_val: 0.278 -- soft_iou_train: 0.126 -- soft_iou_val: 0.192 -- lr: 0.000 -- completed in 1.5 mins
Model got better and was saved, with loss_val = 1.239


Epoch 3/120: 100%|██████████| 16/16 [01:32<00:00,  5.77s/it]


Epoch 3: dx_train: 1.051 -- dx_val: 0.953 -- cx_train: 0.731 -- cx_val: 0.776 -- attn_train: 0.494 -- attn_val: 0.479 -- loss_train: 1.051 -- loss_val: 0.953 -- dx_acc_train: 0.647 -- dx_acc_val: 0.673 -- cx_f1_train: 0.266 -- cx_f1_val: 0.271 -- soft_iou_train: 0.129 -- soft_iou_val: 0.167 -- lr: 0.001 -- completed in 1.6 mins
Model got better and was saved, with loss_val = 0.953


Epoch 4/120: 100%|██████████| 16/16 [01:26<00:00,  5.42s/it]


Epoch 4: dx_train: 1.105 -- dx_val: 4.110 -- cx_train: 0.776 -- cx_val: 1.138 -- attn_train: 0.489 -- attn_val: 0.477 -- loss_train: 1.105 -- loss_val: 4.110 -- dx_acc_train: 0.619 -- dx_acc_val: 0.281 -- cx_f1_train: 0.276 -- cx_f1_val: 0.284 -- soft_iou_train: 0.144 -- soft_iou_val: 0.148 -- lr: 0.000 -- completed in 1.5 mins


Epoch 5/120: 100%|██████████| 16/16 [01:24<00:00,  5.28s/it]


Epoch 5: dx_train: 1.109 -- dx_val: 1.799 -- cx_train: 0.746 -- cx_val: 1.001 -- attn_train: 0.482 -- attn_val: 0.449 -- loss_train: 1.109 -- loss_val: 1.799 -- dx_acc_train: 0.606 -- dx_acc_val: 0.501 -- cx_f1_train: 0.289 -- cx_f1_val: 0.288 -- soft_iou_train: 0.157 -- soft_iou_val: 0.198 -- lr: 0.000 -- completed in 1.5 mins


Epoch 6/120: 100%|██████████| 16/16 [01:21<00:00,  5.07s/it]


Epoch 6: dx_train: 0.966 -- dx_val: 2.812 -- cx_train: 0.783 -- cx_val: 1.259 -- attn_train: 0.484 -- attn_val: 0.444 -- loss_train: 0.966 -- loss_val: 2.812 -- dx_acc_train: 0.648 -- dx_acc_val: 0.511 -- cx_f1_train: 0.320 -- cx_f1_val: 0.355 -- soft_iou_train: 0.139 -- soft_iou_val: 0.194 -- lr: 0.000 -- completed in 1.4 mins


Epoch 7/120: 100%|██████████| 16/16 [01:21<00:00,  5.09s/it]


Epoch 7: dx_train: 0.899 -- dx_val: 0.986 -- cx_train: 0.849 -- cx_val: 0.922 -- attn_train: 0.487 -- attn_val: 0.474 -- loss_train: 0.899 -- loss_val: 0.986 -- dx_acc_train: 0.679 -- dx_acc_val: 0.664 -- cx_f1_train: 0.334 -- cx_f1_val: 0.346 -- soft_iou_train: 0.131 -- soft_iou_val: 0.163 -- lr: 0.000 -- completed in 1.4 mins


Epoch 8/120: 100%|██████████| 16/16 [01:25<00:00,  5.36s/it]


Epoch 8: dx_train: 0.806 -- dx_val: 1.054 -- cx_train: 0.842 -- cx_val: 0.863 -- attn_train: 0.487 -- attn_val: 0.482 -- loss_train: 0.806 -- loss_val: 1.054 -- dx_acc_train: 0.724 -- dx_acc_val: 0.632 -- cx_f1_train: 0.329 -- cx_f1_val: 0.362 -- soft_iou_train: 0.133 -- soft_iou_val: 0.134 -- lr: 0.000 -- completed in 1.5 mins


Epoch 9/120: 100%|██████████| 16/16 [01:25<00:00,  5.37s/it]


Epoch 9: dx_train: 0.562 -- dx_val: 0.932 -- cx_train: 0.889 -- cx_val: 1.064 -- attn_train: 0.489 -- attn_val: 0.476 -- loss_train: 0.562 -- loss_val: 0.932 -- dx_acc_train: 0.802 -- dx_acc_val: 0.732 -- cx_f1_train: 0.319 -- cx_f1_val: 0.346 -- soft_iou_train: 0.136 -- soft_iou_val: 0.150 -- lr: 0.000 -- completed in 1.5 mins
Model got better and was saved, with loss_val = 0.932


Epoch 10/120: 100%|██████████| 16/16 [01:24<00:00,  5.28s/it]


Epoch 10: dx_train: 0.488 -- dx_val: 0.681 -- cx_train: 0.941 -- cx_val: 1.040 -- attn_train: 0.485 -- attn_val: 0.470 -- loss_train: 0.488 -- loss_val: 0.681 -- dx_acc_train: 0.825 -- dx_acc_val: 0.807 -- cx_f1_train: 0.332 -- cx_f1_val: 0.354 -- soft_iou_train: 0.142 -- soft_iou_val: 0.169 -- lr: 0.000 -- completed in 1.5 mins
Model got better and was saved, with loss_val = 0.681


Epoch 11/120: 100%|██████████| 16/16 [01:22<00:00,  5.15s/it]


Epoch 11: dx_train: 0.398 -- dx_val: 0.732 -- cx_train: 0.959 -- cx_val: 1.039 -- attn_train: 0.486 -- attn_val: 0.470 -- loss_train: 0.398 -- loss_val: 0.732 -- dx_acc_train: 0.875 -- dx_acc_val: 0.760 -- cx_f1_train: 0.328 -- cx_f1_val: 0.351 -- soft_iou_train: 0.140 -- soft_iou_val: 0.162 -- lr: 0.000 -- completed in 1.4 mins


Epoch 12/120: 100%|██████████| 16/16 [01:24<00:00,  5.31s/it]


Epoch 12: dx_train: 0.384 -- dx_val: 0.728 -- cx_train: 0.970 -- cx_val: 1.038 -- attn_train: 0.489 -- attn_val: 0.473 -- loss_train: 0.384 -- loss_val: 0.728 -- dx_acc_train: 0.857 -- dx_acc_val: 0.788 -- cx_f1_train: 0.326 -- cx_f1_val: 0.347 -- soft_iou_train: 0.131 -- soft_iou_val: 0.160 -- lr: 0.001 -- completed in 1.5 mins


Epoch 13/120: 100%|██████████| 16/16 [01:24<00:00,  5.26s/it]


Epoch 13: dx_train: 0.503 -- dx_val: 2.369 -- cx_train: 0.971 -- cx_val: 1.360 -- attn_train: 0.487 -- attn_val: 0.448 -- loss_train: 0.503 -- loss_val: 2.369 -- dx_acc_train: 0.827 -- dx_acc_val: 0.514 -- cx_f1_train: 0.323 -- cx_f1_val: 0.388 -- soft_iou_train: 0.136 -- soft_iou_val: 0.190 -- lr: 0.000 -- completed in 1.5 mins


Epoch 14/120: 100%|██████████| 16/16 [01:30<00:00,  5.63s/it]


Epoch 14: dx_train: 0.904 -- dx_val: 1.539 -- cx_train: 0.888 -- cx_val: 0.940 -- attn_train: 0.481 -- attn_val: 0.482 -- loss_train: 0.904 -- loss_val: 1.539 -- dx_acc_train: 0.673 -- dx_acc_val: 0.520 -- cx_f1_train: 0.341 -- cx_f1_val: 0.287 -- soft_iou_train: 0.141 -- soft_iou_val: 0.147 -- lr: 0.000 -- completed in 1.6 mins


Epoch 15/120: 100%|██████████| 16/16 [01:24<00:00,  5.30s/it]


Epoch 15: dx_train: 0.951 -- dx_val: 1.174 -- cx_train: 0.819 -- cx_val: 1.012 -- attn_train: 0.489 -- attn_val: 0.471 -- loss_train: 0.951 -- loss_val: 1.174 -- dx_acc_train: 0.678 -- dx_acc_val: 0.546 -- cx_f1_train: 0.313 -- cx_f1_val: 0.272 -- soft_iou_train: 0.127 -- soft_iou_val: 0.180 -- lr: 0.000 -- completed in 1.5 mins


Epoch 16/120: 100%|██████████| 16/16 [01:25<00:00,  5.34s/it]


Epoch 16: dx_train: 0.764 -- dx_val: 1.298 -- cx_train: 0.861 -- cx_val: 1.003 -- attn_train: 0.487 -- attn_val: 0.479 -- loss_train: 0.764 -- loss_val: 1.298 -- dx_acc_train: 0.740 -- dx_acc_val: 0.648 -- cx_f1_train: 0.327 -- cx_f1_val: 0.335 -- soft_iou_train: 0.126 -- soft_iou_val: 0.156 -- lr: 0.000 -- completed in 1.5 mins


Epoch 17/120: 100%|██████████| 16/16 [01:21<00:00,  5.09s/it]


Epoch 17: dx_train: 0.588 -- dx_val: 1.136 -- cx_train: 0.922 -- cx_val: 1.076 -- attn_train: 0.483 -- attn_val: 0.472 -- loss_train: 0.588 -- loss_val: 1.136 -- dx_acc_train: 0.791 -- dx_acc_val: 0.741 -- cx_f1_train: 0.338 -- cx_f1_val: 0.335 -- soft_iou_train: 0.130 -- soft_iou_val: 0.149 -- lr: 0.000 -- completed in 1.4 mins


Epoch 18/120: 100%|██████████| 16/16 [01:24<00:00,  5.27s/it]


Epoch 18: dx_train: 0.640 -- dx_val: 1.350 -- cx_train: 0.936 -- cx_val: 1.171 -- attn_train: 0.482 -- attn_val: 0.486 -- loss_train: 0.640 -- loss_val: 1.350 -- dx_acc_train: 0.772 -- dx_acc_val: 0.577 -- cx_f1_train: 0.327 -- cx_f1_val: 0.305 -- soft_iou_train: 0.142 -- soft_iou_val: 0.158 -- lr: 0.000 -- completed in 1.5 mins


Epoch 19/120: 100%|██████████| 16/16 [01:22<00:00,  5.15s/it]


Epoch 19: dx_train: 0.640 -- dx_val: 1.129 -- cx_train: 0.918 -- cx_val: 1.062 -- attn_train: 0.481 -- attn_val: 0.466 -- loss_train: 0.640 -- loss_val: 1.129 -- dx_acc_train: 0.786 -- dx_acc_val: 0.698 -- cx_f1_train: 0.322 -- cx_f1_val: 0.339 -- soft_iou_train: 0.142 -- soft_iou_val: 0.164 -- lr: 0.000 -- completed in 1.4 mins


Epoch 20/120: 100%|██████████| 16/16 [01:20<00:00,  5.01s/it]


Epoch 20: dx_train: 0.510 -- dx_val: 0.987 -- cx_train: 0.953 -- cx_val: 1.019 -- attn_train: 0.484 -- attn_val: 0.473 -- loss_train: 0.510 -- loss_val: 0.987 -- dx_acc_train: 0.827 -- dx_acc_val: 0.698 -- cx_f1_train: 0.321 -- cx_f1_val: 0.372 -- soft_iou_train: 0.134 -- soft_iou_val: 0.158 -- lr: 0.000 -- completed in 1.4 mins


Epoch 21/120: 100%|██████████| 16/16 [01:19<00:00,  4.99s/it]


Epoch 21: dx_train: 0.551 -- dx_val: 1.057 -- cx_train: 1.003 -- cx_val: 1.151 -- attn_train: 0.487 -- attn_val: 0.469 -- loss_train: 0.551 -- loss_val: 1.057 -- dx_acc_train: 0.830 -- dx_acc_val: 0.670 -- cx_f1_train: 0.336 -- cx_f1_val: 0.291 -- soft_iou_train: 0.125 -- soft_iou_val: 0.168 -- lr: 0.000 -- completed in 1.4 mins


Epoch 22/120: 100%|██████████| 16/16 [01:25<00:00,  5.37s/it]


Epoch 22: dx_train: 0.402 -- dx_val: 0.785 -- cx_train: 1.038 -- cx_val: 1.092 -- attn_train: 0.489 -- attn_val: 0.458 -- loss_train: 0.402 -- loss_val: 0.785 -- dx_acc_train: 0.870 -- dx_acc_val: 0.692 -- cx_f1_train: 0.318 -- cx_f1_val: 0.345 -- soft_iou_train: 0.133 -- soft_iou_val: 0.178 -- lr: 0.000 -- completed in 1.5 mins


Epoch 23/120: 100%|██████████| 16/16 [01:25<00:00,  5.32s/it]


Epoch 23: dx_train: 0.362 -- dx_val: 1.116 -- cx_train: 1.115 -- cx_val: 1.233 -- attn_train: 0.486 -- attn_val: 0.471 -- loss_train: 0.362 -- loss_val: 1.116 -- dx_acc_train: 0.878 -- dx_acc_val: 0.654 -- cx_f1_train: 0.332 -- cx_f1_val: 0.330 -- soft_iou_train: 0.132 -- soft_iou_val: 0.164 -- lr: 0.000 -- completed in 1.5 mins


Epoch 24/120: 100%|██████████| 16/16 [01:24<00:00,  5.26s/it]


Epoch 24: dx_train: 0.370 -- dx_val: 0.919 -- cx_train: 1.064 -- cx_val: 1.219 -- attn_train: 0.483 -- attn_val: 0.466 -- loss_train: 0.370 -- loss_val: 0.919 -- dx_acc_train: 0.887 -- dx_acc_val: 0.741 -- cx_f1_train: 0.332 -- cx_f1_val: 0.363 -- soft_iou_train: 0.148 -- soft_iou_val: 0.146 -- lr: 0.000 -- completed in 1.5 mins


Epoch 25/120: 100%|██████████| 16/16 [01:26<00:00,  5.44s/it]


Epoch 25: dx_train: 0.302 -- dx_val: 1.102 -- cx_train: 1.127 -- cx_val: 1.172 -- attn_train: 0.488 -- attn_val: 0.478 -- loss_train: 0.302 -- loss_val: 1.102 -- dx_acc_train: 0.898 -- dx_acc_val: 0.711 -- cx_f1_train: 0.310 -- cx_f1_val: 0.337 -- soft_iou_train: 0.138 -- soft_iou_val: 0.152 -- lr: 0.000 -- completed in 1.5 mins


Epoch 26/120: 100%|██████████| 16/16 [01:24<00:00,  5.29s/it]


Epoch 26: dx_train: 0.355 -- dx_val: 0.702 -- cx_train: 1.135 -- cx_val: 1.245 -- attn_train: 0.482 -- attn_val: 0.459 -- loss_train: 0.355 -- loss_val: 0.702 -- dx_acc_train: 0.887 -- dx_acc_val: 0.751 -- cx_f1_train: 0.330 -- cx_f1_val: 0.367 -- soft_iou_train: 0.141 -- soft_iou_val: 0.159 -- lr: 0.000 -- completed in 1.5 mins


Epoch 27/120: 100%|██████████| 16/16 [01:23<00:00,  5.21s/it]


Epoch 27: dx_train: 0.342 -- dx_val: 1.310 -- cx_train: 1.143 -- cx_val: 1.338 -- attn_train: 0.485 -- attn_val: 0.479 -- loss_train: 0.342 -- loss_val: 1.310 -- dx_acc_train: 0.891 -- dx_acc_val: 0.723 -- cx_f1_train: 0.325 -- cx_f1_val: 0.329 -- soft_iou_train: 0.145 -- soft_iou_val: 0.145 -- lr: 0.000 -- completed in 1.5 mins


Epoch 28/120: 100%|██████████| 16/16 [01:28<00:00,  5.55s/it]


Epoch 28: dx_train: 0.260 -- dx_val: 0.799 -- cx_train: 1.126 -- cx_val: 1.286 -- attn_train: 0.487 -- attn_val: 0.478 -- loss_train: 0.260 -- loss_val: 0.799 -- dx_acc_train: 0.915 -- dx_acc_val: 0.717 -- cx_f1_train: 0.339 -- cx_f1_val: 0.347 -- soft_iou_train: 0.130 -- soft_iou_val: 0.143 -- lr: 0.000 -- completed in 1.6 mins


Epoch 29/120: 100%|██████████| 16/16 [01:21<00:00,  5.08s/it]


Epoch 29: dx_train: 0.216 -- dx_val: 0.885 -- cx_train: 1.176 -- cx_val: 1.383 -- attn_train: 0.484 -- attn_val: 0.481 -- loss_train: 0.216 -- loss_val: 0.885 -- dx_acc_train: 0.936 -- dx_acc_val: 0.738 -- cx_f1_train: 0.338 -- cx_f1_val: 0.321 -- soft_iou_train: 0.138 -- soft_iou_val: 0.157 -- lr: 0.000 -- completed in 1.4 mins


Epoch 30/120: 100%|██████████| 16/16 [01:31<00:00,  5.69s/it]


Epoch 30: dx_train: 0.199 -- dx_val: 1.013 -- cx_train: 1.214 -- cx_val: 1.408 -- attn_train: 0.483 -- attn_val: 0.473 -- loss_train: 0.199 -- loss_val: 1.013 -- dx_acc_train: 0.922 -- dx_acc_val: 0.726 -- cx_f1_train: 0.335 -- cx_f1_val: 0.318 -- soft_iou_train: 0.146 -- soft_iou_val: 0.166 -- lr: 0.000 -- completed in 1.6 mins


Epoch 31/120: 100%|██████████| 16/16 [01:21<00:00,  5.09s/it]


Epoch 31: dx_train: 0.202 -- dx_val: 0.879 -- cx_train: 1.230 -- cx_val: 1.327 -- attn_train: 0.486 -- attn_val: 0.470 -- loss_train: 0.202 -- loss_val: 0.879 -- dx_acc_train: 0.942 -- dx_acc_val: 0.772 -- cx_f1_train: 0.329 -- cx_f1_val: 0.348 -- soft_iou_train: 0.144 -- soft_iou_val: 0.156 -- lr: 0.000 -- completed in 1.4 mins


Epoch 32/120: 100%|██████████| 16/16 [01:23<00:00,  5.24s/it]


Epoch 32: dx_train: 0.124 -- dx_val: 0.948 -- cx_train: 1.215 -- cx_val: 1.325 -- attn_train: 0.482 -- attn_val: 0.471 -- loss_train: 0.124 -- loss_val: 0.948 -- dx_acc_train: 0.959 -- dx_acc_val: 0.729 -- cx_f1_train: 0.336 -- cx_f1_val: 0.348 -- soft_iou_train: 0.151 -- soft_iou_val: 0.163 -- lr: 0.000 -- completed in 1.5 mins


Epoch 33/120: 100%|██████████| 16/16 [01:23<00:00,  5.22s/it]


Epoch 33: dx_train: 0.119 -- dx_val: 0.813 -- cx_train: 1.259 -- cx_val: 1.323 -- attn_train: 0.488 -- attn_val: 0.469 -- loss_train: 0.119 -- loss_val: 0.813 -- dx_acc_train: 0.971 -- dx_acc_val: 0.751 -- cx_f1_train: 0.338 -- cx_f1_val: 0.359 -- soft_iou_train: 0.134 -- soft_iou_val: 0.159 -- lr: 0.000 -- completed in 1.5 mins


Epoch 34/120: 100%|██████████| 16/16 [01:22<00:00,  5.19s/it]


Epoch 34: dx_train: 0.121 -- dx_val: 0.911 -- cx_train: 1.292 -- cx_val: 1.371 -- attn_train: 0.486 -- attn_val: 0.471 -- loss_train: 0.121 -- loss_val: 0.911 -- dx_acc_train: 0.973 -- dx_acc_val: 0.760 -- cx_f1_train: 0.337 -- cx_f1_val: 0.336 -- soft_iou_train: 0.141 -- soft_iou_val: 0.156 -- lr: 0.000 -- completed in 1.5 mins


Epoch 35/120: 100%|██████████| 16/16 [01:25<00:00,  5.32s/it]
Epoch 36/120: 100%|██████████| 16/16 [01:22<00:00,  5.14s/it]


Epoch 36: dx_train: 0.105 -- dx_val: 1.002 -- cx_train: 1.296 -- cx_val: 1.388 -- attn_train: 0.487 -- attn_val: 0.463 -- loss_train: 0.105 -- loss_val: 1.002 -- dx_acc_train: 0.971 -- dx_acc_val: 0.723 -- cx_f1_train: 0.332 -- cx_f1_val: 0.354 -- soft_iou_train: 0.140 -- soft_iou_val: 0.165 -- lr: 0.000 -- completed in 1.4 mins


Epoch 37/120: 100%|██████████| 16/16 [01:23<00:00,  5.19s/it]


Epoch 37: dx_train: 0.073 -- dx_val: 0.984 -- cx_train: 1.302 -- cx_val: 1.362 -- attn_train: 0.488 -- attn_val: 0.469 -- loss_train: 0.073 -- loss_val: 0.984 -- dx_acc_train: 0.980 -- dx_acc_val: 0.754 -- cx_f1_train: 0.344 -- cx_f1_val: 0.341 -- soft_iou_train: 0.139 -- soft_iou_val: 0.159 -- lr: 0.000 -- completed in 1.5 mins


Epoch 38/120: 100%|██████████| 16/16 [01:24<00:00,  5.29s/it]


Epoch 38: dx_train: 0.089 -- dx_val: 0.959 -- cx_train: 1.333 -- cx_val: 1.404 -- attn_train: 0.484 -- attn_val: 0.468 -- loss_train: 0.089 -- loss_val: 0.959 -- dx_acc_train: 0.971 -- dx_acc_val: 0.760 -- cx_f1_train: 0.343 -- cx_f1_val: 0.351 -- soft_iou_train: 0.144 -- soft_iou_val: 0.163 -- lr: 0.000 -- completed in 1.5 mins


Epoch 39/120: 100%|██████████| 16/16 [01:23<00:00,  5.19s/it]


Epoch 39: dx_train: 0.071 -- dx_val: 1.074 -- cx_train: 1.330 -- cx_val: 1.459 -- attn_train: 0.486 -- attn_val: 0.468 -- loss_train: 0.071 -- loss_val: 1.074 -- dx_acc_train: 0.984 -- dx_acc_val: 0.738 -- cx_f1_train: 0.339 -- cx_f1_val: 0.334 -- soft_iou_train: 0.140 -- soft_iou_val: 0.165 -- lr: 0.001 -- completed in 1.5 mins


Epoch 40/120: 100%|██████████| 16/16 [01:22<00:00,  5.15s/it]


Epoch 40: dx_train: 0.222 -- dx_val: 1.791 -- cx_train: 1.318 -- cx_val: 1.564 -- attn_train: 0.487 -- attn_val: 0.487 -- loss_train: 0.222 -- loss_val: 1.791 -- dx_acc_train: 0.915 -- dx_acc_val: 0.611 -- cx_f1_train: 0.336 -- cx_f1_val: 0.274 -- soft_iou_train: 0.140 -- soft_iou_val: 0.151 -- lr: 0.000 -- completed in 1.4 mins


Epoch 41/120: 100%|██████████| 16/16 [01:25<00:00,  5.36s/it]


Epoch 41: dx_train: 0.470 -- dx_val: 1.799 -- cx_train: 1.183 -- cx_val: 1.429 -- attn_train: 0.479 -- attn_val: 0.445 -- loss_train: 0.470 -- loss_val: 1.799 -- dx_acc_train: 0.846 -- dx_acc_val: 0.626 -- cx_f1_train: 0.331 -- cx_f1_val: 0.365 -- soft_iou_train: 0.151 -- soft_iou_val: 0.185 -- lr: 0.000 -- completed in 1.5 mins


Epoch 42/120:  50%|█████     | 8/16 [00:51<00:19,  2.39s/it]

## Save validation predictions

In [None]:
# model.load_state_dict(torch.load(f"{experiment_fold_dir}/model_at_train_end.pt"))
model.eval()

dict_keys = ['filename', 'predicted_dx', 'actual_dx']
dict_keys = np.concatenate((dict_keys, [f'actual_{cx}' for cx in characteristics]))
dict_keys = np.concatenate((dict_keys, [f'predicted_{cx}' for cx in characteristics]))
validation_dict = dict.fromkeys(dict_keys, np.array([]))

grad_cam_dir = f'{experiment_fold_dir}/gradcams'
if not os.path.exists(grad_cam_dir):
    os.mkdir(grad_cam_dir)

for batch in dataloader_test:
    images, dx_labels, cx_labels, mask_labels, filenames = batch
    dx_labels = [diagnoses[dx] for dx in dx_labels]
    
    predicted_dx, predicted_cx, predicted_attributes = model(images.to(device))
    predicted_dx = [diagnoses[dx] for dx in np.argmax(predicted_dx.detach().cpu().numpy(), axis=1)]
    predicted_cx = np.round(torch.sigmoid(predicted_cx).detach().cpu().numpy())
    predicted_attributes = predicted_attributes.detach().cpu().numpy()
    
    validation_dict['filename'] = np.concatenate((validation_dict['filename'], [Path(filename).name for filename in filenames]))
    validation_dict['predicted_dx'] = np.concatenate((validation_dict['predicted_dx'], predicted_dx))    
    validation_dict['actual_dx'] = np.concatenate((validation_dict['actual_dx'], dx_labels))
    
    for cx_idx in range(len(characteristics)):
        validation_dict[f'predicted_{characteristics[cx_idx]}'] = np.concatenate((validation_dict[f'predicted_{characteristics[cx_idx]}'], predicted_cx[:, cx_idx]))    
        validation_dict[f'actual_{characteristics[cx_idx]}'] = np.concatenate((validation_dict[f'actual_{characteristics[cx_idx]}'], cx_labels[:, cx_idx]))
        
        for image_idx in range(len(images)):
            mask_save_name = f'{grad_cam_dir}/{Path(filenames[image_idx]).name}_{characteristics[cx_idx]}.png'
            cv2.imwrite(mask_save_name, predicted_attributes[image_idx,cx_idx,::] * 255)

In [None]:
df = pd.DataFrame.from_dict(validation_dict).set_index('filename')
df.to_csv(f'{experiment_fold_dir}/final_model_predictions.csv')
df