In [1]:
%cd ..

/fs01/home/abbasgln/codes/medAI/projects/tta


In [2]:
import os
from dotenv import load_dotenv
# Loading environment variables
load_dotenv()

import torch
import torch.nn as nn
import typing as tp
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
import logging
import wandb

import medAI
from medAI.utils.setup import BasicExperiment, BasicExperimentConfig

from utils.metrics import MetricCalculator

from timm.optim.optim_factory import create_optimizer

from einops import rearrange, repeat
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import timm

from copy import copy, deepcopy
import pandas as pd

from datasets.datasets import ExactNCT2013RFImagePatches, ExactNCT2013RFPatchesWithSupportPatches, SupportPatchConfig
from medAI.datasets.nct2013 import (
    KFoldCohortSelectionOptions,
    LeaveOneCenterOutCohortSelectionOptions, 
    PatchOptions,
)


In [3]:
LEAVE_OUT='PCC'

## Data MEMO

In [4]:
###### With support dataset ######
num_support_patches = 2
include_query_patch = False


from memo_experiment import MEMOConfig
config = MEMOConfig(cohort_selection_config=LeaveOneCenterOutCohortSelectionOptions(leave_out=f"{LEAVE_OUT}"))

from baseline_experiment import BaselineConfig
from torchvision.transforms import v2 as T
from torchvision.tv_tensors import Image as TVImage

class Transform:
    def __init__(selfT, augment=False):
        selfT.augment = augment
        selfT.size = (256, 256)
        # Augment support patches
        selfT.transform = T.Compose([
            T.RandomAffine(degrees=0, translate=(0.2, 0.2)),
            T.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0.5),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
        ])   
    
    def __call__(selfT, item):
        patch = item.pop("patch")
        patch = copy(patch)
        patch = (patch - patch.min()) / (patch.max() - patch.min()) \
            if config.instance_norm else patch
        patch = TVImage(patch)
        patch = T.Resize(selfT.size, antialias=True)(patch).float()

        # Support patches
        support_patches = item.pop("support_patches")
        support_patches = copy(support_patches)
        # Normalize support patches along last two dimensions
        support_patches = (support_patches - support_patches.min(axis=(1, 2), keepdims=True)) \
        / (support_patches.max(axis=(1,2), keepdims=True) \
            - support_patches.min(axis=(1, 2), keepdims=True)) if config.instance_norm else support_patches
        support_patches = TVImage(support_patches)
        support_patches = T.Resize(selfT.size, antialias=True)(support_patches).float()
        

        support_patches_aug = torch.stack([selfT.transform(support_patches) for _ in range(5)], dim=0)
        
        if selfT.augment:
            patch = selfT.transform(patch)
        
        label = torch.tensor(item["grade"] != "Benign").long()
        return support_patches_aug, patch, label, item


test_ds_memo = ExactNCT2013RFPatchesWithSupportPatches(
            split="test",
            transform=Transform(),
            cohort_selection_options=config.cohort_selection_config,
            patch_options=config.patch_config,
            support_patch_config=SupportPatchConfig(
                num_support_patches=num_support_patches,
                include_query_patch=include_query_patch
            ),
            debug=config.debug,
        )

test_loader_memo = DataLoader(
    test_ds_memo, batch_size=config.batch_size_test, shuffle=config.shffl_test, num_workers=4
)

Computing positions test: 100%|██████████| 1599/1599 [02:04<00:00, 12.81it/s]


## Model

In [5]:
from baseline_experiment import FeatureExtractorConfig

fe_config = FeatureExtractorConfig()

# Create the model
model: nn.Module = timm.create_model(
    fe_config.model_name,
    num_classes=fe_config.num_classes,
    in_chans=1,
    features_only=fe_config.features_only,
    norm_layer=lambda channels: nn.GroupNorm(
                    num_groups=fe_config.num_groups,
                    num_channels=channels
                    ))

CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/baseline_gn_loco/baseline_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')
# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/memo_gn_loco/memo_gn_{LEAVE_OUT}_loco/checkpoints/', 'best_model.ckpt')

model.load_state_dict(torch.load(CHECkPOINT_PATH)['model'])
model.eval()
model.cuda()

a = True

## Run test MEMO

In [6]:
loader = test_loader_memo


from memo_experiment import batched_marginal_entropy
metric_calculator = MetricCalculator()
desc = "test"

model.eval()
criterion = nn.CrossEntropyLoss()

for i, batch in enumerate(tqdm(loader, desc=desc)):
    images_suprt_aug, images, labels, meta_data = batch
    images_suprt_aug = images_suprt_aug.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    batch_size, aug_size, num_support_patches = images_suprt_aug.shape[0], images_suprt_aug.shape[1], images_suprt_aug.shape[2]

    # Adapt to test
    _images_suprt_aug = images_suprt_aug.reshape(-1, 1, *images_suprt_aug.shape[3:]).cuda()
    adaptation_model = deepcopy(model)
    adaptation_model.eval()
    optimizer = optim.SGD(adaptation_model.parameters(), lr=1e-3)
    
    for j in range(config.adaptation_steps):
        optimizer.zero_grad()
        outputs = adaptation_model(_images_suprt_aug).reshape(num_support_patches, aug_size, -1)  
        loss, logits = batched_marginal_entropy(outputs)
        loss.mean().backward()
        optimizer.step()
    
    # Evaluate
    logits = adaptation_model(images)
    loss = criterion(logits, labels)
                    
    # Update metrics   
    metric_calculator.update(
        batch_meta_data = meta_data,
        probs = nn.functional.softmax(logits, dim=-1).detach().cpu(),
        labels = labels.detach().cpu(),
    )

test:   0%|          | 0/1990 [00:00<?, ?it/s]

In [None]:
# Log metrics every epoch
metrics = metric_calculator.get_metrics()

# Update best score
(
    best_score_updated,
    best_score
    ) = metric_calculator.update_best_score(metrics, desc)

best_score_updated = copy(best_score_updated)
best_score = copy(best_score)
        
# Log metrics
metrics_dict = {
    f"{desc}/{key}": value for key, value in metrics.items()
    }
metrics_dict.update(best_score) if desc == "val" else None 


# wandb.log(
#     metrics_dict,
#     )
metrics_dict

{'test/patch_auroc': tensor(0.5903),
 'test/patch_accuracy': tensor(0.7081),
 'test/all_inv_patch_auroc': tensor(0.5724),
 'test/all_inv_patch_accuracy': tensor(0.6896),
 'test/core_auroc': tensor(0.6653),
 'test/core_accuracy': tensor(0.8504),
 'test/all_inv_core_auroc': tensor(0.6332),
 'test/all_inv_core_accuracy': tensor(0.8169)}

In [None]:
import wandb
group=f"offline_memo_2+0sprt_gn_loco"
name=f"offline_memo_2+0sprt_gn_loco_{LEAVE_OUT}"
wandb.init(project="tta", entity="mahdigilany", name=name, group=group)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mmahdigilany[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
metrics_dict.update({"epoch": 0})
wandb.log(
    metrics_dict,
    )
wandb.finish()



VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁
test/all_inv_core_accuracy,▁
test/all_inv_core_auroc,▁
test/all_inv_patch_accuracy,▁
test/all_inv_patch_auroc,▁
test/core_accuracy,▁
test/core_auroc,▁
test/patch_accuracy,▁
test/patch_auroc,▁

0,1
epoch,0.0
test/all_inv_core_accuracy,0.81687
test/all_inv_core_auroc,0.6332
test/all_inv_patch_accuracy,0.68962
test/all_inv_patch_auroc,0.57242
test/core_accuracy,0.85043
test/core_auroc,0.66531
test/patch_accuracy,0.70807
test/patch_auroc,0.59028
