## Quickstart
In this notebook, we are going to showcase how to build a post-hoc explainer agreement neural network. We will break up
this quickstart into the following sections:
- Load Data
- Create Model
- Create Explainers
- Create Disagreement Loss
- Train Model
- Save Model
- Evaluate Model

In [1]:
import torch
from torch.optim import SGD, Adam, AdamW

import pear

# Define constants
dataset = "californiahousing"
batch_size = 64
model_cfg = {"name": "mlp",
             "width": 100,
             "depth": 3}
optimizer = "adamw"
lr = 5e-3
weight_decay = 2e-4
explainers = ["vanilla_gradients", "integrated_gradients"]
disagreement_lambda = 0.25
disagreement_mu = 0.5
epochs = 3

### Load Data

In [2]:
# Get data
loader_train, loader_test, num_classes = pear.get_data(dataset,
                                                       batch_size,
                                                       data_path="../pear/datasets")
input_dim = loader_train.dataset.data.shape[1]
num_training_data = loader_train.dataset.data.shape[0]
num_testing_data = loader_test.dataset.data.shape[0]
print(f"{dataset} dataset with {num_training_data} training samples and {num_testing_data} testing samples"
      f" and {input_dim} features and "
      f"{torch.unique(torch.tensor(loader_train.dataset.targets), return_counts=True)} classes.")

californiahousing dataset with 15475 training samples and 5159 testing samples and 8 features and (tensor([0, 1]), tensor([7686, 7789])) classes.


### Create Model

In [3]:
# Create a model
model = pear.get_model(model_cfg, input_dim, num_classes)
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(f"This {model_cfg['name']} has {pytorch_total_params / 1e3:0.3f} thousand parameters.")

This mlp has 21.302 thousand parameters.


### Create Explainers (and optimizer)

In [4]:
# Get an optimizer
params = model.parameters()
if optimizer == "sgd":
    optim = SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9)
elif optimizer == "adam":
    optim = Adam(params, lr=lr, weight_decay=weight_decay)
elif optimizer == "adamw":
    optim = AdamW(params, lr=lr, weight_decay=weight_decay)

# Create two explainers for the loss
explainer_a = pear.get_explainer(explainers[0], model, torch.tensor(loader_train.dataset.data))
explainer_b = pear.get_explainer(explainers[1], model, torch.tensor(loader_train.dataset.data))

# Initialize LIME and SHAP for evaluation
lime = pear.get_explainer("lime", model, torch.tensor(loader_train.dataset.data))
shap = pear.get_explainer("shap", model, torch.tensor(loader_train.dataset.data))

### Create Disagreement Loss

In [5]:
# get a metric and define a loss function
disagreement_loss_fn = pear.DisagreementLoss(explainer_a, explainer_b, disagreement_mu)
current_disagreement_lambda = 0

### Train Model

In [6]:
# TODO: set up tqdm
# Training loop
for epoch in range(epochs):
    _ = model.train_loop(trainloader=loader_train,
                         disagreement_lambda=current_disagreement_lambda,
                         optimizer=optim,
                         task_loss_fn=torch.nn.CrossEntropyLoss(),
                         disagreement_loss_fn=disagreement_loss_fn)
    evaluation_on_train_data = model.evaluate_balanced(loader_train,
                                                       task_loss_fn=torch.nn.CrossEntropyLoss(),
                                                       disagreement_loss_fn=disagreement_loss_fn)
    evaluation_on_test_data = model.evaluate_balanced(loader_test,
                                                      task_loss_fn=torch.nn.CrossEntropyLoss(),
                                                      disagreement_loss_fn=disagreement_loss_fn)

    print(f"epoch {epoch:2d} | "
          f"task loss {evaluation_on_train_data['task_loss']:.4f} | "
          f"disagree loss {evaluation_on_train_data['disagreement_loss']:.4f} | "
          f"train bal acc {(evaluation_on_train_data['acc_0'] + evaluation_on_train_data['acc_1']) / 2:.2f} | "
          f"test bal acc {(evaluation_on_test_data['acc_0'] + evaluation_on_test_data['acc_1']) / 2:.2f} | "
          )

    disagreement_vals = pear.get_disagreement_values(loader_train, explainer_a, explainer_b, disagreement_k=5)
    eval_str = f""
    for k in disagreement_vals.keys():
        eval_str += f"{k}: {disagreement_vals[k]:.4f} | "
    print(eval_str)

    disagreement_vals = pear.get_disagreement_values(loader_test, explainer_a, explainer_b, disagreement_k=5)
    eval_str = f""
    for k in disagreement_vals.keys():
        eval_str += f"{k}: {disagreement_vals[k]:.4f} | "
    print(eval_str)

                                                                                                                   

epoch  0 | task loss 0.4101 | disagree loss 0.0043 | train bal acc 80.75 | test bal acc 80.04 | 
feature_agreement: 0.9988 | rank_agreement: 0.8640 | sign_agreement: 0.9988 | signed_rank_agreement: 0.8640 | rank_correlation: 0.9795 | pairwise_rank: 0.9714 | 
feature_agreement: 0.9988 | rank_agreement: 0.8568 | sign_agreement: 0.9988 | signed_rank_agreement: 0.8568 | rank_correlation: 0.9792 | pairwise_rank: 0.9709 | 


                                                                                                                   

epoch  1 | task loss 0.4054 | disagree loss 0.0043 | train bal acc 81.49 | test bal acc 82.28 | 
feature_agreement: 0.9844 | rank_agreement: 0.8476 | sign_agreement: 0.9844 | signed_rank_agreement: 0.8476 | rank_correlation: 0.9797 | pairwise_rank: 0.9734 | 
feature_agreement: 0.9848 | rank_agreement: 0.8216 | sign_agreement: 0.9848 | signed_rank_agreement: 0.8216 | rank_correlation: 0.9773 | pairwise_rank: 0.9706 | 


                                                                                                                   

epoch  2 | task loss 0.3840 | disagree loss 0.0104 | train bal acc 82.67 | test bal acc 83.30 | 
feature_agreement: 0.9696 | rank_agreement: 0.6968 | sign_agreement: 0.9696 | signed_rank_agreement: 0.6968 | rank_correlation: 0.9547 | pairwise_rank: 0.9457 | 
feature_agreement: 0.9704 | rank_agreement: 0.6912 | sign_agreement: 0.9704 | signed_rank_agreement: 0.6912 | rank_correlation: 0.9544 | pairwise_rank: 0.9450 | 


### Evaluate the Model

In [7]:
# Save the results from this training run
result = {"model": model_cfg["name"],
          "dataset": dataset,
          "task_loss_test_data": evaluation_on_test_data['task_loss'],
          "disagreement_loss_test_data": evaluation_on_test_data['disagreement_loss'],
          "task_loss_train_data": evaluation_on_train_data['task_loss'],
          "disagreement_loss_train_data": evaluation_on_train_data['disagreement_loss'],
          "train_acc": evaluation_on_train_data['acc'],
          "test_acc": evaluation_on_test_data['acc'],
          }

In [8]:
print(result)

{'model': 'mlp', 'dataset': 'californiahousing', 'task_loss_test_data': 0.373055636882782, 'disagreement_loss_test_data': 0.010415511259818282, 'task_loss_train_data': 0.3840358555316925, 'disagreement_loss_train_data': 0.010426596146510576, 'train_acc': 82.70759289176091, 'test_acc': 83.25256832719519}
