In [1]:
%cd ..

/fs01/home/abbasgln/codes/medAI/projects/tta


In [2]:
import os
from dotenv import load_dotenv
# Loading environment variables
load_dotenv()

import torch
import torch.nn as nn
import typing as tp
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
import logging
import wandb

import medAI
from medAI.utils.setup import BasicExperiment, BasicExperimentConfig

from utils.metrics import MetricCalculator

from timm.optim.optim_factory import create_optimizer

from einops import rearrange, repeat
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import timm

from copy import copy, deepcopy
import pandas as pd

from datasets.datasets import ExactNCT2013RFImagePatches
from medAI.datasets.nct2013 import (
    KFoldCohortSelectionOptions,
    LeaveOneCenterOutCohortSelectionOptions, 
    PatchOptions
)


In [3]:
LEAVE_OUT='PCC'

## Data Finetuning

In [4]:
###### No support dataset ######

from vicreg_pretrain_experiment import PretrainConfig
config = PretrainConfig(cohort_selection_config=LeaveOneCenterOutCohortSelectionOptions(leave_out=f"{LEAVE_OUT}"))

from baseline_experiment import BaselineConfig
from torchvision.transforms import v2 as T
from torchvision.tv_tensors import Image as TVImage

class Transform:
    def __init__(selfT, augment=False):
        selfT.augment = augment
        selfT.size = (256, 256)
        # Augmentation
        selfT.transform = T.Compose([
            T.RandomAffine(degrees=0, translate=(0.2, 0.2)),
            T.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0.5),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
        ])  
    
    def __call__(selfT, item):
        patch = item.pop("patch")
        patch = copy(patch)
        patch = (patch - patch.min()) / (patch.max() - patch.min()) \
            if config.instance_norm else patch
        patch = TVImage(patch)
        patch = T.Resize(selfT.size, antialias=True)(patch).float()
        
        label = torch.tensor(item["grade"] != "Benign").long()
        
        if selfT.augment:
            patch_augs = torch.stack([selfT.transform(patch) for _ in range(2)], dim=0)
            return patch_augs, patch, label, item
        
        return -1, patch, label, item


cohort_selection_options_train = copy(config.cohort_selection_config)
cohort_selection_options_train.min_involvement = config.min_involvement_train
cohort_selection_options_train.benign_to_cancer_ratio = config.benign_to_cancer_ratio_train
cohort_selection_options_train.remove_benign_from_positive_patients = config.remove_benign_from_positive_patients_train

train_ds = ExactNCT2013RFImagePatches(
    split="train",
    transform=Transform(augment=False),
    cohort_selection_options=cohort_selection_options_train,
    patch_options=config.patch_config,
    debug=config.debug,
)

val_ds = ExactNCT2013RFImagePatches(
    split="val",
    transform=Transform(augment=False),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)

test_ds = ExactNCT2013RFImagePatches(
    split="test",
    transform=Transform(augment=False),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)


train_loader = DataLoader(
    train_ds, batch_size=config.batch_size, shuffle=True, num_workers=4
)

val_loader = DataLoader(
    val_ds, batch_size=config.batch_size, shuffle=False, num_workers=4
)

test_loader = DataLoader(
    test_ds, batch_size=config.batch_size, shuffle=False, num_workers=4
)



Computing positions: 100%|██████████| 658/658 [00:05<00:00, 130.60it/s]
Computing positions: 100%|██████████| 1026/1026 [00:09<00:00, 110.25it/s]
Computing positions: 100%|██████████| 1599/1599 [00:16<00:00, 94.07it/s] 


## Model

In [5]:
from vicreg_pretrain_experiment import TimmFeatureExtractorWrapper
from timm.layers.adaptive_avgmax_pool import SelectAdaptivePool2d


fe_config = config.model_config

# Create the model
model: nn.Module = timm.create_model(
    fe_config.model_name,
    num_classes=fe_config.num_classes,
    in_chans=1,
    features_only=fe_config.features_only,
    norm_layer=lambda channels: nn.GroupNorm(
                    num_groups=fe_config.num_groups,
                    num_channels=channels
                    ))

# Separate creation of classifier and global pool from feature extractor
global_pool = SelectAdaptivePool2d(
    pool_type='avg',
    flatten=True,
    input_fmt='NCHW',
    )

model = nn.Sequential(TimmFeatureExtractorWrapper(model), global_pool)


# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_pretrain_gn_loco/vicreg_pretrain_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')
# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_pretrn_5e-3-20linprob_gn_loco/vicreg_pretrn_5e-3-20linprob_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')
# CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_pretrn_5e-3-15linprob_100ep_actual2048zdim_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')
CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_pretrn_2048zdim_gn_loco/vicreg_pretrn_2048zdim_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')

model.load_state_dict(torch.load(CHECkPOINT_PATH)['model'])
model.eval()
model.cuda()

a = True

## Get train reprs

In [6]:
from models.linear_prob import LinearProb

loader = train_loader

desc = "train"
metric_calculator = MetricCalculator()
# linear_prob = nn.Linear(512, 2).cuda()
# optimizer = optim.Adam(linear_prob.parameters(), lr=1e-4)
all_reprs_labels_metadata_train = []
all_reprs = []
all_labels = []
for i, batch in enumerate(tqdm(loader, desc=desc)):
    batch = deepcopy(batch)
    images_augs, images, labels, meta_data = batch
    images_augs = images_augs.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    reprs = model(images).detach()
    all_reprs.append(reprs.cpu().numpy())
    all_labels.append(labels.cpu().numpy())
    all_reprs_labels_metadata_train.append((reprs, labels, meta_data))

    # logits = linear_prob(reprs)
    # loss = nn.CrossEntropyLoss()(logits, labels)
    
    # optimizer.zero_grad()
    # loss.backward()
    # optimizer.step()
all_reprs = np.concatenate(all_reprs, axis=0)
all_labels = np.concatenate(all_labels, axis=0)


train:   0%|          | 0/803 [00:00<?, ?it/s]

## Train linear model on reprs

### SKlearn logistic regression

In [7]:
import re
from sklearn.linear_model import LogisticRegression

LR = LogisticRegression(solver='lbfgs', max_iter=1000)
LR.fit(all_reprs, all_labels)


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


### Linear prob 

In [8]:
# # os.environ["WANDB_MODE"] = "disabled"
# linear_prob: LinearProb = LinearProb(512, 2, metric_calculator=metric_calculator, log_wandb=False)
# linear_prob.train(all_reprs_labels_metadata_train,
#                   epochs=15,
#                   lr=5e-3
#                   )

## Get test reprs

In [9]:
loader = test_loader

desc = "test"

all_reprs_labels_metadata_test = []
all_reprs_test = []
all_labels_test = []
for i, batch in enumerate(tqdm(loader, desc=desc)):
    batch = deepcopy(batch)
    images_augs, images, labels, meta_data = batch
    images_augs = images_augs.cuda()
    images = images.cuda()
    labels = labels.cuda()
    
    reprs = model(images).detach()
    all_reprs_test .append(reprs.cpu().numpy())
    all_labels_test.append(labels.cpu().numpy())
    all_reprs_labels_metadata_test.append((reprs, labels, meta_data))
                    
    # # Update metrics   
    # metric_calculator.update(
    #     batch_meta_data = meta_data,
    #     probs = nn.functional.softmax(logits, dim=-1).detach().cpu(),
    #     labels = labels.detach().cpu(),
    # )
all_reprs_test = np.concatenate(all_reprs_test, axis=0)
all_labels_test = np.concatenate(all_labels_test, axis=0)


test:   0%|          | 0/1990 [00:00<?, ?it/s]

## Predict

### Maybe train on test before prediction

In [10]:
# # Balance the test set for training
# from sklearn.model_selection import train_test_split

# indices_class_0 = np.where(all_labels_test == 0)[0]
# indices_class_1 = np.where(all_labels_test == 1)[0]

# # Randomly sample indices from class 0 to match the number of class 1 indices
# indices_class_0_sampled, _ = train_test_split(indices_class_0, train_size=len(indices_class_1), random_state=0)
# balanced_idx_test = np.concatenate([indices_class_0_sampled[:int(1.0*len(indices_class_1))], indices_class_1[:int(1.0*len(indices_class_1))]], axis=0)

In [11]:
# # Train using test set and train set
# LR = LogisticRegression(solver='lbfgs', max_iter=1000)
# # LR.fit(
# #     np.concatenate([all_reprs_test[balanced_idx_test, :], all_reprs], axis=0),
# #     np.concatenate([all_labels_test[balanced_idx_test], all_labels], axis=0)
# # )
# LR.fit(all_reprs_test[balanced_idx_test, :], all_labels_test[balanced_idx_test])

In [12]:
# # Train using pseudo labels of test set
# LR_preds = LR.predict_proba(all_reprs_test[balanced_idx_test,:]).argmax(axis=1)
# LR = LogisticRegression(solver='lbfgs', max_iter=1000)
# LR.fit(
#     np.concatenate([all_reprs_test[balanced_idx_test, :], all_reprs], axis=0),
#     np.concatenate([LR_preds, all_labels], axis=0)
# )


### Logistic prediction

In [13]:
# Logisitic regression predictions
LR_probs = LR.predict_proba(all_reprs_test)

# Update metrics
metric_calculator.reset()
probs = torch.tensor(LR_probs)
labels = torch.tensor(all_labels_test)
invs = torch.cat([meta_data["pct_cancer"] for _, _, meta_data in all_reprs_labels_metadata_test])
ids = torch.cat([meta_data["id"] for _, _, meta_data in all_reprs_labels_metadata_test])
for i, id_tensor in enumerate(ids):
    id = id_tensor.item()
    
    # Dict of invs
    metric_calculator.core_id_invs[id] = invs[i]
    
    # Dict of probs and labels
    if id in metric_calculator.core_id_probs:
        metric_calculator.core_id_probs[id].append(probs[i])
        metric_calculator.core_id_labels[id].append(labels[i])
    else:
        metric_calculator.core_id_probs[id] = [probs[i]]
        metric_calculator.core_id_labels[id] = [labels[i]]


### Linear prob prediction

In [14]:
# metric_calculator.reset()
# linear_prob.validate(all_reprs_labels_metadata_test, desc=desc)

## Find metrics

In [15]:
# Log metrics every epoch
metrics = metric_calculator.get_metrics()

# Update best score
(
    best_score_updated,
    best_score
    ) = metric_calculator.update_best_score(metrics, desc)

best_score_updated = copy(best_score_updated)
best_score = copy(best_score)
        
# Log metrics
metrics_dict = {
    f"{desc}/{key}": value for key, value in metrics.items()
    }
metrics_dict.update(best_score) if desc == "val" else None 


# wandb.log(
#     metrics_dict,
#     )
metrics_dict

{'test/patch_auroc': tensor(0.6676),
 'test/patch_accuracy': tensor(0.4920),
 'test/all_inv_patch_auroc': tensor(0.6298),
 'test/all_inv_patch_accuracy': tensor(0.4986),
 'test/core_auroc': tensor(0.7270),
 'test/core_accuracy': tensor(0.4792),
 'test/all_inv_core_auroc': tensor(0.6719),
 'test/all_inv_core_accuracy': tensor(0.4884)}

## Log with wandb

In [None]:
import wandb
group=f"offline_vicreg_finetune_gn_loco"
name= group + f"_{LEAVE_OUT}"
wandb.init(project="tta", entity="mahdigilany", name=name, group=group)

In [9]:
# os.environ["WANDB_MODE"] = "enabled"
metrics_dict.update({"epoch": 0})
wandb.log(
    metrics_dict,
    )
wandb.finish()



VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁
test/all_inv_core_accuracy,▁
test/all_inv_core_auroc,▁
test/all_inv_patch_accuracy,▁
test/all_inv_patch_auroc,▁
test/core_accuracy,▁
test/core_auroc,▁
test/patch_accuracy,▁
test/patch_auroc,▁

0,1
epoch,0.0
test/all_inv_core_accuracy,0.7422
test/all_inv_core_auroc,0.72766
test/all_inv_patch_accuracy,0.66819
test/all_inv_patch_auroc,0.65275
test/core_accuracy,0.76589
test/core_auroc,0.8132
test/patch_accuracy,0.68636
test/patch_auroc,0.72705


In [77]:
import resource

def increase_file_descriptor_limit(new_limit):
    # Get the current soft and hard limits
    soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)

    # Set the new soft limit (cannot exceed the hard limit)
    new_soft_limit = min(new_limit, hard)
    resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft_limit, hard))

    # Verify and return the new limits
    new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
    return new_soft, new_hard

# Increase the limit
new_limit = 4096*2
new_soft, new_hard = increase_file_descriptor_limit(new_limit)
print(f"New file descriptor limits - Soft: {new_soft}, Hard: {new_hard}")

New file descriptor limits - Soft: 4096, Hard: 4096
