In [1]:
%cd ..

/fs01/home/abbasgln/codes/medAI/projects/tta


In [2]:
import os
from dotenv import load_dotenv
# Loading environment variables
load_dotenv()

import torch
import torch.nn as nn
import typing as tp
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
import logging
import wandb

import medAI
from medAI.utils.setup import BasicExperiment, BasicExperimentConfig

from utils.metrics import MetricCalculator

from timm.optim.optim_factory import create_optimizer

from einops import rearrange, repeat
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import timm

from copy import copy, deepcopy
import pandas as pd

from datasets.datasets import ExactNCT2013RFImagePatches
from medAI.datasets.nct2013 import (
    KFoldCohortSelectionOptions,
    LeaveOneCenterOutCohortSelectionOptions, 
    PatchOptions
)


In [3]:
LEAVE_OUT='JH'

## Model

In [4]:
from baseline_experiment import FeatureExtractorConfig

fe_config = FeatureExtractorConfig()

# Create the model
model: nn.Module = timm.create_model(
    fe_config.model_name,
    num_classes=fe_config.num_classes,
    in_chans=1,
    features_only=fe_config.features_only,
    norm_layer=lambda channels: nn.GroupNorm(
                    num_groups=fe_config.num_groups,
                    num_channels=channels
                    ))

CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/baseline_gn_loco/baseline_gn_{LEAVE_OUT}_loco/checkpoints/', 'best_model.ckpt')

model.load_state_dict(torch.load(CHECkPOINT_PATH)['model'])
model.eval()
model.cuda()

a = True

## Data SAR

In [8]:
from baseline_experiment import BaselineConfig
from torchvision.transforms import v2 as T
from torchvision.tv_tensors import Image as TVImage

config = BaselineConfig(cohort_selection_config=LeaveOneCenterOutCohortSelectionOptions(leave_out=f"{LEAVE_OUT}"))

class Transform:
    def __init__(selfT, augment=False):
        selfT.augment = augment
        selfT.size = (256, 256)
    
    def __call__(selfT, item):
        patch = item.pop("patch")
        patch = copy(patch)
        patch = (patch - patch.min()) / (patch.max() - patch.min()) \
            if config.instance_norm else patch
        patch = TVImage(patch)
        # patch = T.ToImage()(patch)
        # patch = T.ToTensor()(patch)
        patch = T.Resize(selfT.size, antialias=True)(patch).float()
        
        
        if selfT.augment:
            # Augment support patches
            transform = T.Compose([
                T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                T.RandomHorizontalFlip(p=0.5),
                T.RandomVerticalFlip(p=0.5),
            ])  
            patch = transform(patch)
        
        label = torch.tensor(item["grade"] != "Benign").long()
        return patch, label, item



val_ds_sar = ExactNCT2013RFImagePatches(
    split="val",
    transform=Transform(),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)
        
test_ds_sar = ExactNCT2013RFImagePatches(
    split="test",
    transform=Transform(),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)


val_loader_sar = DataLoader(
    val_ds_sar, batch_size=config.batch_size, shuffle=False, num_workers=4
)
test_loader_sar = DataLoader(
    test_ds_sar, batch_size=config.batch_size, shuffle=False, num_workers=4
)

Computing positions: 100%|██████████| 1215/1215 [00:15<00:00, 76.97it/s]
Computing positions: 100%|██████████| 616/616 [00:07<00:00, 80.14it/s]
