# Compute GM

This notebook computes the **geometric mean score** (GM) for a trained model, following [Kim, Jeong, and Shin](https://arxiv.org/abs/2004.00431).

Geometric mean score (GM) is defined as the geometric mean of per-class recalls.

## Imports

In [1]:
import numpy as np
import torch

## Configuration

In [2]:
from omegaconf import OmegaConf

CONFIG = OmegaConf.create({
    # Model
    "model": "WideResNet-28-10-torchdistill",
    "dropout_rate": 0.3,
    "num_classes": 10,
    "noise_bn_option": "DARBN",
    # Dataset
    "image_size": 32,
    "normalize_mean": [0.4914, 0.4822, 0.4465],
    "normalize_std": [0.2023, 0.1994, 0.2010],
    "valid_transform_reprs": ["ToTensor()"],
    "batch_size": 128,
    "num_workers": 8,
    "enable_pin_memory": True,

    # To change
    "checkpoint_filename": "CIFAR10IR100-open__epoch_199.pt",
    "valid_dataset": "CIFAR10",
})

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Download checkpoint

In [4]:
import os
import gdown
from storage import CHECKPOINT_URLS

checkpoint_filepath = f"checkpoints/{CONFIG.checkpoint_filename}"
if not os.path.exists(checkpoint_filepath):
    gdown.download(CHECKPOINT_URLS[CONFIG.checkpoint_filename], checkpoint_filepath, quiet=False)

## Initialize Model

In [5]:
from initializers import initialize_model
from initializers import InputNormalize
from models.noise_bn_option import NoiseBnOption

net = initialize_model(
    model_name=CONFIG.model, 
    num_classes=CONFIG.num_classes, 
    noise_bn_option=NoiseBnOption[CONFIG.noise_bn_option],
    dropout_rate=CONFIG.dropout_rate)
net = net.to(device)

normalizer = InputNormalize(
    torch.Tensor(CONFIG.normalize_mean).to(device), 
    torch.Tensor(CONFIG.normalize_std).to(device)
).to(device)

In [6]:
from checkpointing import load_checkpoint

load_checkpoint(net, optimizer=None, checkpoint_filepath=checkpoint_filepath)
net = net.eval()

## Initialize Dataset

In [7]:
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100
from initializers import initialize_transforms

DATA_ROOT = "./data"
valid_transform = initialize_transforms(CONFIG.valid_transform_reprs)

if CONFIG.valid_dataset == "CIFAR10":
    valid_dataset = CIFAR10(root=DATA_ROOT, train=False, transform=valid_transform, download=True)
elif CONFIG.valid_dataset == "CIFAR100":
    valid_dataset = CIFAR100(root=DATA_ROOT, train=False, transform=valid_transform, download=True)
else:
    raise ValueError("CONFIG.valid_dataset should either be CIFAR10 or CIFAR100")

valid_loader = DataLoader(
    valid_dataset,
    batch_size=CONFIG.batch_size,
    num_workers=CONFIG.num_workers,
    pin_memory=CONFIG.enable_pin_memory,
)

Files already downloaded and verified


## Generate Predictions

In [8]:
# Get all the data
valid_outputs = []
valid_labels = []
for minibatch_i, (inputs, labels) in enumerate(valid_loader):
    inputs = inputs.float().to(device)
    labels = labels.to(device)

    inputs = normalizer(inputs)
    outputs = net(inputs)

    valid_outputs.extend(outputs.cpu().detach().tolist())
    valid_labels.extend(labels.cpu().detach().tolist())

valid_outputs = np.array(valid_outputs)
valid_labels = np.array(valid_labels)
valid_preds = np.argmax(valid_outputs, axis=1)

## Compute bACC and GM

In [9]:
from sklearn.metrics import recall_score
from scipy.stats import gmean

recalls = recall_score(y_true=valid_labels, y_pred=valid_preds, labels=np.arange(CONFIG.num_classes), average=None)
bacc = np.mean(recalls)
gm = gmean(recalls)
print(f"bACC: {bacc * 100:.2f}, GM: {gm * 100:.2f}")

bACC: 85.04, GM: 84.79
