In [2]:
import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from torchmetrics import AUROC
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.datasets import ImageFolder
import os
from fastai.vision.all import untar_data, URLs
import pandas as pd
from pytorch_lightning import Trainer
from experiments import MaskLayerGaussian, MaskLayer2d, HistopathologyDownsampledDataset
from dime.greedy_models import GreedyCMIEstimator
from dime.masking_pretrainer import MaskingPretrainer
from dime.greedy_model_pl import GreedyCMIEstimatorPL
from dime.resnet_imagenet import resnet18, resnet34, Predictor, ValueNetwork, ResNet18Backbone, resnet50
from dime.utils import accuracy, auc, normalize, selection_with_lamda, MaskLayer
from dime.vit import PredictorViT, ValueNetworViT
import timm
import matplotlib.pyplot as plt

# Load Dataset

In [None]:
# Load test dataset, split into train/val
mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,
                      transform=transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)]))
np.random.seed(0)
# Load test dataset
test_dataset = MNIST('/tmp/mnist/', download=True, train=False,
                     transform=transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)]))

device = torch.device('cuda:2')

test_dataloader = DataLoader(
        test_dataset, batch_size=128, shuffle=False, pin_memory=True,
        drop_last=True, num_workers=4)

# Set up networks

In [None]:
d_in = 784
d_out = 10
hidden = 512
dropout = 0.3

# Outcome Predictor
predictor = nn.Sequential(
    nn.Linear(d_in * 2, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, d_out)).to(device)

# CMI Predictor
value_network = nn.Sequential(
    nn.Linear(d_in * 2, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, d_in),
    nn.Sigmoid()).to(device)

# Tie weights
value_network[0] = predictor[0]
value_network[3] = predictor[3]
mask_layer = MaskLayer(append=True, mask_size=d_in)

# Evaluate Penalized Policy

In [None]:

for trial in range(0, 5):
    results_dict = {"acc": {}}
    path = f'<path_to_trained_model>'

    greedy_cmi_estimator = GreedyCMIEstimatorPL.load_from_checkpoint(path,
                                                                    value_network=value_network,
                                                                     predictor=predictor,
                                                                     mask_layer=mask_layer,
                                                                     lr=1e-3,
                                                                    min_lr=1e-6,
                                                                    max_features=50,
                                                                    eps=0.05,
                                                                    loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                                    val_loss_fn=auc,
                                                                    eps_decay=True,
                                                                    eps_decay_rate=0.2,
                                                                    patience=3,
                                                                    feature_costs=None,
                                                                    use_entropy=True).to(device)
    avg_num_features_lamda = []
    accuracy_scores_lamda = []
    all_masks_lamda =[]

    lamda_values = list(np.geomspace(0.00016, 0.28, num=10))
    for lamda in lamda_values:
        metric_dict = greedy_cmi_estimator.evaluate(test_dataloader, performance_func=accuracy, 
                                                                        feature_costs=None, use_entropy=True, evaluation_mode='lamda-penalty', lamda=lamda)

        accuracy_score = metric_dict['performance']
        final_masks = metric_dict['final_masks']
        accuracy_scores_lamda.append(accuracy_score)
        avg_num_features_lamda.append(np.mean(np.sum(final_masks, axis=1)))
        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score

        print(f"Lambda={lamda}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")
        all_masks_lamda.append(final_masks)
    with open(f'results/mnist_lamda_ours_trial_{trial-4}.pkl', 'wb') as f:
        pickle.dump(results_dict, f)


# Evaluate Budget Constrained Policy

In [None]:
results_dict = {"acc": {}}
for trial in range(0, 5):
   
    path = f'<path_to_trained_model>'
    greedy_cmi_estimator = GreedyCMIEstimatorPL.load_from_checkpoint(path,
                                                                    value_network=value_network,
                                                                     predictor=predictor,
                                                                     mask_layer=mask_layer,
                                                                     lr=1e-3,
                                                                    min_lr=1e-6,
                                                                    max_features=50,
                                                                    eps=0.05,
                                                                    loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                                    val_loss_fn=auc,
                                                                    eps_decay=True,
                                                                    eps_decay_rate=0.2,
                                                                    patience=3,
                                                                    feature_costs=None,
                                                                    use_entropy=True)
    avg_num_features_budget = []
    accuracy_scores_budget = []
    all_masks_budget=[]

    max_budget_values = [3] + list(range(5, 30, 5))
    for budget in max_budget_values:
        metric_dict_budget = greedy_cmi_estimator.evaluate(test_dataloader, performance_func=accuracy, 
                                                                        feature_costs=None, evaluation_mode='fixed-budget', budget=budget)

        accuracy_score = metric_dict_budget['performance']
        final_masks = metric_dict_budget['final_masks']
        accuracy_scores_budget.append(accuracy_score)
        avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))
        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score
        print(f"Budget={budget}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")

    all_masks_budget.append(final_masks)
    with open(f'results/mnist_ours_trial_{trial-4}.pkl', 'wb') as f:
        pickle.dump(results_dict, f)
