In [None]:
import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from torchmetrics import AUROC
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
import os
from fastai.vision.all import untar_data, URLs
from torchvision import transforms
import pandas as pd
from dime.data_utils import HistopathologyDownsampledDataset, HistopathologyDownsampledEdgeDataset
from dime import MaskingPretrainerPrior, CMIEstimatorPrior
from dime.resnet_imagenet import resnet18, resnet34, Predictor, ValueNetwork, ResNet18Backbone, resnet50
from dime.sketch_supervision_predictor import SketchSupervisionPredictor
from dime.utils import get_confidence, MaskLayer2d
from dime.vit import PredictorViT, ValueNetworkViT, PredictorViTPrior, ValueNetworkViTPrior
import timm
from PIL import Image
import matplotlib.pyplot as plt

# Get Dataset

In [None]:
auc_metric = AUROC(task='multiclass', num_classes=10)

image_size = 224
mask_width = 14
mask_type = 'zero'
mask_layer = MaskLayer2d(append=False, mask_width=mask_width, patch_size=image_size/mask_width)

device = torch.device('cuda:6')

norm_constants = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
data_dir = '/projects/<labname>/<username>/hist_data/mhist/'
transforms_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(*norm_constants),
    ])

# Get test dataset
df = pd.read_csv(data_dir + 'annotations.csv')
test_dataset = HistopathologyDownsampledEdgeDataset(data_dir + 'images/', df.loc[df['Partition'] == 'test'], transforms_test)
test_dataset_len = len(test_dataset)
mbsize = 32


# Split test dataset into val
np.random.seed(0)
val_inds = np.sort(np.random.choice(test_dataset_len, size=int(test_dataset_len*0.5), replace=False))
test_inds = np.setdiff1d(np.arange(test_dataset_len), val_inds)

val_dataset = torch.utils.data.Subset(test_dataset, val_inds)
test_dataset = torch.utils.data.Subset(test_dataset, test_inds)
test_dataloader = DataLoader(test_dataset, batch_size=mbsize, pin_memory=True, drop_last=True, num_workers=4)


# Load pretrained checkpoint

In [None]:
arch ='vit_small_patch16_224'
pretrained_model_path = "<path_to_pretrained_model>"

backbone1 = timm.create_model('vit_small_patch16_224', pretrained=True)
backbone2 = timm.create_model('vit_small_patch16_224', pretrained=True)

predictor =  PredictorViTPrior(backbone1, backbone1, num_classes=2).to(device)
value_network = ValueNetworkViTPrior(backbone1, backbone1).to(device)
    
greedy_cmi_estimator = CMIEstimatorPrior.load_from_checkpoint(pretrained_model_path,
                                                              value_network=value_network,
                                                              predictor=predictor,
                                                              mask_layer=mask_layer,
                                                              lr=1e-5,
                                                              min_lr=1e-8,
                                                              max_features=100,
                                                              eps=0.05,
                                                              loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                              val_loss_fn=auc_metric,
                                                              eps_decay=0.2,
                                                              eps_steps=10,
                                                              patience=3,
                                                              feature_costs=None
                                                        )


trainer = Trainer(
                    accelerator='gpu',
                    devices=[device.index],
                    precision=16
                )

# Evaluate Penalized Policy

In [None]:
avg_num_features_lamda = []
accuracy_scores_lamda = []
all_masks_lamda =[]

# lamda_values = list(np.geomspace(0.0025, 0.04, num=12))
# lamda_values = list(np.linspace(0.02, 0.04, num=12))
lamda_values= list(np.geomspace(0.002, 0.005, num=12))
for lamda in lamda_values:
    metric_dict = greedy_cmi_estimator.inference(trainer, test_dataloader,feature_costs=None, lam=lamda)
    
    y = metric_dict['y']
    pred = metric_dict['pred']
    accuracy_score = auc_metric(pred, y)
    final_masks = np.array(metric_dict['mask'])
    accuracy_scores_lamda.append(accuracy_score)
    avg_num_features_lamda.append(np.mean(np.sum(final_masks, axis=1)))
    results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score

    print(f"Lambda={lamda}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")
    all_masks_lamda.append(final_masks)

# Evaluate Budget-Constrained Policy

In [None]:
avg_num_features_budget = []
accuracy_scores_budget = []
all_masks_budget=[]

max_budget_values = [2]+ list(range(10, 70, 10))
for budget in max_budget_values:
    metric_dict_budget = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=None, budget=budget)
        
    y = metric_dict_budget['y']
    pred = metric_dict_budget['pred']
    accuracy_score = auc_metric(pred, y)
    final_masks = np.array(metric_dict_budget['mask'])
    accuracy_scores_budget.append(accuracy_score)
    avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))
    results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score
    print(f"Budget={budget}, AUROC={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")

    all_masks_budget.append(final_masks)