In [1]:
import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from torchmetrics import AUROC
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
import os
from fastai.vision.all import untar_data, URLs
from dime.data_utils import MaskLayerGaussian, MaskLayer2d
from dime.greedy_models import GreedyCMIEstimator
from dime.masking_pretrainer import MaskingPretrainer
from dime.resnet_imagenet import resnet18, Predictor, ValueNetwork, ResNet18Backbone
from dime.utils import accuracy, auc, normalize, selection_with_lamda
from dime.vit import PredictorViT, ValueNetworViT
import timm
import matplotlib.pyplot as plt

# Load Dataset

In [None]:
image_size = 224
mask_width = 14
mask_type = 'zero'
if mask_type == 'gaussian':
    mask_layer = MaskLayerGaussian(append=False, mask_width=mask_width, patch_size=image_size/mask_width)
else:
    mask_layer = MaskLayer2d(append=False, mask_width=mask_width, patch_size=image_size/mask_width)

device = torch.device('cuda:1')

norm_constants = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
unnorm_constants = ((-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010), (1/0.2023, 1/0.1994, 1/0.2010))
mbsize = 32
dataset_path = "/homes/<labname>/<username>/.fastai/data/imagenette2-320"
transforms_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(*norm_constants),
    ])

test_dataset = ImageFolder(dataset_path+'/val', transforms_test)
test_dataloader = DataLoader(test_dataset, batch_size=mbsize, pin_memory=True, drop_last=True, num_workers=4, shuffle=False)

# Load Pretrained Model

In [None]:
arch = "vit"
use_entropy=False
if arch == 'resnet':
    backbone = ResNet18Backbone(resnet18(pretrained=True))
    predictor =  Predictor(backbone)
    block_layer_stride = 1
    if mask_width == 14:
        block_layer_stride = 0.5
    value_network = ValueNetwork(backbone, block_layer_stride=block_layer_stride)
else:
    backbone = timm.create_model('vit_small_patch16_224', pretrained=True)
    predictor =  PredictorViT(backbone)
    value_network = ValueNetworViT(backbone, mask_width=mask_width, use_entropy=use_entropy)
    
predictor.load_state_dict(torch.load("results/predictor_trained_max_features_50_vit_small_patch16_224_lr_1e-5_use_entropy_True_zero_mask_width_14_save_best_perf.pth"))
value_network.load_state_dict(torch.load(f"results/value_network_trained_max_features_50_vit_small_patch16_224_lr_1e-5_use_entropy_True_zero_mask_width_14_save_best_perf.pth"))


greedy_cmi_estimator = GreedyCMIEstimator(value_network, predictor, mask_layer).to(device)

# Evaluate Penalized Policy

In [None]:
avg_num_features_lamda = []
accuracy_scores_lamda = []
all_masks_lamda =[]

lamda_values = [0.02]#305] #list(np.geomspace(0.001, 0.3, num=10))
for lamda in lamda_values:
    metric_dict = greedy_cmi_estimator.evaluate(test_dataloader, performance_func=accuracy, 
                                                                    feature_costs=None, use_entropy=True, evaluation_mode='lamda-penalty', lamda=lamda)
    
    accuracy_score = metric_dict['performance']
    final_masks = metric_dict['final_masks']
    accuracy_scores_lamda.append(accuracy_score)
    avg_num_features_lamda.append(np.mean(np.sum(final_masks, axis=1)))
    print(f"Lambda={lamda}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")
    all_masks_lamda.append(final_masks)


# Evaluate Budget Constrained Policy

In [None]:
avg_num_features_budget = []
accuracy_scores_budget = []
all_masks_budget=[]

max_budget_values = [10] #list(range(1, 15, 1))
for budget in max_budget_values:
    metric_dict_budget = greedy_cmi_estimator.evaluate(test_dataloader, performance_func=accuracy, 
                                                                    feature_costs=None, evaluation_mode='fixed-budget', budget=budget)
    
    accuracy_score = metric_dict_budget['performance']
    final_masks = metric_dict_budget['final_masks']
    accuracy_scores_budget.append(accuracy_score)
    avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))
    print(f"Budget={budget}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")

    all_masks_budget.append(final_masks)

# Evaluate Confidence Constrained Policy

In [None]:
avg_num_features_confidence = []
accuracy_scores_confidence = []
all_masks_confidence=[]
confidence_values = list(np.arange(0.1, 1, 0.1))

for confidence in confidence_values:
    metric_dict = greedy_cmi_estimator.evaluate(test_dataloader, performance_func=accuracy, 
                                                                    feature_costs=None, use_entropy=True, evaluation_mode='confidence', min_confidence=confidence)
    
    accuracy_score = metric_dict['performance']
    final_masks = metric_dict['final_masks']
    accuracy_scores_confidence.append(accuracy_score)
    avg_num_features_confidence.append(np.mean(np.sum(final_masks, axis=1)))
    print(f"Confidence={confidence}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")
    all_masks_confidence.append(final_masks)