# Load the Pretrained Model and the dataset
We use ernie-2.0-base-en as the model and SST-2 as the dataset for example. More models can be found in [PaddleNLP Model Zoo](https://paddlenlp.readthedocs.io/zh/latest/model_zoo/index.html#transformer).

Obviously, PaddleNLP is needed to run this notebook, which is easy to install:
```bash
pip install setuptools_scm 
pip install --upgrade paddlenlp
```

In [None]:
import paddle
import paddlenlp
from paddlenlp.transformers import ErnieTokenizer, ErnieForSequenceClassification

MODEL_NAME = "ernie-2.0-base-en"

model = ErnieForSequenceClassification.from_pretrained(MODEL_NAME, num_classes=2)
tokenizer = ErnieTokenizer.from_pretrained(MODEL_NAME)

In [2]:
# replace the attention dropout layer with nn.Dropout
from assets.utils import layer_replacement
model = layer_replacement(model)

In [3]:
from paddlenlp.datasets import load_dataset
train_ds, dev_ds, test_ds = load_dataset(
    "glue", name='sst-2', splits=["train", "dev", "test"]
)

# Prepare the Model

## Train the model

In [5]:
# training the model and save to save_dir
# only needs to run once.
# total steps ~1700 (1 epoch)

from assets.utils import training_model
training_model(model, tokenizer, train_ds, dev_ds, save_dir=f'assets/sst-2-ernie-2.0-en')

dataset labels: ['0', '1']
dataset examples:
{'sentence': 'hide new secretions from the parental units ', 'labels': 0}
{'sentence': 'contains no wit , only labored gags ', 'labels': 0}
{'sentence': 'that loves its characters and communicates something rather beautiful about human nature ', 'labels': 1}
{'sentence': 'remains utterly satisfied to remain the same throughout ', 'labels': 0}
{'sentence': 'on the worst revenge-of-the-nerds clichés the filmmakers could dredge up ', 'labels': 0}
Training Starts:
global step 100, epoch: 1, batch: 100, loss: 0.37217, acc: 0.76719
global step 200, epoch: 1, batch: 200, loss: 0.32806, acc: 0.82625
global step 300, epoch: 1, batch: 300, loss: 0.35646, acc: 0.85000
global step 400, epoch: 1, batch: 400, loss: 0.24727, acc: 0.86250
global step 500, epoch: 1, batch: 500, loss: 0.11523, acc: 0.87169
global step 600, epoch: 1, batch: 600, loss: 0.11323, acc: 0.88021
global step 700, epoch: 1, batch: 700, loss: 0.34572, acc: 0.88714
global step 800, epoc

[32m[2022-07-06 19:58:38,642] [    INFO][0m - tokenizer config file saved in assets/sst-2-ernie-2.0-en/tokenizer_config.json[0m
[32m[2022-07-06 19:58:38,646] [    INFO][0m - Special tokens file saved in assets/sst-2-ernie-2.0-en/special_tokens_map.json[0m


## Or Load the trained model

In [4]:
# Load the trained model.
state_dict = paddle.load(f'assets/sst-2-ernie-2.0-en/model_state.pdparams')
model.set_dict(state_dict)

# Prepare for Interpretations

In [5]:
import interpretdl as it 
import numpy as np
from assets.utils import aggregate_subwords_and_importances
from interpretdl.data_processor.visualizer import VisualizationTextRecord, visualize_text

true_labels = [1, 1, 0, 0] * 5
recs = []
reviews = [
    "it 's a charming and often affecting journey . ",
    'the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . ',
    'this one is definitely one to skip , even for horror movie fanatics . ',
    'in its best moments , resembles a bad high school production of grease , without benefit of song . '
]

def text_to_input(raw_text):
    encoded_inputs = tokenizer(text=raw_text, max_seq_len=128)
    _batched_and_to_tuple = tuple([np.array([v]) for v in encoded_inputs.values()])
    return _batched_and_to_tuple

In [7]:
from assets.utils import predict

data = [
    {"text": "it 's a charming and often affecting journey . "},
    {"text":'the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . '},
    {"text":'this one is definitely one to skip , even for horror movie fanatics . '},
    {"text": 'in its best moments , resembles a bad high school production of grease , without benefit of song . '},
]

label_map = {0: 'negative', 1: 'positive'}

batch_size = 32

results = predict(
    model, data, tokenizer, label_map, batch_size=batch_size)

for idx, text in enumerate(data):
    print('Data: {} \t Lable: {}'.format(text, results[idx]))

Data: {'text': "it 's a charming and often affecting journey . "} 	 Lable: positive
Data: {'text': 'the movie achieves as great an impact by keeping these thoughts hidden as ... ( quills ) did by showing them . '} 	 Lable: positive
Data: {'text': 'this one is definitely one to skip , even for horror movie fanatics . '} 	 Lable: negative
Data: {'text': 'in its best moments , resembles a bad high school production of grease , without benefit of song . '} 	 Lable: negative


## BT Interpreter

### Token-wise

In [8]:
bt = it.BTNLPInterpreter(model, device='gpu:0')
for i, review in enumerate(reviews):
    bt_weights = bt.interpret(
        review,
        # tokenizer=tokenizer,
        text_to_input_fn=text_to_input,
        ap_mode="token",
        start_layer=11
    )
    pred_class = bt.predcited_label[0]
    pred_prob = bt.predcited_proba[0, pred_class]
    
    # sg_weights = np.sum(np.abs(sg_weights), axis=-1)

    # subwords with [CLS] and [SEP]
    encoded_inputs = tokenizer(review)
    subwords = tokenizer.convert_ids_to_tokens(encoded_inputs['input_ids'])[1:-1]
    # subwords without special tokens.
    # subwords = " ".join(tokenizer._tokenize(review)).split(' ')
    subword_importances = bt_weights[0][:-1]
    print(len(subwords), len(subword_importances))
    
    words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)
    word_importances = np.array(word_importances) / np.linalg.norm(
        word_importances)
    
    true_label = true_labels[i]
    interp_class = pred_class
    
    if interp_class == 0:
        word_importances = -word_importances
        
    rec = VisualizationTextRecord(
        words, 
        word_importances, 
        true_label,                   
        pred_class, 
        pred_prob,
        interp_class
    )
    
    recs.append(rec)

visualize_text(recs)
# The visualization is not available at github

10 10
26 26
15 15
19 19


True Label,Predicted Label (Prob),Target Label,Word Importance
1.0,1 (1.00),1.0,it ' s a charming and often affecting journey .
,,,
1.0,1 (0.96),1.0,the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them .
,,,
0.0,0 (0.96),0.0,"this one is definitely one to skip , even for horror movie fanatics ."
,,,
0.0,0 (0.99),0.0,"in its best moments , resembles a bad high school production of grease , without benefit of song ."
,,,


### Head-wise

In [9]:
bt = it.BTNLPInterpreter(model, device='gpu:0')

recs = []
for i, review in enumerate(reviews):
    bt_weights = bt.interpret(
        review,
        # tokenizer=tokenizer,
        text_to_input_fn=text_to_input,
        ap_mode="head",
        start_layer=11
    )
    pred_class = bt.predcited_label[0]
    pred_prob = bt.predcited_proba[0, pred_class]
    
    # sg_weights = np.sum(np.abs(sg_weights), axis=-1)

    # subwords with [CLS] and [SEP]
    encoded_inputs = tokenizer(review)
    subwords = tokenizer.convert_ids_to_tokens(encoded_inputs['input_ids'])[1:-1]
    # subwords without special tokens.
    # subwords = " ".join(tokenizer._tokenize(review)).split(' ')
    subword_importances = bt_weights[0][:-1]
    print(len(subwords), len(subword_importances))
    
    words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)
    word_importances = np.array(word_importances) / np.linalg.norm(
        word_importances)
    
    true_label = true_labels[i]
    interp_class = pred_class
    
    if interp_class == 0:
        word_importances = -word_importances
        
    rec = VisualizationTextRecord(
        words, 
        word_importances, 
        true_label,                   
        pred_class, 
        pred_prob,
        interp_class
    )
    
    recs.append(rec)

visualize_text(recs)
# The visualization is not available at github

10 10
26 26
15 15
19 19


True Label,Predicted Label (Prob),Target Label,Word Importance
1.0,1 (1.00),1.0,it ' s a charming and often affecting journey .
,,,
1.0,1 (0.96),1.0,the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them .
,,,
0.0,0 (0.96),0.0,"this one is definitely one to skip , even for horror movie fanatics ."
,,,
0.0,0 (0.99),0.0,"in its best moments , resembles a bad high school production of grease , without benefit of song ."
,,,


## GA Interpreter

In [10]:
ga = it.GANLPInterpreter(model, device='gpu:0')
recs = []
for i, review in enumerate(reviews):
    ga_weights = ga.interpret(
        review,
        # tokenizer=tokenizer,
        text_to_input_fn=text_to_input,
        start_layer=11
    )
    pred_class = ga.predcited_label[0]
    pred_prob = ga.predcited_proba[0, pred_class]
    
    # sg_weights = np.sum(np.abs(sg_weights), axis=-1)

    # subwords with [CLS] and [SEP]
    encoded_inputs = tokenizer(review)
    subwords = tokenizer.convert_ids_to_tokens(encoded_inputs['input_ids'])[1:-1]
    # subwords without special tokens.
    # subwords = " ".join(tokenizer._tokenize(review)).split(' ')
    subword_importances = ga_weights[0][:-1]
    print(len(subwords), len(subword_importances))
    
    words, word_importances = aggregate_subwords_and_importances(subwords, subword_importances)
    word_importances = np.array(word_importances) / np.linalg.norm(
        word_importances)
    
    true_label = true_labels[i]
    interp_class = pred_class
    
    if interp_class == 0:
        word_importances = -word_importances
        
    rec = VisualizationTextRecord(
        words, 
        word_importances, 
        true_label,                   
        pred_class, 
        pred_prob,
        interp_class
    )
    
    recs.append(rec)

visualize_text(recs)
# The visualization is not available at github

10 10
26 26
15 15
19 19


True Label,Predicted Label (Prob),Target Label,Word Importance
1.0,1 (1.00),1.0,it ' s a charming and often affecting journey .
,,,
1.0,1 (0.96),1.0,the movie achieves as great an impact by keeping these thoughts hidden as . . . ( quills ) did by showing them .
,,,
0.0,0 (0.96),0.0,"this one is definitely one to skip , even for horror movie fanatics ."
,,,
0.0,0 (0.99),0.0,"in its best moments , resembles a bad high school production of grease , without benefit of song ."
,,,
