In [1]:
%cd ..

/fs01/home/abbasgln/codes/medAI/projects/tta


In [2]:
import os
from dotenv import load_dotenv
# Loading environment variables
load_dotenv()

import torch
import torch.nn as nn
import typing as tp
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
import logging
import wandb

import medAI
from medAI.utils.setup import BasicExperiment, BasicExperimentConfig

from utils.metrics import MetricCalculator, brier_score, expected_calibration_error

from timm.optim.optim_factory import create_optimizer

from einops import rearrange, repeat
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import timm

from copy import copy, deepcopy
import pandas as pd

from datasets.datasets import ExactNCT2013RFImagePatches
from medAI.datasets.nct2013 import (
    KFoldCohortSelectionOptions,
    LeaveOneCenterOutCohortSelectionOptions, 
    PatchOptions
)


In [3]:
LEAVE_OUT='JH'

## Data Baseline

In [4]:
from baseline_experiment import BaselineConfig
config = BaselineConfig(
    instance_norm = True,
    cohort_selection_config=LeaveOneCenterOutCohortSelectionOptions(leave_out=f"{LEAVE_OUT}"),
    batch_size=64,
    )

from baseline_experiment import BaselineConfig
from torchvision.transforms import v2 as T
from torchvision.tv_tensors import Image as TVImage

class Transform:
    def __init__(selfT, augment=False):
        selfT.augment = augment
        selfT.size = (256, 256)
    
    def __call__(selfT, item):
        patch = item.pop("patch")
        patch = copy(patch)
        patch = (patch - patch.min()) / (patch.max() - patch.min()) \
            if config.instance_norm else patch
        patch = TVImage(patch)
        # patch = T.ToImage()(patch)
        # patch = T.ToTensor()(patch)
        patch = T.Resize(selfT.size, antialias=True)(patch).float()
        
        
        if selfT.augment:
            # Augment support patches
            transform = T.Compose([
                T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                T.RandomHorizontalFlip(p=0.5),
                T.RandomVerticalFlip(p=0.5),
            ])  
            patch = transform(patch)
        
        label = torch.tensor(item["grade"] != "Benign").long()
        return patch, label, item



# val_ds = ExactNCT2013RFImagePatches(
#     split="val",
#     transform=Transform(),
#     cohort_selection_options=config.cohort_selection_config,
#     patch_options=config.patch_config,
#     debug=config.debug,
# )
        
test_ds = ExactNCT2013RFImagePatches(
    split="test",
    transform=Transform(),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)


# val_loader = DataLoader(
#     val_ds_memo, batch_size=config.batch_size_test, shuffle=config.shffl_test, num_workers=4
# )
test_loader = DataLoader(
    test_ds, batch_size=config.batch_size, shuffle=False, num_workers=4
)



Computing positions: 100%|██████████| 1469/1469 [00:10<00:00, 141.20it/s]


## Model

In [5]:
use_batch_norm = True
num_ensembles = 10

from baseline_experiment import FeatureExtractorConfig

fe_config = FeatureExtractorConfig()

if use_batch_norm:
    norm_layer = nn.BatchNorm2d
else:
    norm_layer = lambda channels: nn.GroupNorm(
            num_groups=config.model_config.num_groups,
            num_channels=channels
            )

# Create the model
list_models: tp.List[nn.Module] = [timm.create_model(
    fe_config.model_name,
    num_classes=fe_config.num_classes,
    in_chans=1,
    features_only=fe_config.features_only,
    norm_layer=norm_layer
    ) for _ in range(num_ensembles)]

# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/baseline_gn_loco/baseline_gn_{LEAVE_OUT}_loco/checkpoints/', 'best_model.ckpt')
CHECkPOINT_PATH = os.path.join(
    os.getenv("DATA_ROOT"),
    f'checkpoint_store/Mahdi/ensemble_bn_{num_ensembles}mdls_inst-nrm_loco/ensemble_bn_{num_ensembles}mdls_inst-nrm_loco_{LEAVE_OUT}/',
    'best_model.ckpt'
    )

state = torch.load(CHECkPOINT_PATH)['list_models']
[model.load_state_dict(state[i]) for i, model in enumerate(list_models)]
[model.eval() for model in list_models]
[model.cuda() for model in list_models]

a = True

## Run test Ensemble

In [6]:
loader = test_loader

metric_calculator = MetricCalculator()
desc = "test"

for i, batch in enumerate(tqdm(loader, desc=desc)):
    images, labels, meta_data = batch
    images = images.cuda()
    labels = labels.cuda()
    
    # Forward pass
    logits = torch.stack([model(images) for model in list_models])
        
    # Update metrics   
    metric_calculator.update(
        batch_meta_data = meta_data,
        probs = nn.functional.softmax(logits, dim=-1).mean(dim=0).detach().cpu(), # Take mean over ensembles
        labels = labels.detach().cpu(),
    )

test:   0%|          | 0/858 [00:00<?, ?it/s]

In [7]:
# Log metrics every epoch
metrics = metric_calculator.get_metrics()

# Update best score
(
    best_score_updated,
    best_score
    ) = metric_calculator.update_best_score(metrics, desc)

best_score_updated = copy(best_score_updated)
best_score = copy(best_score)
        
# Log metrics
metrics_dict = {
    f"{desc}/{key}": value for key, value in metrics.items()
    }
metrics_dict.update(best_score) if desc == "val" else None 

metrics_dict

{'test/patch_auroc': tensor(0.6473),
 'test/patch_accuracy': tensor(0.5137),
 'test/all_inv_patch_auroc': tensor(0.5943),
 'test/all_inv_patch_accuracy': tensor(0.5168),
 'test/core_auroc': tensor(0.7717),
 'test/core_accuracy': tensor(0.5221),
 'test/all_inv_core_auroc': tensor(0.6823),
 'test/all_inv_core_accuracy': tensor(0.5308)}

## Get core and patch probs

In [8]:
high_core_ids = metric_calculator.remove_low_inv_ids(metric_calculator.core_id_invs)
ids = high_core_ids # metric_calculator.core_id_invs

patch_probs = torch.cat(
    [torch.stack(probs_list) for id, probs_list in metric_calculator.core_id_probs.items() if id in ids]
    )
patch_labels = torch.cat(
    [torch.tensor(labels_list) for id, labels_list in metric_calculator.core_id_labels.items() if id in ids]
    )

core_probs = torch.stack(
    [torch.stack(probs_list).argmax(dim=1).mean(dim=0, dtype=torch.float32)
    for id, probs_list in metric_calculator.core_id_probs.items() if id in ids])
core_probs = torch.cat([(1 - core_probs).unsqueeze(1), core_probs.unsqueeze(1)], dim=1)

core_labels = torch.stack(
    [labels_list[0] for id, labels_list in metric_calculator.core_id_labels.items() if id in ids])

## Sensitivity and Specificity

In [9]:
import torchmetrics
tpos, fpos, tneg, fneg, _support = torchmetrics.functional.stat_scores(preds=patch_probs.argmax(dim=-1), target=patch_labels, task="binary")
patch_sensitivity = tpos / (tpos + fneg)
patch_specificity = tneg / (tneg + fpos)

tpos, fpos, tneg, fneg, _support = torchmetrics.functional.stat_scores(preds=core_probs.argmax(dim=-1), target=core_labels, task="binary")
core_sensitivity = tpos / (tpos + fneg)
core_specificity = tneg / (tneg + fpos)

metrics_dict.update({
    f"{desc}/patch_sensitivity": patch_sensitivity,
    f"{desc}/patch_specificity": patch_specificity,
    f"{desc}/core_sensitivity": core_sensitivity,
    f"{desc}/core_specificity": core_specificity,
    })
metrics_dict

{'test/patch_auroc': tensor(0.6473),
 'test/patch_accuracy': tensor(0.5137),
 'test/all_inv_patch_auroc': tensor(0.5943),
 'test/all_inv_patch_accuracy': tensor(0.5168),
 'test/core_auroc': tensor(0.7717),
 'test/core_accuracy': tensor(0.5221),
 'test/all_inv_core_auroc': tensor(0.6823),
 'test/all_inv_core_accuracy': tensor(0.5308),
 'test/patch_sensitivity': tensor(0.7100),
 'test/patch_specificity': tensor(0.4973),
 'test/core_sensitivity': tensor(0.8519),
 'test/core_specificity': tensor(0.4931)}

## Brier and ECE metrics

In [10]:
patch_probs_1d = patch_probs[range(len(patch_labels)), patch_labels]
patch_brier = brier_score(patch_probs_1d.numpy(), patch_labels.numpy())
patch_ece, _ = expected_calibration_error(preds=patch_probs.argmax(dim=-1).numpy(), confidence=patch_probs_1d.numpy(), targets=patch_labels.numpy())
metrics_dict.update({
    f"{desc}/patch_brier": patch_brier,
    f"{desc}/patch_ece": patch_ece,
})

In [11]:
core_probs_1d = core_probs[range(len(core_labels)), core_labels]
core_brier = brier_score(core_probs_1d.numpy(), core_labels.numpy())
core_ece, _ = expected_calibration_error(preds=core_probs.argmax(dim=-1).numpy(), confidence=core_probs_1d.numpy(), targets=core_labels.numpy())
metrics_dict.update({
    f"{desc}/core_brier": core_brier,
    f"{desc}/core_ece": core_ece,
})
metrics_dict

{'test/patch_auroc': tensor(0.6473),
 'test/patch_accuracy': tensor(0.5137),
 'test/all_inv_patch_auroc': tensor(0.5943),
 'test/all_inv_patch_accuracy': tensor(0.5168),
 'test/core_auroc': tensor(0.7717),
 'test/core_accuracy': tensor(0.5221),
 'test/all_inv_core_auroc': tensor(0.6823),
 'test/all_inv_core_accuracy': tensor(0.5308),
 'test/patch_sensitivity': tensor(0.7100),
 'test/patch_specificity': tensor(0.4973),
 'test/core_sensitivity': tensor(0.8519),
 'test/core_specificity': tensor(0.4931),
 'test/patch_brier': 0.4784570948855554,
 'test/patch_ece': 0.2884738316691325,
 'test/core_brier': 0.3671951286505839,
 'test/core_ece': 0.3344730977329144}

In [15]:
import wandb
group=f"offline_ensemble_bn_{num_ensembles}mdls_inst-nrm_loco"
name=group + f"_{LEAVE_OUT}"
wandb.init(project="tta", entity="mahdigilany", name=name, group=group)


In [16]:
metrics_dict.update({"epoch": 0})
wandb.log(
    metrics_dict,
    )
wandb.finish()



VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁
test/all_inv_core_accuracy,▁
test/all_inv_core_auroc,▁
test/all_inv_patch_accuracy,▁
test/all_inv_patch_auroc,▁
test/core_accuracy,▁
test/core_auroc,▁
test/core_brier,▁
test/core_ece,▁
test/core_sensitivity,▁

0,1
epoch,0.0
test/all_inv_core_accuracy,0.53084
test/all_inv_core_auroc,0.68232
test/all_inv_patch_accuracy,0.51676
test/all_inv_patch_auroc,0.59434
test/core_accuracy,0.52206
test/core_auroc,0.77173
test/core_brier,0.3672
test/core_ece,0.33447
test/core_sensitivity,0.85185
