# Note!

To run this notebook, you first need to train models using the bash script we provide. Then in the cell number 5 enter the Path to a valid model.

In [1]:
import sys
sys.path.insert(0, "../")

In [2]:
import torch
from tqdm.auto import tqdm
from src.config.models import NatPnModel, LeNet5, ResNet18
from src.config.nat_pn.loss import BayesianLoss
from src.config.uncertainty_metrics import (
    load_dataset,
    load_dataloaders,
    choose_threshold,
    load_model,
)
from src.config.utils import evaluate_accuracy, evaluate_switch, quantiles

import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
import csv
import pandas as pd

In [3]:
measure = 'entropy' # entropy, log_prob, epkl

In [4]:
header = [
    'dataset',
    'client_id',
    'ind_local_acc',
    'ind_global_acc',
    'ind_switch_acc',
    'ood_local_acc',
    'ood_global_acc',
    'ood_switch_acc',
    'mix_local_acc',
    'mix_global_acc',
    'mix_switch_acc',
]

# create and write the header to the csv file
with open(f'mix_threshold_compact_results_{measure}.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(header)

In [5]:
for dataset_name in ['mnist', 'fmnist', 'cifar10', 'svhn', 'medmnistA', 'medmnistC', 'medmnistS']:
    prefix = "None"
    if dataset_name in ['cifar10', 'svhn']:
        prefix = 'lenet_' + prefix
    density_model = 'flow'
    embedding_dim = 16
    
    path = f"../out/FedAvg/{prefix}all_params_stopgrad_logp_{dataset_name}_500_natpn.pt"
    # YOUR PATH TO A VALID MODEL IS HERE
    # for example 
    # path = f"../out/FedAvg/actual_models/{prefix}all_params_stopgrad_logp_{dataset_name}_100_natpn.pt"
    
    

    backbone = 'lenet5' #'res18' if dataset_name in ['cifar10', 'svhn'] else 'lenet5'
    batch_size = 7000 if dataset_name in ['cifar10', 'svhn'] else 25000

    stopgrad = False
    device = 'cuda:0'

    all_params_dict = torch.load(path)

    data_indices, trainset, testset = load_dataset(
        dataset_name=dataset_name,
        normalization_name=dataset_name,
    )

    global_model = load_model(
        dataset_name=dataset_name,
        backbone=backbone,
        stopgrad=stopgrad,
        density_model=density_model,
        embedding_dim=embedding_dim,
        index='global',
        all_params_dict=all_params_dict,
    )
    global_model.eval()
    global_model = global_model.to(device)


    ind_local_accuracies = []
    ind_global_accuracies = []
    ind_switch_accuracies = []

    ood_local_accuracies = []
    ood_global_accuracies = []
    ood_switch_accuracies = []
    
    mix_local_accuracies = []
    mix_global_accuracies = []
    mix_switch_accuracies = []


    for index in tqdm(range(len(all_params_dict) - 1)):
        local_model = load_model(
            dataset_name=dataset_name,
            backbone=backbone,
            density_model=density_model,
            embedding_dim=embedding_dim,
            stopgrad=stopgrad,
            index=index,
            all_params_dict=all_params_dict,
        )
        local_model.eval()
        local_model = local_model.to(device)

        in_classes = local_model.labels.cpu().numpy().tolist()

        # create in-distribution and out-of-distribution datasets
        in_distribution_indices = [i for i, t in enumerate(trainset.dataset.targets) if t in in_classes]
        out_distribution_indices = [i for i, t in enumerate(trainset.dataset.targets) if t not in in_classes]

        in_distribution_dataset = torch.utils.data.Subset(trainset.dataset, in_distribution_indices)
        out_distribution_dataset = torch.utils.data.Subset(trainset.dataset, out_distribution_indices)

        ind_loader = torch.utils.data.DataLoader(in_distribution_dataset, batch_size=batch_size, shuffle=False)
        ood_loader = torch.utils.data.DataLoader(out_distribution_dataset, batch_size=batch_size, shuffle=False)

        _, _, calloader = load_dataloaders(
        client_id=index, data_indices=data_indices, trainset=trainset, testset=testset
        )
        
        quantile = quantiles[dataset_name]

        threshold, values = choose_threshold(
            model=local_model,
            calloader=calloader,
            device=device,
            alpha=quantile,
        )
        ind_correct_decision, ind_correct_local, ind_correct_global, ind_sample_num = evaluate_switch(
            local_model=local_model,
            global_model=global_model, 
            dataloader=ind_loader,
            threshold=threshold[measure],
            uncertainty_measure=measure,
            device=device,
            return_predictions=False
        )

        ood_correct_decision, ood_correct_local, ood_correct_global, ood_sample_num, ood_local_predictions, ood_global_predictions, ood_switch_predictions, ood_true_labels, _, _ = evaluate_switch(
            local_model=local_model,
            global_model=global_model, 
            dataloader=ood_loader,
            threshold=threshold[measure],
            uncertainty_measure=measure,
            device=device,
            return_predictions=True
        )

        ind_switch_accuracies.append(ind_correct_decision / ind_sample_num)
        ood_switch_accuracies.append(ood_correct_decision / ood_sample_num)

        ind_local_accuracies.append(ind_correct_local / ind_sample_num)
        ind_global_accuracies.append(ind_correct_global / ind_sample_num)

        ood_local_accuracies.append(ood_correct_local / ood_sample_num)
        ood_global_accuracies.append(ood_correct_global / ood_sample_num)
        
        
        mix_samples = ind_sample_num
        
        mix_local_correct = ind_correct_local
        mix_global_correct = ind_correct_global
        mix_switch_correct = ind_correct_decision
        
        
        true_labels = torch.hstack(ood_true_labels)
        all_labels = torch.unique(true_labels)
        
        mix_local_answers = torch.hstack(ood_local_predictions)
        mix_global_answers = torch.hstack(ood_global_predictions)
        mix_switch_answers = torch.hstack(ood_switch_predictions)
        
        sample_to_pick = int(ind_sample_num / len(all_labels))
        for l in all_labels:
            selected_predictions = (mix_local_answers == true_labels)[true_labels == l][:sample_to_pick]
            mix_local_correct += selected_predictions.sum().cpu().item()
            
            selected_predictions = (mix_global_answers == true_labels)[true_labels == l][:sample_to_pick]
            mix_global_correct += selected_predictions.sum().cpu().item()
            
            selected_predictions = (mix_switch_answers == true_labels)[true_labels == l][:sample_to_pick]
            mix_switch_correct += selected_predictions.sum().cpu().item()            
            
            mix_samples += len(selected_predictions)
            
        mix_local_accuracies.append(mix_local_correct / mix_samples)
        mix_global_accuracies.append(mix_global_correct / mix_samples)
        mix_switch_accuracies.append(mix_switch_correct / mix_samples)
        
        
        new_row = [
            dataset_name,
            index,
            ind_local_accuracies[-1],
            ind_global_accuracies[-1],
            ind_switch_accuracies[-1],
            ood_local_accuracies[-1],
            ood_global_accuracies[-1],
            ood_switch_accuracies[-1],
            mix_local_accuracies[-1],
            mix_global_accuracies[-1],
            mix_switch_accuracies[-1],
        ]
        # open the csv file in append mode ('a') and write the row
        with open(f'mix_threshold_compact_results_{measure}.csv', 'a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(new_row)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Using downloaded and verified file: ../data/svhn/raw/train_32x32.mat
Using downloaded and verified file: ../data/svhn/raw/test_32x32.mat


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

In [6]:
comp_results = pd.read_csv(f'mix_threshold_compact_results_{measure}.csv')

In [7]:
grouped_df = comp_results.groupby('dataset')[[col for col in comp_results.columns if col not in  ['client_id', 'dataset']]].mean()

In [8]:
grouped_df

Unnamed: 0_level_0,ind_local_acc,ind_global_acc,ind_switch_acc,ood_local_acc,ood_global_acc,ood_switch_acc,mix_local_acc,mix_global_acc,mix_switch_acc
dataset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
cifar10,0.751285,0.396237,0.591292,0.0,0.388211,0.287977,0.37566,0.39382,0.441049
fmnist,0.95707,0.799873,0.842708,0.0,0.814732,0.781737,0.478535,0.807163,0.812558
medmnistA,0.989735,0.958855,0.961896,0.0,0.952013,0.949385,0.496466,0.954885,0.9552
medmnistC,0.966257,0.938915,0.944187,0.0,0.931231,0.887122,0.483515,0.932255,0.911745
medmnistS,0.907246,0.829766,0.868632,0.0,0.826807,0.755279,0.454031,0.814063,0.800323
mnist,0.994038,0.983275,0.983547,0.0,0.983121,0.982914,0.497059,0.983246,0.98329
svhn,0.922265,0.777547,0.871364,0.0,0.775414,0.621665,0.461163,0.762285,0.733723


In [9]:
# Create a list to hold the new multi-index tuples
new_columns = []

# Loop through the existing columns
for column in grouped_df.columns:
    # Split the column name into the group and the rest
    group, rest = column.split('_', 1)
    # Add the new tuple to the list
    new_columns.append((group, rest))

# Assign the new multi-index to the DataFrame
grouped_df.columns = pd.MultiIndex.from_tuples(new_columns, names=['group', 'metric'])

# Now you can select columns by group
ind_df = grouped_df['ind']
ood_df = grouped_df['ood']

In [10]:
grouped_df

group,ind,ind,ind,ood,ood,ood,mix,mix,mix
metric,local_acc,global_acc,switch_acc,local_acc,global_acc,switch_acc,local_acc,global_acc,switch_acc
dataset,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
cifar10,0.751285,0.396237,0.591292,0.0,0.388211,0.287977,0.37566,0.39382,0.441049
fmnist,0.95707,0.799873,0.842708,0.0,0.814732,0.781737,0.478535,0.807163,0.812558
medmnistA,0.989735,0.958855,0.961896,0.0,0.952013,0.949385,0.496466,0.954885,0.9552
medmnistC,0.966257,0.938915,0.944187,0.0,0.931231,0.887122,0.483515,0.932255,0.911745
medmnistS,0.907246,0.829766,0.868632,0.0,0.826807,0.755279,0.454031,0.814063,0.800323
mnist,0.994038,0.983275,0.983547,0.0,0.983121,0.982914,0.497059,0.983246,0.98329
svhn,0.922265,0.777547,0.871364,0.0,0.775414,0.621665,0.461163,0.762285,0.733723


In [11]:
(grouped_df * 100).round(1)

group,ind,ind,ind,ood,ood,ood,mix,mix,mix
metric,local_acc,global_acc,switch_acc,local_acc,global_acc,switch_acc,local_acc,global_acc,switch_acc
dataset,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
cifar10,75.1,39.6,59.1,0.0,38.8,28.8,37.6,39.4,44.1
fmnist,95.7,80.0,84.3,0.0,81.5,78.2,47.9,80.7,81.3
medmnistA,99.0,95.9,96.2,0.0,95.2,94.9,49.6,95.5,95.5
medmnistC,96.6,93.9,94.4,0.0,93.1,88.7,48.4,93.2,91.2
medmnistS,90.7,83.0,86.9,0.0,82.7,75.5,45.4,81.4,80.0
mnist,99.4,98.3,98.4,0.0,98.3,98.3,49.7,98.3,98.3
svhn,92.2,77.8,87.1,0.0,77.5,62.2,46.1,76.2,73.4


In [12]:
print((grouped_df * 100).round(1).to_latex(float_format="%.1f"))

\begin{tabular}{lrrrrrrrrr}
\toprule
group & \multicolumn{3}{r}{ind} & \multicolumn{3}{r}{ood} & \multicolumn{3}{r}{mix} \\
metric & local_acc & global_acc & switch_acc & local_acc & global_acc & switch_acc & local_acc & global_acc & switch_acc \\
dataset &  &  &  &  &  &  &  &  &  \\
\midrule
cifar10 & 75.1 & 39.6 & 59.1 & 0.0 & 38.8 & 28.8 & 37.6 & 39.4 & 44.1 \\
fmnist & 95.7 & 80.0 & 84.3 & 0.0 & 81.5 & 78.2 & 47.9 & 80.7 & 81.3 \\
medmnistA & 99.0 & 95.9 & 96.2 & 0.0 & 95.2 & 94.9 & 49.6 & 95.5 & 95.5 \\
medmnistC & 96.6 & 93.9 & 94.4 & 0.0 & 93.1 & 88.7 & 48.4 & 93.2 & 91.2 \\
medmnistS & 90.7 & 83.0 & 86.9 & 0.0 & 82.7 & 75.5 & 45.4 & 81.4 & 80.0 \\
mnist & 99.4 & 98.3 & 98.4 & 0.0 & 98.3 & 98.3 & 49.7 & 98.3 & 98.3 \\
svhn & 92.2 & 77.8 & 87.1 & 0.0 & 77.5 & 62.2 & 46.1 & 76.2 & 73.4 \\
\bottomrule
\end{tabular}

