# Captum Visual Insights for BERT Seq Classification Model

This Notebook helps you to get started with the Captum Insights. Now that we have gotten the response from the Torchserve API for Captum Explanations, it's important to understand the word importance and their attributions.The example covered here is from the Hugginface Transformers pre-trained model used in Torchserve

In [1]:
import json
import captum
from captum.attr import visualization as viz
import os
import logging
import numpy as np
import json
import torch
from transformers import AutoTokenizer

Create a function to map the label with the classes 

In [2]:
def load_label_mapping(mapping_file_path):
    """
    Load a JSON mapping { class ID -> friendly class name }.
    Used in BaseHandler.
    """
    if not os.path.isfile(mapping_file_path):
        logging.info('Missing the index_to_name.json file. Inference output will not include class name.')
        return None

    with open(mapping_file_path) as f:
        mapping = json.load(f)
    if not isinstance(mapping, dict):
        raise Exception('index_to_name mapping should be in "class":"label" json format')

    # Older examples had a different syntax than others. This code accommodates those.
    if 'object_type_names' in mapping and isinstance(mapping['object_type_names'], list):
        mapping = {str(k): v for k, v in enumerate(mapping['object_type_names'])}
        return mapping

    for key, value in mapping.items():
        new_value = value
        if isinstance(new_value, list):
            new_value = value[-1]
        if not isinstance(new_value, str):
            raise Exception('labels in index_to_name must be either str or [str]')
        mapping[key] = new_value
    return mapping

Open the Response JSON Object and load the attributions, word importances and delta key-value pairs

In [3]:
input_file=open('./bert_response.json', 'r')
input_json = json.load(input_file)

In [4]:
attributions = input_json['explanations'][0]['importances']
words = input_json['explanations'][0]['words']
delta = input_json['explanations'][0]['delta']


For visualization purpose using Captum, the attributions and delta parameters should be in the form of Torch Tensors. Please note that the Predictions is returned from the Inference Request.The predicted response should be converted to a torch tensor as the parameter is passed as arguments to the captum visualizer.

In [None]:
#curl request to make a Prediction Request
!curl -H "Content-Type: application/json" --data @examples/Huggingface_Transformers/bert_ts.json http://127.0.0.1:8080/predictions/bert

In [6]:
#Get the predictions value from the Inference Request and place it here.
predictions = [ 
         -0.05393758416175842,
          0.3400498330593109
        ]
#convert predictions to a torch tensor and take the argmax value of it
predictions2 = torch.tensor(predictions)
predictions2 = torch.argmax(predictions2)

#Mapping the class to the labels for the BERT Seq Classification Model
mapping = load_label_mapping("index_to_name_bert.json")
true_label = 'Accepted'

In [7]:
pred_ind = torch.argmax(predictions2)
label = mapping[str(pred_ind.item())]

Creating a word tokenizer using the Hugginface Transformer's Auto tokenizer

In [8]:
if not os.path.isfile(os.path.join(".", "vocab.*")):
            tokenizer = AutoTokenizer.from_pretrained(
                "bert-base-uncased",
                do_lower_case= True,
            )

In [9]:
word_ids = tokenizer.convert_tokens_to_ids(words)

Using the Visualization Data Record method from Captum's Visualization toolkit to render the visualization

In [10]:

result = viz.VisualizationDataRecord(
                        attributions2,
                        predictions2,
                        label,
                        true_label,
                        label,
                        attributions2.sum(),       
                        words,
                        delta2)    


In [11]:
viz.visualize_text([result])

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
Accepted,Not Accepted (1.00),Not Accepted,-0.04,[CLS] the recent climate change across world is impact ##ing negatively [SEP]
,,,,
