In [1]:
%cd ..

/fs01/home/abbasgln/codes/medAI/projects/tta


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import os
from dotenv import load_dotenv
# Loading environment variables
load_dotenv()

import torch
import torch.nn as nn
import typing as tp
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
import logging
import wandb

import medAI
from medAI.utils.setup import BasicExperiment, BasicExperimentConfig

from utils.metrics import MetricCalculator

from timm.optim.optim_factory import create_optimizer

from einops import rearrange, repeat
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import timm

from copy import copy, deepcopy
import pandas as pd

from datasets.datasets import ExactNCT2013RFImagePatches
from medAI.datasets.nct2013 import (
    KFoldCohortSelectionOptions,
    LeaveOneCenterOutCohortSelectionOptions, 
    PatchOptions
)


In [3]:
LEAVE_OUT='JH'

## Data MEMO

In [4]:
###### No support dataset ######

from memo_experiment import MEMOConfig
config = MEMOConfig(cohort_selection_config=LeaveOneCenterOutCohortSelectionOptions(leave_out=f"{LEAVE_OUT}"))

from baseline_experiment import BaselineConfig
from torchvision.transforms import v2 as T
from torchvision.tv_tensors import Image as TVImage

class Transform:
    def __init__(selfT, augment=False):
        selfT.augment = augment
        selfT.size = (256, 256)
        # Augmentation
        selfT.transform = T.Compose([
            T.RandomAffine(degrees=0, translate=(0.2, 0.2)),
            T.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0.5),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
        ])  
    
    def __call__(selfT, item):
        patch = item.pop("patch")
        patch = copy(patch)
        patch = (patch - patch.min()) / (patch.max() - patch.min()) \
            if config.instance_norm else patch
        patch = TVImage(patch)
        patch = T.Resize(selfT.size, antialias=True)(patch).float()
        
        label = torch.tensor(item["grade"] != "Benign").long()
        
        if selfT.augment:
            patch_augs = torch.stack([selfT.transform(patch) for _ in range(5)], dim=0)
            return patch_augs, patch, label, item
        
        return -1, patch, label, item


val_ds_memo = ExactNCT2013RFImagePatches(
    split="val",
    transform=Transform(augment=True),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)
        
test_ds_memo = ExactNCT2013RFImagePatches(
    split="test",
    transform=Transform(augment=True),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)


val_loader_memo = DataLoader(
    val_ds_memo, batch_size=config.batch_size_test, shuffle=True, num_workers=4
)
test_loader_memo = DataLoader(
    test_ds_memo, batch_size=config.batch_size_test, shuffle=config.shffl_test, num_workers=4
)



Computing positions: 100%|██████████| 1215/1215 [00:15<00:00, 80.95it/s] 
Computing positions: 100%|██████████| 616/616 [00:08<00:00, 74.62it/s]


## Model

In [10]:
from baseline_experiment import FeatureExtractorConfig

fe_config = FeatureExtractorConfig()

# Create the model
model: nn.Module = timm.create_model(
    fe_config.model_name,
    num_classes=fe_config.num_classes,
    in_chans=1,
    features_only=fe_config.features_only,
    norm_layer=lambda channels: nn.GroupNorm(
                    num_groups=fe_config.num_groups,
                    num_channels=channels
                    ))

# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/baseline_gn_loco/baseline_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')
# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/memo_gn_loco/memo_gn_{LEAVE_OUT}_loco/checkpoints/', 'best_model.ckpt')
CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/baseline_gn_crtd3ratio_loco/baseline_gn_crtd3ratio_loco_{LEAVE_OUT}/', 'best_model.ckpt')

model.load_state_dict(torch.load(CHECkPOINT_PATH)['model'])
model.eval()
model.cuda()

a = True

In [11]:
# # Turn requires_grad off for all layers except the last one
# for name, params in model.named_parameters():
#     if name == "fc.weight" and name == "fc.bias":
#         params.requires_grad_(False)
#         # print(name)
#         # print(params)

## Tempreture Scaling

In [12]:
from medAI.utils import optimizer


loader = val_loader_memo

metric_calculator = MetricCalculator()
desc = "val"

model.eval()

temp = torch.tensor(1.0).cuda().requires_grad_(True)
beta = torch.tensor(0.0).cuda().requires_grad_(True)

params = [temp, beta]
_optimizer = optim.Adam(params, lr=1e-3)

for epoch in range(1):
    metric_calculator.reset()
    for i, batch in enumerate(tqdm(loader, desc=desc)):
        images_augs, images, labels, meta_data = batch
        images = images.cuda()
        labels = labels.cuda()
        

        # Evaluate
        with torch.no_grad():
            logits = model(images)
        scaled_logits = logits/temp + beta
        loss = nn.CrossEntropyLoss()(scaled_logits, labels)
        _optimizer.zero_grad()
        loss.backward()
        _optimizer.step()
                        
        # Update metrics   
        metric_calculator.update(
            batch_meta_data = meta_data,
            probs = nn.functional.softmax(scaled_logits, dim=-1).detach().cpu(),
            labels = labels.detach().cpu(),
        )

val:   0%|          | 0/1492 [00:00<?, ?it/s]

In [13]:
# JH
temp = torch.tensor(0.8034).cuda()
beta = torch.tensor(-0.5266).cuda()

(tensor(0.8034, device='cuda:0', requires_grad=True),
 tensor(-0.5266, device='cuda:0', requires_grad=True))

## Run test MEMO

In [21]:
loader = test_loader_memo
enable_memo = True
temp_scale = True

from memo_experiment import batched_marginal_entropy
metric_calculator = MetricCalculator()
desc = "test"

model.eval()
criterion = nn.CrossEntropyLoss()

for i, batch in enumerate(tqdm(loader, desc=desc)):
    images_augs, images, labels, meta_data = batch
    images_augs = images_augs.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    adaptation_model = deepcopy(model)
    adaptation_model.eval()
    
    if enable_memo:
        batch_size, aug_size= images_augs.shape[0], images_augs.shape[1]

        # Adapt to test
        _images_augs = images_augs.reshape(-1, *images_augs.shape[2:]).cuda()
        optimizer = optim.SGD(adaptation_model.parameters(), lr=5e-4)
        
        for j in range(1):
            optimizer.zero_grad()
            outputs = adaptation_model(_images_augs).reshape(batch_size, aug_size, -1)  
            if temp_scale:
                outputs = outputs / temp + beta
            loss, logits = batched_marginal_entropy(outputs)
            loss.mean().backward()
            optimizer.step()
    
    # Evaluate
    logits = adaptation_model(images)
    if temp_scale:
        logits = logits / temp + beta
    loss = criterion(logits, labels)
                    
    # Update metrics   
    metric_calculator.update(
        batch_meta_data = meta_data,
        probs = nn.functional.softmax(logits, dim=-1).detach().cpu(),
        labels = labels.detach().cpu(),
    )

test:   0%|          | 0/726 [00:00<?, ?it/s]

In [25]:
avg_core_probs_first = True
metric_calculator.avg_core_probs_first = avg_core_probs_first

# Log metrics every epoch
metrics = metric_calculator.get_metrics()

# Update best score
(best_score_updated,best_score) = metric_calculator.update_best_score(metrics, desc)

best_score_updated = copy(best_score_updated)
best_score = copy(best_score)
        
# Log metrics
metrics_dict = {
    f"{desc}/{key}": value for key, value in metrics.items()
    }

metrics_dict

{'test/patch_auroc': tensor(0.6440),
 'test/patch_accuracy': tensor(0.8820),
 'test/all_inv_patch_auroc': tensor(0.6012),
 'test/all_inv_patch_accuracy': tensor(0.8630),
 'test/core_auroc': tensor(0.7466),
 'test/core_accuracy': tensor(0.9276),
 'test/all_inv_core_auroc': tensor(0.6879),
 'test/all_inv_core_accuracy': tensor(0.9062)}

## Spliting test for poc of pseudo labeling

In [None]:
from torch.utils.data import Subset

train_indices = range(0, len(test_ds_memo) // 2)
test_indices = range(len(test_ds_memo) // 2, len(test_ds_memo))

# split test_ds_memo into two
test_train = Subset(test_ds_memo, train_indices)
test_test = Subset(test_ds_memo, test_indices)


test_train_loader = DataLoader(
    test_train, batch_size=64, shuffle=True, num_workers=4
)
test_test_loader = DataLoader(
    test_test, batch_size=32, shuffle=False, num_workers=4
)


In [None]:
# test_core_info = test_ds_memo.dataset.dataset.core_info
# train_test_core_info = test_core_info[test_core_info['id'] <= test_core_info.index[len(test_core_info) // 2]]
# labels = train_test_core_info['grade'] != "Benign"
# benign_ids = list(labels[labels == False].sample(len(labels[labels == True])).index)
# cancer_ids = list(labels[labels == True].index)
# balanced_ids = benign_ids + cancer_ids
# len(balanced_ids)

In [None]:
loader = test_train_loader
epochs = 2

metric_calculator = MetricCalculator()
desc = "train"
model.train()
optimizer = optim.SGD(model.parameters(), lr=1e-4)

for epoch in range(epochs):
    for i, batch in enumerate(tqdm(loader, desc=desc)):
        images_augs, images, labels, meta_data = batch
        
        benign_indx = np.where(labels == 0)[0]
        cancer_indx = np.where(labels == 1)[0]
        benign_indx = np.random.choice(benign_indx, len(cancer_indx))
        balanced_indices = np.concatenate([benign_indx, cancer_indx])
        # balanced_indices = [id in balanced_ids for id in meta_data['id']]     
        
        images = images[balanced_indices, ...].cuda()
        labels = labels[balanced_indices].cuda()
        meta_data['id'] = meta_data['id'][balanced_indices]
        meta_data['pct_cancer'] = meta_data['pct_cancer'][balanced_indices]
        
        # Train
        logits = model(images)
        loss = nn.CrossEntropyLoss()(logits, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
                        
        # Update metrics   
        metric_calculator.update(
            batch_meta_data = meta_data,
            probs = nn.functional.softmax(logits, dim=-1).detach().cpu(),
            labels = labels.detach().cpu(),
        )

In [26]:
# loader = test_test_loader
loader = test_loader_memo
enable_pseudo_label = True
temp_scale = True

metric_calculator = MetricCalculator()
desc = "test"
model.eval()

for i, batch in enumerate(tqdm(loader, desc=desc)):
    images_augs, images, labels, meta_data = batch
    # images_augs = images_augs.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    adaptation_model = deepcopy(model)
    adaptation_model.eval()
    
    if enable_pseudo_label:
        optimizer = optim.SGD(adaptation_model.parameters(), lr=5e-4)
        for j in range(1):
            optimizer.zero_grad()
            logits = adaptation_model(images)
            if temp_scale:
                logits = logits / temp + beta
            loss = nn.CrossEntropyLoss()(logits, logits.argmax(dim=-1))
            loss.backward()
            optimizer.step()
    
    # Train
    logits = adaptation_model(images)
    if temp_scale:
        logits = logits / temp + beta
    loss = nn.CrossEntropyLoss()(logits, labels)
    
                    
    # Update metrics   
    metric_calculator.update(
        batch_meta_data = meta_data,
        probs = nn.functional.softmax(logits, dim=-1).detach().cpu(),
        labels = labels.detach().cpu(),
    )

test:   0%|          | 0/726 [00:00<?, ?it/s]

In [27]:
avg_core_probs_first = True
metric_calculator.avg_core_probs_first = avg_core_probs_first

# Log metrics every epoch
metrics = metric_calculator.get_metrics()

# Update best score
(best_score_updated,best_score) = metric_calculator.update_best_score(metrics, desc)

best_score_updated = copy(best_score_updated)
best_score = copy(best_score)
        
# Log metrics
metrics_dict = {
    f"{desc}/{key}": value for key, value in metrics.items()
    }

metrics_dict

{'test/patch_auroc': tensor(0.6474),
 'test/patch_accuracy': tensor(0.8810),
 'test/all_inv_patch_auroc': tensor(0.6046),
 'test/all_inv_patch_accuracy': tensor(0.8620),
 'test/core_auroc': tensor(0.7470),
 'test/core_accuracy': tensor(0.9192),
 'test/all_inv_core_auroc': tensor(0.6925),
 'test/all_inv_core_accuracy': tensor(0.8980)}

In [13]:
# core probs and labels
ids = metric_calculator.remove_low_inv_ids(metric_calculator.core_id_invs)
# ids = list(metric_calculator.core_id_probs.keys())
probs = torch.stack(
    [torch.stack(probs_list).argmax(dim=1).mean(dim=0, dtype=torch.float32)
    for id, probs_list in metric_calculator.core_id_probs.items() if id in ids])
probs = torch.cat([(1 - probs).unsqueeze(1), probs.unsqueeze(1)], dim=1)

probs2 = torch.stack(
    [torch.stack(probs_list).mean(dim=0) for id, probs_list in metric_calculator.core_id_probs.items() if id in ids])  

labels = torch.stack(
    [labels_list[0] for id, labels_list in metric_calculator.core_id_labels.items() if id in ids])


In [15]:
metric_calculator._get_metrics(
            probs2, 
            labels, 
            prefix="all_inv_core_" if ids is None else "core_"
            )

{'core_auroc': tensor(0.7463), 'core_accuracy': tensor(0.9175)}

## WabdB Log

In [None]:
import wandb
group=f"offline_pslabel_gn_3ratio_loco"
name= group + f"_{LEAVE_OUT}"
wandb.init(project="tta", entity="mahdigilany", name=name, group=group)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mmahdigilany[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
metrics_dict.update({"epoch": 0})
wandb.log(
    metrics_dict,
    )
wandb.finish()



VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁
test/all_inv_core_accuracy,▁
test/all_inv_core_auroc,▁
test/all_inv_patch_accuracy,▁
test/all_inv_patch_auroc,▁
test/core_accuracy,▁
test/core_auroc,▁
test/patch_accuracy,▁
test/patch_auroc,▁

0,1
epoch,0.0
test/all_inv_core_accuracy,0.90625
test/all_inv_core_auroc,0.75075
test/all_inv_patch_accuracy,0.86461
test/all_inv_patch_auroc,0.63496
test/core_accuracy,0.92761
test/core_auroc,0.79331
test/patch_accuracy,0.88328
test/patch_auroc,0.67782
