In [1]:
%cd ..

/fs01/home/abbasgln/codes/medAI/projects/tta


In [2]:
import os
from dotenv import load_dotenv
# Loading environment variables
load_dotenv()

import torch
import torch.nn as nn
import typing as tp
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
import logging
import wandb

import medAI
from medAI.utils.setup import BasicExperiment, BasicExperimentConfig

from utils.metrics import MetricCalculator

from timm.optim.optim_factory import create_optimizer

from einops import rearrange, repeat
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import timm

from copy import copy, deepcopy
import pandas as pd

from datasets.datasets import ExactNCT2013RFImagePatches
from medAI.datasets.nct2013 import (
    KFoldCohortSelectionOptions,
    LeaveOneCenterOutCohortSelectionOptions, 
    PatchOptions
)


In [3]:
LEAVE_OUT='JH'

## Data MEMO

In [4]:
###### No support dataset ######

from vicreg_pretrain_experiment import PretrainConfig
config = PretrainConfig(cohort_selection_config=LeaveOneCenterOutCohortSelectionOptions(leave_out=f"{LEAVE_OUT}"))

from baseline_experiment import BaselineConfig
from torchvision.transforms import v2 as T
from torchvision.tv_tensors import Image as TVImage

class Transform:
    def __init__(selfT, augment=False):
        selfT.augment = augment
        selfT.size = (256, 256)
        # Augmentation
        selfT.transform = T.Compose([
            T.RandomAffine(degrees=0, translate=(0.2, 0.2)),
            T.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0.5),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
        ])  
    
    def __call__(selfT, item):
        patch = item.pop("patch")
        patch = copy(patch)
        patch = (patch - patch.min()) / (patch.max() - patch.min()) \
            if config.instance_norm else patch
        patch = TVImage(patch)
        patch = T.Resize(selfT.size, antialias=True)(patch).float()
        
        label = torch.tensor(item["grade"] != "Benign").long()
        
        if selfT.augment:
            patch_augs = torch.stack([selfT.transform(patch) for _ in range(2)], dim=0)
            return patch_augs, patch, label, item
        
        return -1, patch, label, item


cohort_selection_options_train = copy(config.cohort_selection_config)
cohort_selection_options_train.min_involvement = config.min_involvement_train
cohort_selection_options_train.benign_to_cancer_ratio = config.benign_to_cancer_ratio_train
cohort_selection_options_train.remove_benign_from_positive_patients = config.remove_benign_from_positive_patients_train

train_ds = ExactNCT2013RFImagePatches(
    split="train",
    transform=Transform(augment=False),
    cohort_selection_options=cohort_selection_options_train,
    patch_options=config.patch_config,
    debug=config.debug,
)

val_ds = ExactNCT2013RFImagePatches(
    split="val",
    transform=Transform(augment=False),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)

test_ds = ExactNCT2013RFImagePatches(
    split="test",
    transform=Transform(augment=False),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)


train_loader = DataLoader(
    train_ds, batch_size=config.batch_size, shuffle=True, num_workers=4
)

val_loader = DataLoader(
    val_ds, batch_size=config.batch_size, shuffle=False, num_workers=4
)

test_loader = DataLoader(
    test_ds, batch_size=config.batch_size, shuffle=False, num_workers=4
)



Computing positions: 100%|██████████| 756/756 [00:06<00:00, 115.53it/s]
Computing positions:  87%|████████▋ | 1059/1215 [00:09<00:01, 114.78it/s]


KeyboardInterrupt: 

## Model

In [None]:
from vicreg_pretrain_experiment import TimmFeatureExtractorWrapper
from timm.layers.adaptive_avgmax_pool import SelectAdaptivePool2d


fe_config = config.model_config

# Create the model
model: nn.Module = timm.create_model(
    fe_config.model_name,
    num_classes=fe_config.num_classes,
    in_chans=1,
    features_only=fe_config.features_only,
    norm_layer=lambda channels: nn.GroupNorm(
                    num_groups=fe_config.num_groups,
                    num_channels=channels
                    ))

# Separate creation of classifier and global pool from feature extractor
global_pool = SelectAdaptivePool2d(
    pool_type='avg',
    flatten=True,
    input_fmt='NCHW',
    )

model = nn.Sequential(TimmFeatureExtractorWrapper(model), global_pool)


# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_pretrain_gn_loco/vicreg_pretrain_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')
CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_pretrn_5e-3-20linprob_gn_loco/vicreg_pretrn_5e-3-20linprob_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')

model.load_state_dict(torch.load(CHECkPOINT_PATH)['model'])
model.eval()
model.cuda()

a = True

## Run train linear eval

In [None]:
from models.linear_prob import LinearProb

loader = train_loader

desc = "train"

# linear_prob = nn.Linear(512, 2).cuda()
# optimizer = optim.Adam(linear_prob.parameters(), lr=1e-4)
all_reprs_labels_metadata_train = []
for i, batch in enumerate(tqdm(loader, desc=desc)):
    images_augs, images, labels, meta_data = batch
    images_augs = images_augs.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    reprs = model(images).detach()
    all_reprs_labels_metadata_train.append((reprs, labels, meta_data))

    # logits = linear_prob(reprs)
    # loss = nn.CrossEntropyLoss()(logits, labels)
    
    # optimizer.zero_grad()
    # loss.backward()
    # optimizer.step()



train:   0%|          | 0/932 [00:00<?, ?it/s]

In [None]:
# os.environ["WANDB_MODE"] = "disabled"
metric_calculator = MetricCalculator()
linear_prob: LinearProb = LinearProb(512, 2, metric_calculator=metric_calculator, log_wandb=False)
linear_prob.train(all_reprs_labels_metadata_train,
                  epochs=15,
                  lr=5e-3
                  )

train_linear_prob:   0%|          | 0/932 [00:00<?, ?it/s]

train_linear_prob: 100%|██████████| 932/932 [00:21<00:00, 42.61it/s] 
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 693.27it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 515.21it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 681.74it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 680.41it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 468.87it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 677.84it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 690.04it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 685.67it/s]
train_linear_prob: 100%|██████████| 932/932 [00:02<00:00, 440.45it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 680.69it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 680.97it/s]
train_linear_prob: 100%|██████████| 932/932 [00:01<00:00, 669.38it/s]
train_linear_prob: 100%|██████████| 932/932 [00:02<00:00, 386.81it/s]
train_linear_prob: 1

In [None]:
loader = val_loader

desc = "test"

metric_calculator.reset()
all_reprs_labels_metadata_test = []
for i, batch in enumerate(tqdm(loader, desc=desc)):
    images_augs, images, labels, meta_data = batch
    images_augs = images_augs.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    reprs = model(images).detach()
    all_reprs_labels_metadata_test.append((reprs, labels, meta_data))
                    
    # # Update metrics   
    # metric_calculator.update(
    #     batch_meta_data = meta_data,
    #     probs = nn.functional.softmax(logits, dim=-1).detach().cpu(),
    #     labels = labels.detach().cpu(),
    # )
    

linear_prob.validate(all_reprs_labels_metadata_test, desc=desc)

test:   0%|          | 0/726 [00:00<?, ?it/s]

train_linear_prob: 100%|██████████| 726/726 [00:01<00:00, 695.98it/s]
train_linear_prob: 100%|██████████| 726/726 [00:01<00:00, 682.51it/s]
train_linear_prob: 100%|██████████| 726/726 [00:01<00:00, 688.31it/s]
train_linear_prob: 100%|██████████| 726/726 [00:41<00:00, 17.44it/s] 
train_linear_prob: 100%|██████████| 726/726 [00:01<00:00, 699.63it/s]
train_linear_prob: 100%|██████████| 726/726 [00:01<00:00, 693.29it/s]
train_linear_prob: 100%|██████████| 726/726 [00:01<00:00, 680.14it/s]
train_linear_prob: 100%|██████████| 726/726 [00:01<00:00, 464.67it/s]
train_linear_prob: 100%|██████████| 726/726 [00:01<00:00, 689.62it/s]
train_linear_prob: 100%|██████████| 726/726 [00:01<00:00, 682.90it/s]
test_linear_prob: 100%|██████████| 726/726 [00:00<00:00, 1418.64it/s]


In [None]:
# Log metrics every epoch
metrics = metric_calculator.get_metrics()

# Update best score
(
    best_score_updated,
    best_score
    ) = metric_calculator.update_best_score(metrics, desc)

best_score_updated = copy(best_score_updated)
best_score = copy(best_score)
        
# Log metrics
metrics_dict = {
    f"{desc}/{key}": value for key, value in metrics.items()
    }
metrics_dict.update(best_score) if desc == "val" else None 


# wandb.log(
#     metrics_dict,
#     )
metrics_dict

{'test/patch_auroc': tensor(0.6423),
 'test/patch_accuracy': tensor(0.9443),
 'test/all_inv_patch_auroc': tensor(0.6331),
 'test/all_inv_patch_accuracy': tensor(0.9211),
 'test/core_auroc': tensor(0.5000),
 'test/core_accuracy': tensor(0.9444),
 'test/all_inv_core_auroc': tensor(0.5000),
 'test/all_inv_core_accuracy': tensor(0.9227)}

In [None]:
import wandb
group=f"offline_vicreg_finetune_gn_loco"
name= group + f"_{LEAVE_OUT}"
wandb.init(project="tta", entity="mahdigilany", name=name, group=group)

In [9]:
# os.environ["WANDB_MODE"] = "enabled"
metrics_dict.update({"epoch": 0})
wandb.log(
    metrics_dict,
    )
wandb.finish()



VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁
test/all_inv_core_accuracy,▁
test/all_inv_core_auroc,▁
test/all_inv_patch_accuracy,▁
test/all_inv_patch_auroc,▁
test/core_accuracy,▁
test/core_auroc,▁
test/patch_accuracy,▁
test/patch_auroc,▁

0,1
epoch,0.0
test/all_inv_core_accuracy,0.7422
test/all_inv_core_auroc,0.72766
test/all_inv_patch_accuracy,0.66819
test/all_inv_patch_auroc,0.65275
test/core_accuracy,0.76589
test/core_auroc,0.8132
test/patch_accuracy,0.68636
test/patch_auroc,0.72705


In [77]:
import resource

def increase_file_descriptor_limit(new_limit):
    # Get the current soft and hard limits
    soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)

    # Set the new soft limit (cannot exceed the hard limit)
    new_soft_limit = min(new_limit, hard)
    resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft_limit, hard))

    # Verify and return the new limits
    new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
    return new_soft, new_hard

# Increase the limit
new_limit = 4096*2
new_soft, new_hard = increase_file_descriptor_limit(new_limit)
print(f"New file descriptor limits - Soft: {new_soft}, Hard: {new_hard}")

New file descriptor limits - Soft: 4096, Hard: 4096
