**TL;DR**

<font color="red">!!! Best view in nbviewer <a href="https://nbviewer.org/github/sjtu-xai-lab/aog/blob/main/src/demo_sentiment_classification.ipynb"><img src="https://img.shields.io/badge/Open%20in-nbviewer-orange.svg"></img></a></font>

This is a demo showing how to 

1. compute interactions between input variables (words/phrases in NLP tasks) encoded by a given DNN; 
2. boost the consiceness of such interaction-based explanation by learning the optimal baseline value;
3. re-organizing such an explanation into a hierarchical And-Or Graph (AOG).

Based on the demo code below, you can generate an interactive version of the AOG. When the user hovers on each pattern in the AOG, its corresponding parse graph will be highlighted, which shows the interaction effect of a certain pattern.

**Load modules**

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from pprint import pprint
import argparse
import torch
import numpy as np
from tools.utils import set_seed
set_seed(0)

# packages for computing interactions
from harsanyi import AndHarsanyi, AndBaselineSparsifier
from harsanyi.init_baseline_values import pad_baseline_nlp, calc_word_emb_std
from harsanyi.remove_noisy import remove_noisy_greedy
from tools.metrics import eval_explain_ratio_given_indices
from tools.aggregate import aggregate_pattern_iterative
from graph.aog_utils import construct_AOG

# for models & datasets used
import models.nlp as models
from datasets import get_dataset
from setup_exp import init_dataset_model_settings


Set SEED: 0


**Load args**

In this example, we explain the interaction encoded by an LSTM model trained on the SST-2 dataset. The arguments for the dataset and the model are as follows.

In [2]:
args = argparse.Namespace(
    data_root = "../data/NLP",
    dataset = "sst2",
    arch = "lstm2_uni"
)
init_dataset_model_settings(args)

pprint(vars(args))

{'arch': 'lstm2_uni',
 'data_root': '../data/NLP',
 'dataset': 'sst2',
 'dataset_kwargs': {},
 'model_kwargs': {'embedding_dim': 100, 'hidden_dim': 256, 'output_dim': 1},
 'task': 'logistic_regression'}


**Load dataset**

In [3]:
sst2 = get_dataset(args.data_root, args.dataset, **args.dataset_kwargs)
TEXT, LABEL = sst2.get_fields()
train_loader, test_loader = sst2.get_dataloader(batch_size=1)

print(sst2)

The SST-2 dataset, #train: 67349 | #test: 872 | vocab_size: 13889 | label_to_id: defaultdict(None, {'1': 0, '0': 1})


**Load model**

In [4]:
net = models.__dict__[args.arch](
    vocab_size=len(TEXT.vocab),
    pad_idx=TEXT.vocab.stoi[TEXT.pad_token],
    **args.model_kwargs
).cuda()

ckpt_path = f"../saved-models/{args.arch}-{args.dataset}.pt"
net.load_state_dict(torch.load(ckpt_path, map_location=torch.device("cuda")))
net.to_eval_mode()

print(net)

LSTM(
  (embedding): Embedding(13889, 100, padding_idx=1)
  (lstm): LSTM(100, 256, num_layers=2, batch_first=True)
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)


**Load baseline value configs**

We set the initial baseline value for each word as `[PAD]`. The baseline value of each input variable was constrained within a certain range around its initialization.

In [5]:
single_baseline_vector = pad_baseline_nlp(net, text_field=TEXT)
word_emb_std = calc_word_emb_std(net, text_field=TEXT)

bound_threshold = 0.05
min_baseline_vector = single_baseline_vector - bound_threshold * word_emb_std  # [1, emb_dim]
max_baseline_vector = single_baseline_vector + bound_threshold * word_emb_std  # [1, emb_dim]

print("initial baseline value:", single_baseline_vector)
print("lower bound:", min_baseline_vector)
print("upper bound:", max_baseline_vector)

initial baseline value: tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.]], device='cuda:0')
lower bound: tensor([[-0.0533, -0.0510, -0.0529, -0.0514, -0.0505, -0.0501, -0.0491, -0.0510,
         -0.0526, -0.0523, -0.0486, -0.0538, -0.0519, -0.0513, -0.0499, -0.0526,
         -0.0516, -0.0518, -0.0520, -0.0515, -0.0508, -0.0519, -0.0557, -0.0510,
         -0.0522, -0.0517, -0.0535, -0.0518, -0.0510, -0.0547, -0.0530, -0.0501,
         -0.0505, -0.0519, -0.0519, -0.0529, -0.0514, -0.0507, -0.0527, -0.0518,
         -0.0513, -0.0532, -0.0502, -0.0516, -0.0545, -0.0497, -0.0542, -0.0532,
         -0.0

**Define the mask function**

When computing $v(S)$, we mask input variables in $N\setminus S$, while retaining input variables in $S$. This function specializes how to mask each input variable.

In [6]:
def masked_input_fn(input_embs, baselines, S_list):
    assert input_embs.shape[0] == 1
    assert baselines.shape[0] == 1
    batch_size = len(S_list)

    input_embs_batch = input_embs.expand(batch_size, input_embs.shape[1], input_embs.shape[2]).clone()
    baselines_batch = baselines.expand(batch_size, input_embs.shape[1], input_embs.shape[2]).clone()

    mask = torch.zeros_like(input_embs_batch, device=input_embs.device)

    for i, S in enumerate(S_list):
        mask[i, S] = 1.  # variables in S are NOT masked

    return mask * input_embs_batch + (1 - mask) * baselines_batch

Now, we have done all the preparation work. We can now extract interaction patterns encoded by the LSTM model.

**The example sentence**

In this example, we use the sentence "it's just not very smart." The sentence needs to be tokenized and encoded into ids before feeding into the model.

In [7]:
sentence = "it's just not very smart."
words = TEXT.tokenize(sentence)
print("tokenized", words)
text_ids = [[TEXT.vocab.stoi[word] for word in words]]
input = torch.LongTensor(text_ids).cuda()
input_length = torch.LongTensor([len(a) for a in text_ids])

label = LABEL.vocab.stoi['0']  # The ground truth is negative sentiment
label2text = {
    '0': 'negative sentiment',
    '1': 'positive sentiment'
}

tokenized ['it', "'s", 'just', 'not', 'very', 'smart', '.']


Then, we generate the word embedding for each token.

In [8]:
with torch.no_grad():
    input_emb = net.get_emb(input)

print(input_emb.shape)

torch.Size([1, 7, 100])


Initialize the baseline value for each token.

In [9]:
baseline_init = torch.cat([single_baseline_vector] * input_length.item(), dim=0).unsqueeze(0)
baseline_max = torch.cat([max_baseline_vector] * input_length.item(), dim=0).unsqueeze(0)
baseline_min = torch.cat([min_baseline_vector] * input_length.item(), dim=0).unsqueeze(0)

print(baseline_init.shape, baseline_max.shape, baseline_min.shape)

torch.Size([1, 7, 100]) torch.Size([1, 7, 100]) torch.Size([1, 7, 100])


The inference score is defined as the output logit of the model.

In [10]:
if label == LABEL.vocab.stoi['0']:
    selected_dim = "neg-logistic-odds"
else:
    selected_dim = "logistic-odds"

We can compute the harsanyi interaction using the initialized baseline value using the code below

In [11]:
calculator = AndHarsanyi(
    model=lambda x: net.emb2out(x),
    selected_dim=selected_dim,
    x=input_emb,
    baseline=baseline_init,
    y=label,
    all_players=list(range(input_length.item())),
    mask_input_fn=masked_input_fn,
    verbose=0
)
calculator.attribute()

**Boosting consiceness by learning the optimal baseline value**

As mentioned in [this paper](https://arxiv.org/abs/2105.10719), the optimal baseline values provide a perspective that simplifies the explanation of the deep model. We further boost the sparsity of interaction patterns by learning the optimal baseline values that enhance the conciseness of the explanation.

In [12]:
sparsifer = AndBaselineSparsifier(
    calculator=calculator, loss="l1",
    baseline_min=baseline_min,
    baseline_max=baseline_max,
    baseline_lr=1e-3, niter=50
)
sparsifer.sparsify()
masks = sparsifer.get_masks()
interactions = sparsifer.get_interaction()
rewards = sparsifer.get_rewards()

print(masks.shape, interactions.shape, rewards.shape)

Optimizing b:   0%|                                   | 0/50 [00:00<?, ?it/s, [0/50] loss: 356.6899]

[0/50] loss: 356.6899


Optimizing b:  28%|███████                  | 14/50 [00:00<00:01, 23.27it/s, [14/50] loss: 194.0877]

[10/50] loss: 201.7698


Optimizing b:  46%|███████████▌             | 23/50 [00:01<00:01, 22.52it/s, [24/50] loss: 184.9016]

[20/50] loss: 187.5781


Optimizing b:  70%|█████████████████▌       | 35/50 [00:01<00:00, 21.87it/s, [34/50] loss: 181.2826]

[30/50] loss: 182.4417


Optimizing b:  88%|██████████████████████   | 44/50 [00:02<00:00, 19.86it/s, [43/50] loss: 179.6281]

[40/50] loss: 180.0465


Optimizing b: 100%|█████████████████████████| 50/50 [00:02<00:00, 21.43it/s, [49/50] loss: 178.9564]

[49/50] loss: 178.9564
torch.Size([128, 7]) torch.Size([128]) torch.Size([128])





**Removing noisy patterns**

We use a greedy strategy to remove the noisy patterns from $2^{\mathcal{N}}=\{\mathcal{S}:\mathcal{S}\subseteq \mathcal{N}\}$ and keep the salient causal patterns to construct the set of salient patterns $\Omega\subseteq 2^{\mathcal{N}}$.

In [13]:
selected_indices = remove_noisy_greedy(
    rewards, interactions.clone(), masks,
    min_patterns=15, n_greedy=40,
    thres_square_error=0.01,
    thres_explain_ratio=0.95
)["final_retained"]
explain_ratio = eval_explain_ratio_given_indices(interactions, masks, selected_indices)

Removing noisy patterns -- # coalitions: 33 | normalized error: 0.0098 | explain ratio: 0.9921



**Re-organizing the explanation into an AOG**

We summarize common coalitions shared by different patterns, and rewritten the above explanation into an AOG.

In [14]:
merged_patterns, aggregated_concepts, _ = aggregate_pattern_iterative(
    patterns=masks[selected_indices],
    interactions=interactions[selected_indices],
)
single_features = np.vstack([np.eye(len(words)).astype(bool), merged_patterns])

# construct the AOG
aog = construct_AOG(
    attributes=words,
    single_features=single_features,
    concepts=aggregated_concepts,
    interactions=interactions[selected_indices]
)

**Visualizing the AOG**

Finally, we generate an interactive version of the AOG. When the user **hovers on each pattern in the AOG**, its corresponding parse graph will be highlighted, which shows the interaction effect of a certain pattern.

In [15]:
from IPython.display import HTML, display
import re
import mpld3

max_node = 12

fig = aog.visualize(
    save_path="nlp_aog.html", figsize=(15, 7),
    renderer="networkx",
    n_row_interaction=int(np.ceil(len(selected_indices) / max_node)),
    title=f"output={interactions[selected_indices].sum():.2f} "
          f"| label: {label2text[LABEL.vocab.itos[label]]} "
          f"| R={100 * explain_ratio:.2f}%",
    show=True
)
# mpld3.display(fig) # TOO big
display(HTML('<div style="zoom:50%;">' + mpld3.fig_to_html(fig) + '</div>'))

- The leaf nodes (in gray) are input variables (words). 
- The AND nodes in the second layer (in purple) indicates common coalitions shared by different patterns. 
- The AND nodes in red/blue represents interaction patterns encoded by the model.
- The OR node (the root node) sums up the interaction effect of each pattern.

We can focus on some salient patterns.

- Patterns with positive effects

    - I(smart)=6.57
    - I('s, smart)=5.83
    - I(very, smart)=3.72

- Patterns with negative effects

    - I(not, smart)=-13.48
    - I(not, very)=-12.83
    - I(just)=-7.28
    - I('s, not, smart)=-5.23