In [46]:
%%capture
!pip install transformers
!pip install captum

# Tokenization

In [47]:
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import pipeline
import torch

tokenizer = BertTokenizer.from_pretrained('yiyanghkust/finbert-tone')


text = 'growth is strong and we have plenty of liquidity'
label = 1;

# Tokenize input text
text_ids = tokenizer.encode(text, add_special_tokens=True)

# Print the tokens
print(tokenizer.convert_ids_to_tokens(text_ids))
# Output: ['[CLS]', 'The', 'movie', 'is', 'superb', '[SEP]']

# Print the ids of the tokens
print(text_ids)
# Output: [101, 1109, 2523, 1110, 25876, 102]

['[CLS]', 'growth', 'is', 'strong', 'and', 'we', 'have', 'plenty', 'of', 'liquidity', '[SEP]']
[3, 64, 17, 253, 8, 13, 29, 9146, 7, 466, 4]


# Fetch Embedding of Tokens

In [48]:
from transformers import BertModel

# Instantiate BERT model
model = BertForSequenceClassification.from_pretrained('yiyanghkust/finbert-tone',num_labels=3)

embeddings = model.bert.embeddings(torch.tensor([text_ids]))
print(embeddings.size())
# Output: torch.Size([1, 6, 768]), since there are 6 tokens in text_ids

torch.Size([1, 11, 768])


# Inference

In [49]:
text = 'growth is strong and we have plenty of liquidity'
inputs = tokenizer(text, return_tensors="pt")

from transformers import AutoModelForSequenceClassification

with torch.no_grad():
    logits = model(**inputs).logits

predicted_class_id = logits.argmax().item()

print(model.config.id2label)
print(model.config.id2label[predicted_class_id])

{0: 'Neutral', 1: 'Positive', 2: 'Negative'}
Positive


# Define Model Input and Output

In [50]:
# # Define model output
# def model_output(inputs):
#   return model(inputs)[0]

# # Define model input
# model_input = model.bert.embeddings

# Instantiate Integrated Gradients Method

In [51]:
from captum.attr import LayerIntegratedGradients
lig = LayerIntegratedGradients(lambda inputs: model(inputs)[0], model.bert.embeddings)

# Construct Original and Baseline Input

In [52]:
def construct_input_and_baseline(text):

    max_length = 510
    baseline_token_id = tokenizer.pad_token_id 
    sep_token_id = tokenizer.sep_token_id 
    cls_token_id = tokenizer.cls_token_id 

    text_ids = tokenizer.encode(text, max_length=max_length, truncation=True, add_special_tokens=False)
   
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    token_list = tokenizer.convert_ids_to_tokens(input_ids)
  

    baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]
    return torch.tensor([input_ids], device='cpu'), torch.tensor([baseline_input_ids], device='cpu'), token_list

# text = 'This movie is superb'
input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)

print(f'original text: {input_ids}')
print(f'baseline text: {baseline_input_ids}')

# Output: original text: tensor([[  101,  1109,  2523,  1110, 25876,   102]])
# Output: baseline text: tensor([[101,   0,   0,   0,   0, 102]])


original text: tensor([[   3,   64,   17,  253,    8,   13,   29, 9146,    7,  466,    4]])
baseline text: tensor([[3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4]])


# Compute Attributions

In [53]:
attributions, delta = lig.attribute(
    inputs=input_ids,
    baselines=baseline_input_ids,
    target=label,
    return_convergence_delta=True
)
print(attributions.size())
# Output: torch.Size([1, 6, 768])

torch.Size([1, 11, 768])


# Compute Attribution for Each Token

In [54]:
def summarize_attributions(attributions):

    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    
    return attributions

attributions_sum = summarize_attributions(attributions)
print(attributions_sum.size())
# Output: torch.Size([6])

torch.Size([11])


# Encapsulate All the Steps Above

In [59]:
def interpret_text(text, true_class):

    input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)
    attributions, delta = lig.attribute(inputs= input_ids,
                                    baselines= baseline_input_ids,
                                    target=true_class,
                                    return_convergence_delta=True
                                    )
    attributions_sum = summarize_attributions(attributions)

    score_vis = viz.VisualizationDataRecord(
                        word_attributions = attributions_sum,
                        pred_prob = torch.max(model(input_ids)[0]),
                        pred_class = torch.argmax(model(input_ids)[0]).numpy(),
                        true_class = true_class,
                        attr_class = text,
                        attr_score = attributions_sum.sum(),       
                        raw_input_ids = all_tokens,
                        convergence_score = delta)

    viz.visualize_text([score_vis])

In [73]:
text = "Stocks rallied and the British pound gained."
true_class = model.config.label2id["Positive"]

interpret_text(text, true_class)


text = 'growth is strong and we have plenty of liquidity'
true_class = model.config.label2id["Positive"];
interpret_text(text, true_class)
  

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (2.99),Stocks rallied and the British pound gained.,0.52,[CLS] stocks ral ##lied and the british pound gained . [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (10.98),growth is strong and we have plenty of liquidity,0.85,[CLS] growth is strong and we have plenty of liquidity [SEP]
,,,,
