## PEAR Training Experiment

This is a notebook where we train two models with different PEAR loss hyperparamters and compare their performance and consistency metrics.


In [1]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from torch.optim import AdamW

import pear

In [2]:
# Define constants
dataset = "californiahousing"
batch_size = 64
model_cfg = {"name": "mlp",
             "width": 100,
             "depth": 3}
lr = 5e-4
weight_decay = 2e-4
explainers = ["vanilla_gradients", "integrated_gradients"]
disagreement_mu = 0.5

In [3]:
loader_train, loader_test, num_classes = pear.get_data(dataset,
                                                       batch_size,
                                                       data_path="../pear/datasets")
input_dim = loader_train.dataset.data.shape[1]
num_training_data = loader_train.dataset.data.shape[0]
num_testing_data = loader_test.dataset.data.shape[0]
print(f"{dataset} dataset with {num_training_data} training samples and {num_testing_data} testing samples"
      f" and {input_dim} features and "
      f"{torch.unique(torch.tensor(loader_train.dataset.targets), return_counts=True)} classes.")

californiahousing dataset with 15475 training samples and 5159 testing samples and 8 features and (tensor([0, 1]), tensor([7686, 7789])) classes.


In [4]:
# create a model trained with pear (lambda = 0.5)
disagreement_lambda = 0.5
epochs = 50

model_pear = pear.get_model(model_cfg, input_dim, num_classes)
pytorch_total_params = sum(p.numel() for p in model_pear.parameters())
print(f"This {model_cfg['name']} has {pytorch_total_params / 1e3:0.3f} thousand parameters.")

This mlp has 21.302 thousand parameters.


In [5]:
# get an optimizer
params = model_pear.parameters()
optim = AdamW(params, lr=lr, weight_decay=weight_decay)

# create two explainers for the loss
explainer_a = pear.get_explainer(explainers[0], model_pear, torch.tensor(loader_train.dataset.data))
explainer_b = pear.get_explainer(explainers[1], model_pear, torch.tensor(loader_train.dataset.data))
disagreement_loss_fn = pear.DisagreementLoss(explainer_a, explainer_b, disagreement_mu)

In [None]:
for epoch in range(epochs):
    _ = model_pear.train_loop(trainloader=loader_train,
                              disagreement_lambda=disagreement_lambda,
                              optimizer=optim,
                              task_loss_fn=torch.nn.CrossEntropyLoss(),
                              disagreement_loss_fn=disagreement_loss_fn)
    evaluation_on_train_data = model_pear.evaluate_balanced(loader_train,
                                                            task_loss_fn=torch.nn.CrossEntropyLoss(),
                                                            disagreement_loss_fn=disagreement_loss_fn)
    evaluation_on_test_data = model_pear.evaluate_balanced(loader_test,
                                                           task_loss_fn=torch.nn.CrossEntropyLoss(),
                                                           disagreement_loss_fn=disagreement_loss_fn)

    print(f"epoch {epoch:2d} | "
          f"task loss {evaluation_on_train_data['task_loss']:.4f} | "
          f"disagree loss {evaluation_on_train_data['disagreement_loss']:.4f} | "
          f"train bal acc {(evaluation_on_train_data['acc_0'] + evaluation_on_train_data['acc_1']) / 2:.2f} | "
          f"test bal acc {(evaluation_on_test_data['acc_0'] + evaluation_on_test_data['acc_1']) / 2:.2f} | "
          )

                                                  

epoch  0 | task loss 0.4462 | disagree loss 0.0035 | train bal acc 78.04 | test bal acc 78.80 | 


                                                  

epoch  1 | task loss 0.4118 | disagree loss 0.0034 | train bal acc 80.11 | test bal acc 80.73 | 


                                                  

epoch  2 | task loss 0.4014 | disagree loss 0.0028 | train bal acc 80.90 | test bal acc 81.37 | 


                                                  

epoch  3 | task loss 0.4015 | disagree loss 0.0030 | train bal acc 80.87 | test bal acc 80.28 | 


                                                  

epoch  4 | task loss 0.3946 | disagree loss 0.0034 | train bal acc 81.39 | test bal acc 81.65 | 


                                                  

epoch  5 | task loss 0.4092 | disagree loss 0.0024 | train bal acc 80.65 | test bal acc 81.37 | 


                                                  

epoch  6 | task loss 0.3920 | disagree loss 0.0029 | train bal acc 81.39 | test bal acc 81.52 | 


                                                  

epoch  7 | task loss 0.3903 | disagree loss 0.0027 | train bal acc 81.50 | test bal acc 82.12 | 


                                                  

epoch  8 | task loss 0.4034 | disagree loss 0.0038 | train bal acc 80.65 | test bal acc 79.92 | 


                                                  

epoch  9 | task loss 0.3919 | disagree loss 0.0032 | train bal acc 81.58 | test bal acc 82.01 | 


                                                  

epoch 10 | task loss 0.3876 | disagree loss 0.0030 | train bal acc 81.80 | test bal acc 82.00 | 


                                                  

epoch 11 | task loss 0.3899 | disagree loss 0.0029 | train bal acc 81.72 | test bal acc 81.59 | 


                                                 

epoch 12 | task loss 0.3850 | disagree loss 0.0036 | train bal acc 81.97 | test bal acc 82.20 | 


                                                 

epoch 13 | task loss 0.3875 | disagree loss 0.0046 | train bal acc 81.88 | test bal acc 81.68 | 


                                                 

epoch 14 | task loss 0.3836 | disagree loss 0.0041 | train bal acc 82.17 | test bal acc 82.33 | 


  9%|▉         | 22/242 [00:00<00:02, 102.29it/s]

In [None]:
# create a second model with a different lambda
disagreement_lambda = 0.0
epochs = 30

model_vanilla = pear.get_model(model_cfg, input_dim, num_classes)
pytorch_total_params = sum(p.numel() for p in model_vanilla.parameters())
print(f"This {model_cfg['name']} has {pytorch_total_params / 1e3:0.3f} thousand parameters.")

In [None]:
# get an optimizer
params = model_vanilla.parameters()
optim = AdamW(params, lr=lr, weight_decay=weight_decay)

# create two explainers for the loss
explainer_a = pear.get_explainer(explainers[0], model_vanilla, torch.tensor(loader_train.dataset.data))
explainer_b = pear.get_explainer(explainers[1], model_vanilla, torch.tensor(loader_train.dataset.data))

In [None]:
for epoch in range(epochs):
    _ = model_vanilla.train_loop(trainloader=loader_train,
                                 disagreement_lambda=disagreement_lambda,
                                 optimizer=optim,
                                 task_loss_fn=torch.nn.CrossEntropyLoss(),
                                 disagreement_loss_fn=disagreement_loss_fn)
    evaluation_on_train_data = model_vanilla.evaluate_balanced(loader_train,
                                                               task_loss_fn=torch.nn.CrossEntropyLoss(),
                                                               disagreement_loss_fn=disagreement_loss_fn)
    evaluation_on_test_data = model_vanilla.evaluate_balanced(loader_test,
                                                              task_loss_fn=torch.nn.CrossEntropyLoss(),
                                                              disagreement_loss_fn=disagreement_loss_fn)

    print(f"epoch {epoch:2d} | "
          f"task loss {evaluation_on_train_data['task_loss']:.4f} | "
          f"disagree loss {evaluation_on_train_data['disagreement_loss']:.4f} | "
          f"train bal acc {(evaluation_on_train_data['acc_0'] + evaluation_on_train_data['acc_1']) / 2:.2f} | "
          f"test bal acc {(evaluation_on_test_data['acc_0'] + evaluation_on_test_data['acc_1']) / 2:.2f} | "
          )

In [None]:
metric = "pairwise_rank"
red_grid_data_pear = pear.disagreement_matrices(model_pear, loader_train, loader_test, k=5, metric=metric)
red_grid_data_vanilla = pear.disagreement_matrices(model_vanilla, loader_train, loader_test, k=5, metric=metric)

explainer_indices = {"vanilla_gradients": 2,
                     "integrated_gradients": 4,
                     "shap": 1,
                     "lime": 0,
                     "input_x_gradient": 3,
                     "smooth_grad": 5}

explainer_pairs = [
    "input_x_gradient_v_input_x_gradient",
    "input_x_gradient_v_integrated_gradients",
    "input_x_gradient_v_lime",
    "input_x_gradient_v_shap",
    "input_x_gradient_v_smooth_grad",
    "input_x_gradient_v_vanilla_gradients",
    "integrated_gradients_v_integrated_gradients",
    "integrated_gradients_v_lime",
    "integrated_gradients_v_shap",
    "integrated_gradients_v_smooth_grad",
    "integrated_gradients_v_vanilla_gradients",
    "lime_v_lime",
    "lime_v_shap",
    "lime_v_smooth_grad",
    "lime_v_vanilla_gradients",
    "shap_v_shap",
    "shap_v_smooth_grad",
    "shap_v_vanilla_gradients",
    "smooth_grad_v_smooth_grad",
    "smooth_grad_v_vanilla_gradients",
    "vanilla_gradients_v_vanilla_gradients",
]

In [None]:
metric_strs = {
    "feature_agreement": "Feature Agreement",
    "rank_agreement": "Rank Agreement",
    "sign_agreement": "Sign Agreement",
    "signed_rank_agreement": "Signed Rank Agreement",
    "rank_correlation": "Rank Correlation",
    "pairwise_rank": "Pairwise Rank Agreement",
}

In [None]:
cmap = sns.color_palette("light:darkred", as_cmap=True)
metric = "pairwise_rank"
fs = 12
tables = [red_grid_data_pear, red_grid_data_vanilla]
lams = [0.5, 0.0]
for table, lam in zip(tables, lams):
    to_plot = np.zeros((len(explainer_indices.keys()), len(explainer_indices.keys())))
    for explainer_pair in explainer_pairs:
        exs = explainer_pair.split("_v_")
        to_plot[explainer_indices[exs[0]], explainer_indices[exs[1]]] = table[metric][explainer_pair]
        to_plot[explainer_indices[exs[1]], explainer_indices[exs[0]]] = table[metric][explainer_pair]

    fig, ax = plt.subplots(figsize=(5, 4.8))
    if "correlation" in metric:
        sns.heatmap(to_plot, vmin=-1, vmax=1, cmap=cmap, ax=ax, annot=True)
        title = f"California Housing Data\n{metric_strs[metric]}\n$\lambda$ = {lam}"
    else:
        sns.heatmap(to_plot, vmin=0, vmax=1, cmap=cmap, ax=ax, annot=True)
        title = f"California Housing Data\n{metric_strs[metric]}\n$\lambda$ = {lam} and k = {5}"
    ax.set_title(title, fontsize=fs)

    ax.set_xticks([0 + 0.5, 1 + 0.5, 2 + 0.5, 3 + 0.5, 4 + 0.5, 5 + 0.5],
                  ["LIME", "SHAP", "Grad", "Grad*\nInput", "IntGrad", "Smooth\nGrad"],
                  rotation=0,
                  fontsize=fs)
    ax.set_yticks([0 + 0.5, 1 + 0.5, 2 + 0.5, 3 + 0.5, 4 + 0.5, 5 + 0.5],
                  ["LIME", "SHAP", "Grad", "Grad*\nInput", "IntGrad", "Smooth\nGrad"],
                  rotation=0,
                  fontsize=fs)
    plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(12, 9))
title = f"California Housing Data\nLIME vs. SHAP"
fs = 48
ax.set_title(title, fontsize=fs)

vanilla_acc = model_vanilla.evaluate(loader_test,
                                     task_loss_fn=torch.nn.CrossEntropyLoss(),
                                     disagreement_loss_fn=disagreement_loss_fn)["acc"]
vanilla_agreement = red_grid_data_vanilla[metric]["lime_v_shap"]
ax.plot([vanilla_acc], [vanilla_agreement],
        marker="*",
        markersize=22,
        linestyle="none",
        color="grey",
        label="Vanilla")

pear_acc = model_vanilla.evaluate(loader_test,
                                  task_loss_fn=torch.nn.CrossEntropyLoss(),
                                  disagreement_loss_fn=disagreement_loss_fn)["acc"]
pear_agreement = red_grid_data_pear[metric]["lime_v_shap"]
ax.plot([pear_acc], [pear_agreement],
        marker="o",
        markersize=22,
        linestyle="none",
        color="red",
        label="PEAR")

ax.set_xlim([81, 85.5])
x = np.arange(81, 86, 1)
ax.set_xticks(x)
ax.set_xticklabels(x, fontsize=fs, rotation=0)
ax.tick_params(axis='x', labelsize=fs - 8)
ax.set_xlabel("Test Accuracy (%)", fontsize=fs)

ax.set_ylim([0.7, 0.85])
y = np.arange(0.72, 0.85, 0.04)
ax.set_yticks(y)
ax.set_yticklabels([f"{i:0.2f}" for i in y], fontsize=fs, rotation=0)
ax.tick_params(axis='y', labelsize=fs - 8)
ax.set_ylabel(f"{metric_strs[metric]}", fontsize=fs)

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.legend(loc='upper left', bbox_to_anchor=(0.2, 0.5), fontsize=fs / 2)
plt.tight_layout()
plt.show()