In [10]:
import sys
import os
sys.path.insert(0, "/home/farnoush/symbolicXAI")
from model.transformer import tiny_transformer_with_3_layers
from model.utils import load_pretrained_weights
from lrp.symbolic_xai import TransformerSymbXAI
import transformers
import torch
import dgl
import networkx as nx
from dgl.data import SSTDataset
from visualization.utils import create_text_heat_map
from itertools import product
from tqdm import tqdm
from lrp.queries import run_query
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from IPython.display import display, HTML

## Create the model

In [2]:
# Create model.
tiny_transformer = tiny_transformer_with_3_layers(
    pretrained_model_name_or_path="textattack/bert-base-uncased-SST-2"
)
pretrained_embeddings = tiny_transformer.bert.embeddings

# Load pre-trained weights.
load_pretrained_weights(
    tiny_transformer,
        '/home/farnoush/fairness/sst2-3layer-model.pt'
)

## Load SST-2 dataset

In [3]:
# Load SST2 dataset.
dataset = load_dataset("sst2", "default")

Found cached dataset sst2 (/home/farnoush/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)


  0%|          | 0/3 [00:00<?, ?it/s]

## Evaluate model

In [10]:
acc = 0
tiny_transformer.eval()

tokenizer = transformers.BertTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")
UNK_IDX = tokenizer.unk_token_id  # an out-of-vocab token

for i in tqdm(range(len(dataset['validation']))):
    sentence = dataset['validation']['sentence'][i]
    target = dataset['validation']['label'][i]

    x = tokenizer(sentence, return_tensors="pt")
    words = tokenizer.convert_ids_to_tokens(x['input_ids'].squeeze())
    
    logits = tiny_transformer(x)
    prediction = logits.argmax()
    
    if prediction == target:
        acc += 1

print("Top-1 acc: {}".format(acc / len(dataset['validation'])))

100%|█████████████████████████████████████████| 872/872 [00:13<00:00, 65.17it/s]

Top-1 acc: 0.75





## Explain model

In [17]:
i = 402
sentence = dataset['validation']['sentence'][i]
target = torch.tensor([-1, 1])

tokenizer = transformers.BertTokenizer.from_pretrained("textattack/bert-base-uncased-SST-2")

sample = tokenizer(sentence, return_tensors="pt")
words = tokenizer.convert_ids_to_tokens(sample['input_ids'].squeeze())

symb_xai_transformer = TransformerSymbXAI(
    model=tiny_transformer,
    embeddings=pretrained_embeddings,
    sample=sample,
    target=target
)

R1 = symb_xai_transformer.subgraph_relevance(subgraph=range(2, 5), from_walks=False)
print("it, may, please: {}".format(R2_5))
R = np.zeros(len(words))
R[2: 5] = R1
display(HTML(create_text_heat_map(words, R.squeeze())))

R2 = symb_xai_transformer.subgraph_relevance(subgraph=range(-5, -2), from_walks=False)
print("away, in, disgust: {}".format(R2))
R = np.zeros(len(words))
R[-5: -2] = R2
display(HTML(create_text_heat_map(words, R.squeeze())))

R3 = symb_xai_transformer.subgraph_relevance(subgraph=[1], from_walks=False)
print("however: {}".format(R3))
R = np.zeros(len(words))
R[1] = R3
display(HTML(create_text_heat_map(words, R.squeeze())))

it, may, please: 4.5922017097473145


away, in, disgust: -39.019615173339844


however: 0.5737371444702148
