# Evaluate models

In [2]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from torch.nn.functional import softmax
from sklearn.calibration import calibration_curve
import sys
sys.argv = ['']


## Utils functions

In [3]:
def compute_ece(probs, labels, n_bins=15):
    """Expected Calibration Error"""
    confidences = probs.max(1)
    predictions = probs.argmax(1)
    accuracies = predictions == labels

    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    ece = torch.zeros(1, device=probs.device)

    for i in range(n_bins):
        mask = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
        if mask.any():
            acc = accuracies[mask].float().mean()
            conf = confidences[mask].mean()
            ece += (conf - acc).abs() * mask.float().mean()

    return ece.item()

In [4]:
def plot_confusion(y_true, y_pred, labels, title):
    cm = confusion_matrix(y_true, y_pred)
    fig, ax = plt.subplots(figsize=(8, 6))
    cax = ax.matshow(cm, cmap=plt.cm.Blues)
    plt.colorbar(cax)

    ax.set_xticks(np.arange(len(labels)))
    ax.set_yticks(np.arange(len(labels)))
    ax.set_xticklabels(labels, rotation=45, ha='left')
    ax.set_yticklabels(labels)

    # Annotate each cell with its value
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, str(cm[i, j]), va='center', ha='center', color='white' if cm[i, j] > cm.max()/2 else 'black')

    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

In [5]:
def compute_accuracy_f1(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='macro')
    return acc, f1

## MNIST Addition

In [None]:
from expressive.util import get_device
from torch.utils.data import DataLoader
import torch
import wandb

from expressive.experiments.mnist_op.absorbing_mnist import (
    MNISTAddProblem,
    create_mnistadd,
    vector_to_base10,
)
from expressive.args import MNISTAbsorbingArguments
from expressive.experiments.mnist_op.data import (
    create_nary_multidigit_operation,
    get_mnist_op_dataloaders,
)
import math

from expressive.methods.logger import (
    TestLog,
    TrainingLog,
    TrainLogger,
    TestLogger,
)

In [None]:
def test(
    val_loader: DataLoader,
    test_logger: TestLog,
    model: MNISTAddProblem,
    device: torch.device,
    args,
):
    for i, batch in enumerate(val_loader):
        mn_digits, label_digits, label = (
            batch[: 2 * args.N],
            batch[2 * args.N : -1],
            batch[-1],
        )
        x = torch.cat(mn_digits, dim=1)
        model.evaluate(
            x.to(device),
            vector_to_base10(label.to(device), args.N + 1),
            torch.stack(label_digits, dim=-1).to(device),
            test_logger.log,
        )
        if args.DEBUG:
            break
    test_logger.push(len(val_loader))


In [7]:
def evaluate_model(model, dataloader, device):
    all_preds = []
    all_labels = []
    all_concepts = []
    all_concept_labels = []

    with torch.no_grad():
        for x, concepts, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            concepts = concepts.to(device)

            preds, concept_preds = model(x)
            y_pred = preds.argmax(dim=1)

            all_preds.append(y_pred.cpu())
            all_labels.append(y.cpu())

            c_preds = concept_preds.argmax(dim=-1)
            all_concepts.append(c_preds.cpu())
            all_concept_labels.append(concepts.cpu())

    y_true = torch.cat(all_labels)
    y_pred = torch.cat(all_preds)

    c_true = torch.cat(all_concept_labels).view(-1)
    c_pred = torch.cat(all_concepts).view(-1)

    return y_true.numpy(), y_pred.numpy(), c_true.numpy(), c_pred.numpy()

In [None]:
args = MNISTAbsorbingArguments(explicit_bool=True).parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_mnistadd(args).to(device)
arity = 2
digits_per_number = args.N
n_operands = arity * digits_per_number

bin_op = sum if args.op == "sum" else math.prod if args.op == "product" else None
op = create_nary_multidigit_operation(arity, bin_op)

train_size = 60000 if args.test else 50000
val_size = 0 if args.test else 10000
train_loader, val_loader, test_loader = get_mnist_op_dataloaders(
    count_train=int(train_size / n_operands),
    count_val=int(val_size / n_operands),
    count_test=int(10000 / n_operands),
    batch_size=args.batch_size,
    n_operands=n_operands,
    op=op,
    shuffle=True,
)

model.load_state_dict(torch.load("models/gv2nx5am/model_12.pth", map_location=device))
model.eval()
model.to(device)

test_logger = TestLogger(TestLog, args, "test")
test(test_loader, test_logger, model, device)

120
