In [1]:
%cd ..

/fs01/home/abbasgln/codes/medAI/projects/tta


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import os
from dotenv import load_dotenv
# Loading environment variables
load_dotenv()

import torch
import torch.nn as nn
import typing as tp
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from dataclasses import dataclass, field
import logging
import wandb

import medAI
from medAI.utils.setup import BasicExperiment, BasicExperimentConfig

from utils.metrics import MetricCalculator, CoreMetricCalculator

from timm.optim.optim_factory import create_optimizer

from einops import rearrange, repeat
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import timm

from copy import copy, deepcopy
import pandas as pd

from datasets.datasets import ExactNCT2013RFImagePatches, ExactNCT2013RFCores
from medAI.datasets.nct2013 import (
    KFoldCohortSelectionOptions,
    LeaveOneCenterOutCohortSelectionOptions, 
    PatchOptions
)


In [3]:
LEAVE_OUT='JH'

## Data Finetuning

In [4]:
###### No support dataset ######

from doctest import debug

from matplotlib import axis
from vicreg_pretrain_experiment import PretrainConfig
config = PretrainConfig(cohort_selection_config=LeaveOneCenterOutCohortSelectionOptions(leave_out=f"{LEAVE_OUT}"), debug=False, batch_size=1)

from baseline_experiment import BaselineConfig
from torchvision.transforms import v2 as T

class Transform:
    def __init__(selfT, augment=False):
        selfT.augment = augment
        selfT.size = (256, 256)
        # Augmentation
        selfT.transform = T.Compose([
            T.RandomAffine(degrees=0, translate=(0.2, 0.2)),
            T.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0.5),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
        ])  
    
    def __call__(selfT, item):
        patches = item.pop("patches")
        patches = copy(patches)
        patches = (patches - np.min(patches, axis=(-2,-1), keepdims=True)) / (np.min(patches, axis=(-2,-1), keepdims=True) - np.min(patches, axis=(-2,-1), keepdims=True)) \
            if config.instance_norm else patches
        patches = torch.tensor(patches).float()
        patches = T.Resize(selfT.size, antialias=True)(patches).float()
        
        label = torch.tensor(item["grade"] != "Benign").long()
        
        if selfT.augment:
            patches_augs = torch.stack([selfT.transform(patches) for _ in range(2)], dim=0)
            return patches_augs, patches, label, item
        
        return -1, patches, label, item


cohort_selection_options_train = copy(config.cohort_selection_config)
cohort_selection_options_train.min_involvement = config.min_involvement_train
cohort_selection_options_train.benign_to_cancer_ratio = config.benign_to_cancer_ratio_train
cohort_selection_options_train.remove_benign_from_positive_patients = config.remove_benign_from_positive_patients_train

train_ds = ExactNCT2013RFCores(
    split="train",
    transform=Transform(augment=False),
    cohort_selection_options=cohort_selection_options_train,
    patch_options=config.patch_config,
    debug=config.debug,
)

# val_ds = ExactNCT2013RFCores(
#     split="val",
#     transform=Transform(augment=False),
#     cohort_selection_options=config.cohort_selection_config,
#     patch_options=config.patch_config,
#     debug=config.debug,
# )

test_ds = ExactNCT2013RFCores(
    split="test",
    transform=Transform(augment=True),
    cohort_selection_options=config.cohort_selection_config,
    patch_options=config.patch_config,
    debug=config.debug,
)


train_loader = DataLoader(
    train_ds, batch_size=config.batch_size, shuffle=True, num_workers=0
)

# val_loader = DataLoader(
#     val_ds, batch_size=config.batch_size, shuffle=False, num_workers=4
# )

test_loader = DataLoader(
    test_ds, batch_size=config.batch_size, shuffle=False, num_workers=0
)



Computing positions: 100%|██████████| 756/756 [00:08<00:00, 88.08it/s] 
Computing positions: 100%|██████████| 616/616 [00:06<00:00, 95.87it/s] 


## Model

In [5]:
from vicreg_pretrain_experiment import TimmFeatureExtractorWrapper
from timm.layers.adaptive_avgmax_pool import SelectAdaptivePool2d


fe_config = config.model_config

# Create the model
model: nn.Module = timm.create_model(
    fe_config.model_name,
    num_classes=fe_config.num_classes,
    in_chans=1,
    features_only=fe_config.features_only,
    norm_layer=lambda channels: nn.GroupNorm(
                    num_groups=fe_config.num_groups,
                    num_channels=channels
                    ))

# Separate creation of classifier and global pool from feature extractor
global_pool = SelectAdaptivePool2d(
    pool_type='avg',
    flatten=True,
    input_fmt='NCHW',
    )

model = nn.Sequential(TimmFeatureExtractorWrapper(model), global_pool)

CHECkPOINT_PATH = os.path.join(os.getcwd(), f'logs/tta/vicreg_pretrn_2048zdim_gn_loco/vicreg_pretrn_2048zdim_gn_loco_{LEAVE_OUT}/', 'best_model.ckpt')

model.load_state_dict(torch.load(CHECkPOINT_PATH)['model'])
model.cuda()

a = True

## Get train reprs

In [11]:
from models.finetuner import AttentionFineturner, AttentionConfig

metric_calculator = CoreMetricCalculator()
attenuer_model: AttentionFineturner = AttentionFineturner(
    feature_extractor=model,
    feature_dim=512, 
    num_classes=2,
    core_batch_size=10,
    attention_config=AttentionConfig(nhead=8),
    metric_calculator=metric_calculator,
    log_wandb=False
    )

# attenuer_state_dict = torch.load(os.path.join(os.getcwd(), f'notebooks/attenuer_model_{LEAVE_OUT}.ckpt'))
# attenuer_model.feature_extractor.load_state_dict(attenuer_state_dict['feature_extractor'])
# attenuer_model.attention.load_state_dict(attenuer_state_dict['attention_head'])
# attenuer_model.linear.load_state_dict(attenuer_state_dict['linear_head'])

attenuer_model.train(train_loader,
                  epochs=10,
                  train_backbone=False,
                  backbone_lr=1e-4,
                  head_lr=5e-4,
                  )
metric_calculator.get_metrics()

train:   0%|          | 0/750 [00:00<?, ?it/s]

In [7]:
# torch.save(
#     {'feature_extractor': attenuer_model.feature_extractor.state_dict(),
#      'attention_head': attenuer_model.attention.state_dict(),
#      'linear_head': attenuer_model.linear.state_dict(),
#      },
#     os.path.join(os.getcwd(), f'notebooks/attenuer_model_{LEAVE_OUT}.ckpt')
# )

## Test

In [8]:
attenuer_model.validate(test_loader, desc="test", use_memo=False)
# attenuer_model.validate(test_loader, desc="test", use_memo=True, memo_lr=5e-4)

test:   0%|          | 2/608 [00:00<00:48, 12.53it/s]

test: 100%|██████████| 608/608 [00:47<00:00, 12.89it/s]


## Find metrics

In [9]:
# Log metrics every epoch
desc = "test"
metrics = metric_calculator.get_metrics()

# Update best score
(
    best_score_updated,
    best_score
    ) = metric_calculator.update_best_score(metrics, desc)

best_score_updated = copy(best_score_updated)
best_score = copy(best_score)
        
# Log metrics
metrics_dict = {
    f"{desc}/{key}": value for key, value in metrics.items()
    }
metrics_dict.update(best_score) if desc == "val" else None 


# wandb.log(
#     metrics_dict,
#     )
metrics_dict

{'test/core_auroc': tensor(0.6203),
 'test/core_accuracy': tensor(0.4714),
 'test/all_inv_core_auroc': tensor(0.5668),
 'test/all_inv_core_accuracy': tensor(0.4704)}

## Log with wandb

In [None]:
import wandb
group=f"offline_vicreg_finetune_gn_loco"
name= group + f"_{LEAVE_OUT}"
wandb.init(project="tta", entity="mahdigilany", name=name, group=group)

In [None]:
# os.environ["WANDB_MODE"] = "enabled"
metrics_dict.update({"epoch": 0})
wandb.log(
    metrics_dict,
    )
wandb.finish()



VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁
test/all_inv_core_accuracy,▁
test/all_inv_core_auroc,▁
test/all_inv_patch_accuracy,▁
test/all_inv_patch_auroc,▁
test/core_accuracy,▁
test/core_auroc,▁
test/patch_accuracy,▁
test/patch_auroc,▁

0,1
epoch,0.0
test/all_inv_core_accuracy,0.7422
test/all_inv_core_auroc,0.72766
test/all_inv_patch_accuracy,0.66819
test/all_inv_patch_auroc,0.65275
test/core_accuracy,0.76589
test/core_auroc,0.8132
test/patch_accuracy,0.68636
test/patch_auroc,0.72705


In [None]:
import resource

def increase_file_descriptor_limit(new_limit):
    # Get the current soft and hard limits
    soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)

    # Set the new soft limit (cannot exceed the hard limit)
    new_soft_limit = min(new_limit, hard)
    resource.setrlimit(resource.RLIMIT_NOFILE, (new_soft_limit, hard))

    # Verify and return the new limits
    new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE)
    return new_soft, new_hard

# Increase the limit
new_limit = 4096*2
new_soft, new_hard = increase_file_descriptor_limit(new_limit)
print(f"New file descriptor limits - Soft: {new_soft}, Hard: {new_hard}")

New file descriptor limits - Soft: 4096, Hard: 4096
