In [1]:
%cd ..

/fs01/home/abbasgln/codes/medAI/projects/tta


In [2]:
import os
from dotenv import load_dotenv
# Loading environment variables
load_dotenv()

import torch
import torch.nn as nn
import torch.nn.functional as F
import typing as tp
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
import logging
import wandb

import medAI
from medAI.utils.setup import BasicExperiment, BasicExperimentConfig

from utils.metrics import MetricCalculator

from timm.optim.optim_factory import create_optimizer

from einops import rearrange, repeat
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import timm

from copy import copy, deepcopy
import pandas as pd

from datasets.datasets import ExactNCT2013RFImagePatches
from medAI.datasets.nct2013 import (
    KFoldCohortSelectionOptions,
    LeaveOneCenterOutCohortSelectionOptions, 
    PatchOptions
)


In [3]:
LEAVE_OUT='JH'

## Data MEMO

In [4]:
###### No support dataset ######

from ensemble_experiment import EnsembleConfig
config = EnsembleConfig(cohort_selection_config=LeaveOneCenterOutCohortSelectionOptions(leave_out=f"{LEAVE_OUT}"),
                        # patch_config=PatchOptions(needle_mask_threshold=0.6, prostate_mask_threshold=0.9, patch_size_mm = (3,3), strides = (1.2,1.2))
)

from baseline_experiment import BaselineConfig
from torchvision.transforms import v2 as T
from torchvision.tv_tensors import Image as TVImage

class Transform:
    def __init__(selfT, augment=False):
        selfT.augment = augment
        selfT.size = (256, 256)
        # Augmentation
        selfT.transform = T.Compose([
            T.RandomAffine(degrees=0, translate=(0.2, 0.2)),
            T.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0.5),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
        ])  
    
    def __call__(selfT, item):
        patch = item.pop("patch")
        patch = copy(patch)
        patch = (patch - patch.min()) / (patch.max() - patch.min()) \
            if config.instance_norm else patch
        patch = TVImage(patch)
        patch_orig_size = patch
        patch = T.Resize(selfT.size, antialias=True)(patch).float()
        
        label = torch.tensor(item["grade"] != "Benign").long()
        
        if selfT.augment:
            patch_augs = torch.stack([selfT.transform(patch) for _ in range(2)], dim=0)
            return patch_augs, patch, label, item #, patch_orig_size
        
        return -1, patch, label, item


# val_ds_memo = ExactNCT2013RFImagePatches(
#     split="val",
#     transform=Transform(augment=True),
#     cohort_selection_options=config.cohort_selection_config,
#     patch_options=config.patch_config,
#     debug=config.debug,
# )
    
if (config.cohort_selection_config.leave_out == "UVA"):
    config.cohort_selection_config.benign_to_cancer_ratio = 5.0     

test_ds_memo = ExactNCT2013RFImagePatches(
    split="test",
    transform=Transform(augment=True),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)


# val_loader_memo = DataLoader(
#     val_ds_memo, batch_size=32, shuffle=True, num_workers=4
# )
test_loader_memo = DataLoader(
    test_ds_memo, batch_size=32, shuffle=True, num_workers=4
)



Computing positions: 100%|██████████| 616/616 [00:04<00:00, 139.55it/s]


In [5]:
batch = test_ds_memo[0]
len(test_ds_memo)

23214

In [8]:
batch = next(iter(test_loader_memo))

In [6]:
# batch[-1].shape, batch[1].shape

In [7]:
# plt.imshow(batch[1][:,32:64,:].permute(1, 2, 0).numpy(), aspect='auto')
# plt.show()
# plt.imshow(batch[-1][:,56:112,:].permute(1, 2, 0).numpy(), aspect='auto')

## Model

In [8]:
from baseline_experiment import FeatureExtractorConfig

fe_config = FeatureExtractorConfig()

# Create the model
list_models: tp.List[nn.Module] = [timm.create_model(
    fe_config.model_name,
    num_classes=fe_config.num_classes,
    in_chans=1,
    features_only=fe_config.features_only,
    norm_layer=lambda channels: nn.GroupNorm(
                    num_groups=fe_config.num_groups,
                    num_channels=channels
                    )) for _ in range(5)]

CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/ensemble_5mdls_gn_3ratio_loco/ensemble_5mdls_gn_3ratio_loco_{LEAVE_OUT}/', 'best_model.ckpt')

state = torch.load(CHECkPOINT_PATH)
[model.load_state_dict(state["list_models"][i]) for i, model in enumerate(list_models)]

[model.eval() for model in list_models]
[model.cuda() for model in list_models]

a = True

In [9]:
# # Turn requires_grad off for all layers except the last one
# for model in list_models:
#     for name, params in model.named_parameters():
#         if name != "fc.weight" and name != "fc.bias":
#                 params.requires_grad_(False)
#                 # print(name)
#                 # print(params)
        

## Tempreture Scaling

In [None]:
loader = val_loader_memo

metric_calculator = MetricCalculator()
desc = "val"


temp = torch.tensor(1.0).cuda().requires_grad_(True)
beta = torch.tensor(0.0).cuda().requires_grad_(True)


params = [temp, beta]
_optimizer = optim.Adam(params, lr=1e-3)

for epoch in range(1):
    metric_calculator.reset()
    for i, batch in enumerate(tqdm(loader, desc=desc)):
        images_augs, images, labels, meta_data = batch
        images = images.cuda()
        labels = labels.cuda()
        

        # Evaluate
        with torch.no_grad():
            stacked_logits = torch.stack([model(images) for model in list_models])
        scaled_stacked_logits = stacked_logits/ temp + beta
        losses = [nn.CrossEntropyLoss()(
            scaled_stacked_logits[i, ...],
            labels
            ) for i in range(5)
        ]
        
        # optimize
        _optimizer.zero_grad()
        sum(losses).backward()
        _optimizer.step()
                    
        # Update metrics   
        metric_calculator.update(
            batch_meta_data = meta_data,
            probs = F.softmax(scaled_stacked_logits, dim=-1).mean(dim=0).detach().cpu(), # Take mean over ensembles
            labels = labels.detach().cpu(),
        )

val:   0%|          | 0/1269 [00:00<?, ?it/s]

In [None]:
temp, beta

(tensor(1.5950, device='cuda:0', requires_grad=True),
 tensor(-0.8514, device='cuda:0', requires_grad=True))

In [None]:
# # JH
# temp = torch.tensor(1.6793).cuda()
# beta = torch.tensor(-1.0168).cuda()

# PCC
temp = torch.tensor(1.5950).cuda()
beta = torch.tensor(-0.8514).cuda()


## Run test MEMO

### Separate MEMO

In [None]:
loader = test_loader_memo
enable_memo = True
certain_threshold = 0.8

from memo_experiment import batched_marginal_entropy
metric_calculator = MetricCalculator()
desc = "test"

criterion = nn.CrossEntropyLoss()

for i, batch in enumerate(tqdm(loader, desc=desc)):
    images_augs, images, labels, meta_data = batch
    images_augs = images_augs.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    adaptation_model_list = [deepcopy(model) for model in list_models] 
    [model.eval() for model in adaptation_model_list]
    
    if enable_memo:
        batch_size, aug_size= images_augs.shape[0], images_augs.shape[1]

        params = []
        for model in adaptation_model_list:
            params.append({'params': model.parameters()})
        optimizer = optim.SGD(params, lr=5e-4)
        
        _images_augs = images_augs.reshape(-1, *images_augs.shape[2:]).cuda()
        # Adapt to test
        for j in range(1):
            optimizer.zero_grad()
            # Forward pass
            stacked_logits = torch.stack([model(_images_augs).reshape(batch_size, aug_size, -1) for model in adaptation_model_list])
            
            # Remove uncertain samples from test-time adaptation
            certain_idx = F.softmax(stacked_logits, dim=-1).mean(dim=0).mean(dim=1).max(dim=-1)[0] >= certain_threshold
            stacked_logits = stacked_logits[:, certain_idx, ...]
            
            list_losses = []
            list_logits = []
            for k in range(5):
                loss, logits = batched_marginal_entropy(stacked_logits[k,...])
                list_losses.append(loss.mean())
                list_logits.append(logits)
            # Backward pass
            sum(list_losses).backward()
            optimizer.step()
    
    # Evaluate
    logits = torch.stack([model(images) for model in adaptation_model_list])
    losses = [nn.CrossEntropyLoss()(
        logits[i, ...],
        labels
        ) for i in range(5)
    ]
                    
    # Update metrics   
    metric_calculator.update(
        batch_meta_data = meta_data,
        probs = F.softmax(logits, dim=-1).mean(dim=0).detach().cpu(), # Take mean over ensembles
        labels = labels.detach().cpu(),
    )
    

test:   0%|          | 0/726 [00:00<?, ?it/s]

### Combined MEMO

In [None]:
loader = test_loader_memo
enable_memo = True

from memo_experiment import batched_marginal_entropy
metric_calculator = MetricCalculator()
desc = "test"

criterion = nn.CrossEntropyLoss()

for i, batch in enumerate(tqdm(loader, desc=desc)):
    images_augs, images, labels, meta_data = batch
    images_augs = images_augs.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    adaptation_model_list = [deepcopy(model) for model in list_models] 
    [model.eval() for model in adaptation_model_list]
    
    if enable_memo:
        batch_size, aug_size= images_augs.shape[0], images_augs.shape[1]

        params = []
        for model in adaptation_model_list:
            params.append({'params': model.parameters()})
        optimizer = optim.SGD(params, lr=5e-4)
        
        _images_augs = images_augs.reshape(-1, *images_augs.shape[2:]).cuda()
        # Adapt to test
        for j in range(1):
            optimizer.zero_grad()
            # Forward pass
            stacked_outputs = torch.stack([model(_images_augs).reshape(batch_size, aug_size, -1) for model in adaptation_model_list])
            
            # for outputs in len(adaptation_model_list):
            #     loss, logits = batched_marginal_entropy(outputs)
            #     list_losses.append(loss.mean())
            #     list_logits.append(logits)
            # # Backward pass
            # sum(list_losses).backward()
            # optimizer.step()
            
            loss, logits = batched_marginal_entropy(stacked_outputs.mean(dim=0))
            # Backward pass
            loss.mean().backward()
            optimizer.step()
    
    # Evaluate
    logits = torch.stack([model(images) for model in adaptation_model_list])
    losses = [nn.CrossEntropyLoss()(
        logits[i, ...],
        labels
        ) for i in range(5)
    ]
                    
    # Update metrics   
    metric_calculator.update(
        batch_meta_data = meta_data,
        probs = F.softmax(logits, dim=-1).mean(dim=0).detach().cpu(), # Take mean over ensembles
        labels = labels.detach().cpu(),
    )

test:   0%|          | 0/726 [00:00<?, ?it/s]

### Get metrrics

In [None]:
avg_core_probs_first = False
metric_calculator.avg_core_probs_first = avg_core_probs_first

# Log metrics every epoch
metrics = metric_calculator.get_metrics()

# Update best score
(best_score_updated,best_score) = metric_calculator.update_best_score(metrics, desc)

best_score_updated = copy(best_score_updated)
best_score = copy(best_score)
        
# Log metrics
metrics_dict = {
    f"{desc}/{key}": value for key, value in metrics.items()
    }

metrics_dict

{'test/patch_auroc': tensor(0.6533),
 'test/patch_accuracy': tensor(0.7716),
 'test/all_inv_patch_auroc': tensor(0.6178),
 'test/all_inv_patch_accuracy': tensor(0.7589),
 'test/core_auroc': tensor(0.7585),
 'test/core_accuracy': tensor(0.8788),
 'test/all_inv_core_auroc': tensor(0.7104),
 'test/all_inv_core_accuracy': tensor(0.8602)}

## Spliting test for poc of pseudo labeling

In [None]:
from torch.utils.data import Subset

train_indices = range(0, len(test_ds_memo) // 2)
test_indices = range(len(test_ds_memo) // 2, len(test_ds_memo))

# split test_ds_memo into two
test_train = Subset(test_ds_memo, train_indices)
test_test = Subset(test_ds_memo, test_indices)


test_train_loader = DataLoader(
    test_train, batch_size=64, shuffle=True, num_workers=4
)
test_test_loader = DataLoader(
    test_test, batch_size=32, shuffle=False, num_workers=4
)


In [10]:
# loader = test_test_loader
loader = test_loader_memo
enable_pseudo_label = False
temp_scale = False
certain_threshold = 0.8

metric_calculator = MetricCalculator()
desc = "test"

for i, batch in enumerate(tqdm(loader, desc=desc)):
    images_augs, images, labels, meta_data = batch
    # images_augs = images_augs.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    adaptation_model_list = [deepcopy(model) for model in list_models] 
    [model.eval() for model in adaptation_model_list]

    
    if enable_pseudo_label:
        params = []
        for model in adaptation_model_list:
            params.append({'params': model.parameters()})
        optimizer = optim.SGD(params, lr=5e-4)
        
        # Adapt to test
        for j in range(1):
            optimizer.zero_grad()
            # Forward pass
            stacked_logits = torch.stack([model(images) for model in adaptation_model_list])
            if temp_scale:
                stacked_logits = stacked_logits / temp + beta
            
            # Remove uncertain samples from test-time adaptation
            certain_idx = F.softmax(stacked_logits, dim=-1).mean(dim=0).max(dim=-1)[0] >= certain_threshold
            stacked_logits = stacked_logits[:, certain_idx, ...]
            
            list_losses = []
            for k, outputs in enumerate(adaptation_model_list):
                loss = nn.CrossEntropyLoss()(stacked_logits[k, ...], F.softmax(stacked_logits, dim=-1).mean(dim=0).argmax(dim=-1))
                list_losses.append(loss.mean())
            # Backward pass
            sum(list_losses).backward()
            optimizer.step()
        
    # Evaluate
    logits = torch.stack([model(images) for model in adaptation_model_list])
    if temp_scale:
        logits = logits / temp + beta
    losses = [nn.CrossEntropyLoss()(
        logits[i, ...],
        labels
        ) for i in range(5)
    ]
                    
    # Update metrics   
    metric_calculator.update(
        batch_meta_data = meta_data,
        probs = F.softmax(logits, dim=-1).mean(dim=0).detach().cpu(), # Take mean over ensembles
        labels = labels.detach().cpu(),
    )

test:   0%|          | 0/1990 [00:00<?, ?it/s]

In [14]:
avg_core_probs_first = True
metric_calculator.avg_core_probs_first = avg_core_probs_first

# Log metrics every epoch
metrics = metric_calculator.get_metrics()

# Update best score
(best_score_updated,best_score) = metric_calculator.update_best_score(metrics, desc)

best_score_updated = copy(best_score_updated)
best_score = copy(best_score)
        
# Log metrics
metrics_dict = {
    f"{desc}/{key}": value for key, value in metrics.items()
    }

metrics_dict

{'test/patch_auroc': tensor(0.6299),
 'test/patch_accuracy': tensor(0.7978),
 'test/all_inv_patch_auroc': tensor(0.6018),
 'test/all_inv_patch_accuracy': tensor(0.7697),
 'test/core_auroc': tensor(0.7187),
 'test/core_accuracy': tensor(0.8981),
 'test/all_inv_core_auroc': tensor(0.6690),
 'test/all_inv_core_accuracy': tensor(0.8571)}

## Testing 10% max probs

In [24]:
## core probs and labels

ids = metric_calculator.remove_low_inv_ids(metric_calculator.core_id_invs)
# ids = list(metric_calculator.core_id_probs.keys())

probs_avg_max = []
for id, probs_list in metric_calculator.core_id_probs.items():
    if id in ids:
        core_len_10th = len(probs_list) // 5 + 1
        probs, labels = torch.stack(probs_list).max(dim=1)
        sorted_probs, indx = probs.sort()
        sorted_probs = sorted_probs[-core_len_10th:]
        indx = indx[-core_len_10th:]
        labels_sorted_probs = labels[indx]
        probs_avg_max.append(sum(sorted_probs*labels_sorted_probs)/core_len_10th)
probs_avg_max = torch.stack(probs_avg_max)

labels = torch.stack(
    [labels_list[0] for id, labels_list in metric_calculator.core_id_labels.items() if id in ids])


In [25]:
import torchmetrics
torchmetrics.functional.auroc(probs_avg_max, labels, task="binary"), torchmetrics.functional.accuracy(probs_avg_max, labels, task="binary")

(tensor(0.6683), tensor(0.8987))

## WabdB Log

In [None]:
import wandb
# group=f"offline_combEnsmPsdo_.8uncrtnty_gn_3ratio_loco"
group=f"offline_combEnsmPsdo_avgprob_.8uncrtnty_gn_3ratio_loco"
# group=f"offline_ensemble_avgprob_5mdls_gn_3ratio_loco"
name= group + f"_{LEAVE_OUT}"
wandb.init(project="tta", entity="mahdigilany", name=name, group=group)


In [None]:
metrics_dict.update({"epoch": 0})
wandb.log(
    metrics_dict,
    )
wandb.finish()



VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁
test/all_inv_core_accuracy,▁
test/all_inv_core_auroc,▁
test/all_inv_patch_accuracy,▁
test/all_inv_patch_auroc,▁
test/core_accuracy,▁
test/core_auroc,▁
test/patch_accuracy,▁
test/patch_auroc,▁

0,1
epoch,0.0
test/all_inv_core_accuracy,0.85777
test/all_inv_core_auroc,0.67289
test/all_inv_patch_accuracy,0.79639
test/all_inv_patch_auroc,0.61083
test/core_accuracy,0.8994
test/core_auroc,0.7264
test/patch_accuracy,0.8278
test/patch_auroc,0.64418
