<a href="https://colab.research.google.com/github/ssdorsey/transformers-interp/blob/main/interpret_tweets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
# Install the packages we need
try:
    import transformers
except:
    !pip install transformers -q

try:
    import captum
except:
    !pip install captum -q

# !pip install emoji

[K     |████████████████████████████████| 1.8MB 16.4MB/s 
[K     |████████████████████████████████| 3.2MB 61.3MB/s 
[K     |████████████████████████████████| 890kB 61.2MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 4.4MB 18.0MB/s 
[?25h

In [None]:
# download the model
import os
# if 'model' not in os.listdir():
!wget -O model.zip https://www.dropbox.com/sh/sx6g8hp8miylrvk/AADr49JboXw26njcHn5Xloqla?dl=1 # this is the "polarizing" model
!unzip model.zip -d model

# download the script to display the interpretation 
# !wget -O robertainterp.py https://www.dropbox.com/sh/j6v5kgbmgki15ap/AABI4b4sx5gnJCvPkD8oE5lia?dl=1 -q

# code to display
from ipywidgets import interact, widgets, Layout
from IPython.display import display, clear_output
# code to interpret
# from robertainterp import robertainterp

--2021-02-23 04:33:11--  https://www.dropbox.com/sh/sx6g8hp8miylrvk/AADr49JboXw26njcHn5Xloqla?dl=1
Resolving www.dropbox.com (www.dropbox.com)... 162.125.6.18, 2620:100:601c:18::a27d:612
Connecting to www.dropbox.com (www.dropbox.com)|162.125.6.18|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /sh/dl/sx6g8hp8miylrvk/AADr49JboXw26njcHn5Xloqla [following]
--2021-02-23 04:33:12--  https://www.dropbox.com/sh/dl/sx6g8hp8miylrvk/AADr49JboXw26njcHn5Xloqla
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc34fade432d9889a6832edb6b60.dl.dropboxusercontent.com/zip_download_get/As3m1qJYnYm7S7TV714F8yEI4MlJtT_JxmZol7WDnW2qn0vXu3ycQScirFHMdjXBCV7nCyOAjNvffipCahTXWDcktf-pl2IQOf2Dq5tXlYRmuw?dl=1 [following]
--2021-02-23 04:33:12--  https://uc34fade432d9889a6832edb6b60.dl.dropboxusercontent.com/zip_download_get/As3m1qJYnYm7S7TV714F8yEI4MlJtT_JxmZol7WDnW2qn0vXu3ycQScirFHMdjXBCV7nCyOAjNvffi

In [None]:
!ls model/

cached_dev_roberta_128_2_796	 merges.txt		  tokenizer_config.json
cached_train_roberta_128_2_3172  model_args.json	  training_args.bin
config.json			 pytorch_model.bin	  vocab.json
eval_results.txt		 special_tokens_map.json


In [None]:
from typing import Any, Iterable, List, Tuple, Union
# from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

import torch

import warnings
from enum import Enum
from typing import Any, Iterable, List, Tuple, Union

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.figure import Figure
from matplotlib.pyplot import axis, figure
from mpl_toolkits.axes_grid1 import make_axes_locatable
from numpy import ndarray

from transformers import RobertaTokenizer, RobertaForSequenceClassification, RobertaModel, AutoModel, AutoTokenizer


try:
    from IPython.core.display import display, HTML
    HAS_IPYTHON = True
except ImportError:
    HAS_IPYTHON = False


# choose device
# load tokenizer
# tokenizer = AutoTokenizer.from_pretrained('vinai/bertweet-base')
# slow_tokenizer = AutoTokenizer.from_pretrained('vinai/bertweet-base')
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
slow_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load model
model = RobertaModel.from_pretrained('model/')
model.to(device)
model.eval()
model.zero_grad()

ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence


# label dictionary
# label_dict = {
#     0: 'Civil',
#     1: 'Uncivil'
# }
label_dict = {
    77: 'Not Polarizing',
    5196: 'Polarizing',
    11205: 'Polarizing'
}


try:
    from IPython.core.display import display, HTML

    HAS_IPYTHON = True
except ImportError:
    HAS_IPYTHON = False

class ImageVisualizationMethod(Enum):
    heat_map = 1
    blended_heat_map = 2
    original_image = 3
    masked_image = 4
    alpha_scaling = 5


class VisualizeSign(Enum):
    positive = 1
    absolute_value = 2
    negative = 3
    all = 4


def _prepare_image(attr_visual: ndarray):
    return np.clip(attr_visual.astype(int), 0, 255)


def _normalize_scale(attr: ndarray, scale_factor: float):
    if abs(scale_factor) < 1e-5:
        warnings.warn(
            "Attempting to normalize by value approximately 0, skipping normalization."
            "This likely means that attribution values are all close to 0."
        )
        return np.clip(attr, -1, 1)
    attr_norm = attr / scale_factor
    return np.clip(attr_norm, -1, 1)


def _cumulative_sum_threshold(values: ndarray, percentile: Union[int, float]):
    # given values should be non-negative
    assert percentile >= 0 and percentile <= 100, (
        "Percentile for thresholding must be " "between 0 and 100 inclusive."
    )
    sorted_vals = np.sort(values.flatten())
    cum_sums = np.cumsum(sorted_vals)
    threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0]
    return sorted_vals[threshold_id]


def _normalize_image_attr(
    attr: ndarray, sign: str, outlier_perc: Union[int, float] = 2
):
    attr_combined = np.sum(attr, axis=2)
    # Choose appropriate signed values and rescale, removing given outlier percentage.
    if VisualizeSign[sign] == VisualizeSign.all:
        threshold = _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc)
    elif VisualizeSign[sign] == VisualizeSign.positive:
        attr_combined = (attr_combined > 0) * attr_combined
        threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc)
    elif VisualizeSign[sign] == VisualizeSign.negative:
        attr_combined = (attr_combined < 0) * attr_combined
        threshold = -1 * _cumulative_sum_threshold(
            np.abs(attr_combined), 100 - outlier_perc
        )
    elif VisualizeSign[sign] == VisualizeSign.absolute_value:
        attr_combined = np.abs(attr_combined)
        threshold = _cumulative_sum_threshold(attr_combined, 100 - outlier_perc)
    else:
        raise AssertionError("Visualize Sign type is not valid.")
    return _normalize_scale(attr_combined, threshold)


def visualize_image_attr(
    attr: ndarray,
    original_image: Union[None, ndarray] = None,
    method: str = "heat_map",
    sign: str = "absolute_value",
    plt_fig_axis: Union[None, Tuple[figure, axis]] = None,
    outlier_perc: Union[int, float] = 2,
    cmap: Union[None, str] = None,
    alpha_overlay: float = 0.5,
    show_colorbar: bool = False,
    title: Union[None, str] = None,
    fig_size: Tuple[int, int] = (6, 6),
    use_pyplot: bool = True,
):
    r"""
        Visualizes attribution for a given image by normalizing attribution values
        of the desired sign (positive, negative, absolute value, or all) and displaying
        them using the desired mode in a matplotlib figure.

        Args:

            attr (numpy.array): Numpy array corresponding to attributions to be
                        visualized. Shape must be in the form (H, W, C), with
                        channels as last dimension. Shape must also match that of
                        the original image if provided.
            original_image (numpy.array, optional):  Numpy array corresponding to
                        original image. Shape must be in the form (H, W, C), with
                        channels as the last dimension. Image can be provided either
                        with float values in range 0-1 or int values between 0-255.
                        This is a necessary argument for any visualization method
                        which utilizes the original image.
                        Default: None
            method (string, optional): Chosen method for visualizing attribution.
                        Supported options are:

                        1. `heat_map` - Display heat map of chosen attributions

                        2. `blended_heat_map` - Overlay heat map over greyscale
                           version of original image. Parameter alpha_overlay
                           corresponds to alpha of heat map.

                        3. `original_image` - Only display original image.

                        4. `masked_image` - Mask image (pixel-wise multiply)
                           by normalized attribution values.

                        5. `alpha_scaling` - Sets alpha channel of each pixel
                           to be equal to normalized attribution value.
                        Default: `heat_map`
            sign (string, optional): Chosen sign of attributions to visualize. Supported
                        options are:

                        1. `positive` - Displays only positive pixel attributions.

                        2. `absolute_value` - Displays absolute value of
                           attributions.

                        3. `negative` - Displays only negative pixel attributions.

                        4. `all` - Displays both positive and negative attribution
                           values. This is not supported for `masked_image` or
                           `alpha_scaling` modes, since signed information cannot
                           be represented in these modes.
                        Default: `absolute_value`
            plt_fig_axis (tuple, optional): Tuple of matplotlib.pyplot.figure and axis
                        on which to visualize. If None is provided, then a new figure
                        and axis are created.
                        Default: None
            outlier_perc (float or int, optional): Top attribution values which
                        correspond to a total of outlier_perc percentage of the
                        total attribution are set to 1 and scaling is performed
                        using the minimum of these values. For sign=`all`, outliers a
                        nd scale value are computed using absolute value of
                        attributions.
                        Default: 2
            cmap (string, optional): String corresponding to desired colormap for
                        heatmap visualization. This defaults to "Reds" for negative
                        sign, "Blues" for absolute value, "Greens" for positive sign,
                        and a spectrum from red to green for all. Note that this
                        argument is only used for visualizations displaying heatmaps.
                        Default: None
            alpha_overlay (float, optional): Alpha to set for heatmap when using
                        `blended_heat_map` visualization mode, which overlays the
                        heat map over the greyscaled original image.
                        Default: 0.5
            show_colorbar (boolean, optional): Displays colorbar for heatmap below
                        the visualization. If given method does not use a heatmap,
                        then a colormap axis is created and hidden. This is
                        necessary for appropriate alignment when visualizing
                        multiple plots, some with colorbars and some without.
                        Default: False
            title (string, optional): Title string for plot. If None, no title is
                        set.
                        Default: None
            fig_size (tuple, optional): Size of figure created.
                        Default: (6,6)
            use_pyplot (boolean, optional): If true, uses pyplot to create and show
                        figure and displays the figure after creating. If False,
                        uses Matplotlib object oriented API and simply returns a
                        figure object without showing.
                        Default: True.

        Returns:
            2-element tuple of **figure**, **axis**:
            - **figure** (*matplotlib.pyplot.figure*):
                        Figure object on which visualization
                        is created. If plt_fig_axis argument is given, this is the
                        same figure provided.
            - **axis** (*matplotlib.pyplot.axis*):
                        Axis object on which visualization
                        is created. If plt_fig_axis argument is given, this is the
                        same axis provided.

        Examples::

            >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
            >>> # and returns an Nx10 tensor of class probabilities.
            >>> net = ImageClassifier()
            >>> ig = IntegratedGradients(net)
            >>> # Computes integrated gradients for class 3 for a given image .
            >>> attribution, delta = ig.attribute(orig_image, target=3)
            >>> # Displays blended heat map visualization of computed attributions.
            >>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map")
    """
    # Create plot if figure, axis not provided
    if plt_fig_axis is not None:
        plt_fig, plt_axis = plt_fig_axis
    else:
        if use_pyplot:
            plt_fig, plt_axis = plt.subplots(figsize=fig_size)
        else:
            plt_fig = Figure(figsize=fig_size)
            plt_axis = plt_fig.subplots()

    if original_image is not None:
        if np.max(original_image) <= 1.0:
            original_image = _prepare_image(original_image * 255)
    else:
        assert (
            ImageVisualizationMethod[method] == ImageVisualizationMethod.heat_map
        ), "Original Image must be provided for any visualization other than heatmap."

    # Remove ticks and tick labels from plot.
    plt_axis.xaxis.set_ticks_position("none")
    plt_axis.yaxis.set_ticks_position("none")
    plt_axis.set_yticklabels([])
    plt_axis.set_xticklabels([])

    heat_map = None
    # Show original image
    if ImageVisualizationMethod[method] == ImageVisualizationMethod.original_image:
        plt_axis.imshow(original_image)
    else:
        # Choose appropriate signed attributions and normalize.
        norm_attr = _normalize_image_attr(attr, sign, outlier_perc)

        # Set default colormap and bounds based on sign.
        if VisualizeSign[sign] == VisualizeSign.all:
            default_cmap = LinearSegmentedColormap.from_list(
                "RdWhGn", ["red", "white", "green"]
            )
            vmin, vmax = -1, 1
        elif VisualizeSign[sign] == VisualizeSign.positive:
            default_cmap = "Greens"
            vmin, vmax = 0, 1
        elif VisualizeSign[sign] == VisualizeSign.negative:
            default_cmap = "Reds"
            vmin, vmax = 0, 1
        elif VisualizeSign[sign] == VisualizeSign.absolute_value:
            default_cmap = "Blues"
            vmin, vmax = 0, 1
        else:
            raise AssertionError("Visualize Sign type is not valid.")
        cmap = cmap if cmap is not None else default_cmap

        # Show appropriate image visualization.
        if ImageVisualizationMethod[method] == ImageVisualizationMethod.heat_map:
            heat_map = plt_axis.imshow(norm_attr, cmap=cmap, vmin=vmin, vmax=vmax)
        elif (
            ImageVisualizationMethod[method]
            == ImageVisualizationMethod.blended_heat_map
        ):
            plt_axis.imshow(np.mean(original_image, axis=2), cmap="gray")
            heat_map = plt_axis.imshow(
                norm_attr, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha_overlay
            )
        elif ImageVisualizationMethod[method] == ImageVisualizationMethod.masked_image:
            assert VisualizeSign[sign] != VisualizeSign.all, (
                "Cannot display masked image with both positive and negative "
                "attributions, choose a different sign option."
            )
            plt_axis.imshow(
                _prepare_image(original_image * np.expand_dims(norm_attr, 2))
            )
        elif ImageVisualizationMethod[method] == ImageVisualizationMethod.alpha_scaling:
            assert VisualizeSign[sign] != VisualizeSign.all, (
                "Cannot display alpha scaling with both positive and negative "
                "attributions, choose a different sign option."
            )
            plt_axis.imshow(
                np.concatenate(
                    [
                        original_image,
                        _prepare_image(np.expand_dims(norm_attr, 2) * 255),
                    ],
                    axis=2,
                )
            )
        else:
            raise AssertionError("Visualize Method type is not valid.")

    # Add colorbar. If given method is not a heatmap and no colormap is relevant,
    # then a colormap axis is created and hidden. This is necessary for appropriate
    # alignment when visualizing multiple plots, some with heatmaps and some
    # without.
    if show_colorbar:
        axis_separator = make_axes_locatable(plt_axis)
        colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.1)
        if heat_map:
            plt_fig.colorbar(heat_map, orientation="horizontal", cax=colorbar_axis)
        else:
            colorbar_axis.axis("off")
    if title:
        plt_axis.set_title(title)

    if use_pyplot:
        plt.show()

    return plt_fig, plt_axis


def visualize_image_attr_multiple(
    attr: ndarray,
    original_image: Union[None, ndarray],
    methods: List[str],
    signs: List[str],
    titles: Union[None, List[str]] = None,
    fig_size: Tuple[int, int] = (8, 6),
    use_pyplot: bool = True,
    **kwargs: Any
):
    r"""
        Visualizes attribution using multiple visualization methods displayed
        in a 1 x k grid, where k is the number of desired visualizations.

        Args:

            attr (numpy.array): Numpy array corresponding to attributions to be
                        visualized. Shape must be in the form (H, W, C), with
                        channels as last dimension. Shape must also match that of
                        the original image if provided.
            original_image (numpy.array, optional):  Numpy array corresponding to
                        original image. Shape must be in the form (H, W, C), with
                        channels as the last dimension. Image can be provided either
                        with values in range 0-1 or 0-255. This is a necessary
                        argument for any visualization method which utilizes
                        the original image.
            methods (list of strings): List of strings of length k, defining method
                            for each visualization. Each method must be a valid
                            string argument for method to visualize_image_attr.
            signs (list of strings): List of strings of length k, defining signs for
                            each visualization. Each sign must be a valid
                            string argument for sign to visualize_image_attr.
            titles (list of strings, optional):  List of strings of length k, providing
                        a title string for each plot. If None is provided, no titles
                        are added to subplots.
                        Default: None
            fig_size (tuple, optional): Size of figure created.
                        Default: (8, 6)
            use_pyplot (boolean, optional): If true, uses pyplot to create and show
                        figure and displays the figure after creating. If False,
                        uses Matplotlib object oriented API and simply returns a
                        figure object without showing.
                        Default: True.
            **kwargs (Any, optional): Any additional arguments which will be passed
                        to every individual visualization. Such arguments include
                        `show_colorbar`, `alpha_overlay`, `cmap`, etc.


        Returns:
            2-element tuple of **figure**, **axis**:
            - **figure** (*matplotlib.pyplot.figure*):
                        Figure object on which visualization
                        is created. If plt_fig_axis argument is given, this is the
                        same figure provided.
            - **axis** (*matplotlib.pyplot.axis*):
                        Axis object on which visualization
                        is created. If plt_fig_axis argument is given, this is the
                        same axis provided.

        Examples::

            >>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
            >>> # and returns an Nx10 tensor of class probabilities.
            >>> net = ImageClassifier()
            >>> ig = IntegratedGradients(net)
            >>> # Computes integrated gradients for class 3 for a given image .
            >>> attribution, delta = ig.attribute(orig_image, target=3)
            >>> # Displays original image and heat map visualization of
            >>> # computed attributions side by side.
            >>> _ = visualize_mutliple_image_attr(["original_image", "heat_map"],
            >>>                     ["all", "positive"], attribution, orig_image)
    """
    assert len(methods) == len(signs), "Methods and signs array lengths must match."
    if titles is not None:
        assert len(methods) == len(titles), (
            "If titles list is given, length must " "match that of methods list."
        )
    if use_pyplot:
        plt_fig = plt.figure(figsize=fig_size)
    else:
        plt_fig = Figure(figsize=fig_size)
    plt_axis = plt_fig.subplots(1, len(methods))

    # When visualizing one
    if len(methods) == 1:
        plt_axis = [plt_axis]

    for i in range(len(methods)):
        visualize_image_attr(
            attr,
            original_image=original_image,
            method=methods[i],
            sign=signs[i],
            plt_fig_axis=(plt_fig, plt_axis[i]),
            use_pyplot=False,
            title=titles[i] if titles else None,
            **kwargs
        )
    plt_fig.tight_layout()
    if use_pyplot:
        plt.show()
    return plt_fig, plt_axis


# These visualization methods are for text and are partially copied from
# experiments conducted by Davide Testuggine at Facebook.


class VisualizationDataRecord:
    r"""
        A data record for storing attribution relevant information
    """
    __slots__ = [
        "word_attributions",
        "pred_prob",
        "pred_class",
        "true_class",
        "attr_class",
        "attr_score",
        "raw_input",
        "convergence_score",
    ]

    def __init__(
        self,
        word_attributions,
        pred_prob,
        pred_class,
        true_class,
        attr_class,
        attr_score,
        raw_input,
        convergence_score,
    ):
        self.word_attributions = word_attributions
        self.pred_prob = pred_prob
        self.pred_class = pred_class
        self.true_class = true_class
        self.attr_class = attr_class
        self.attr_score = attr_score
        self.raw_input = raw_input
        self.convergence_score = convergence_score


# def _get_color(attr):
#     # clip values to prevent CSS errors (Values should be from [-1,1])
#     attr = max(-1, min(1, attr))
#     if attr > 0:
#         hue = 120
#         sat = 75
#         lig = 100 - int(90 * attr)
#     else:
#         hue = 0
#         sat = 75
#         lig = 100 - int(-80 * attr)
#     return "hsl({}, {}%, {}%)".format(hue, sat, lig)


def format_classname(classname):
    return '<td><text style="padding-right:2em"><b>{}</b></text></td>'.format(classname)


def format_special_tokens(token):
    if token.startswith("<") and token.endswith(">"):
        return "#" + token.strip("<>")
    return token


def format_tooltip(item, text):
    return '<div class="tooltip">{item}\
        <span class="tooltiptext">{text}</span>\
        </div>'.format(
        item=item, text=text
    )


def format_word_importances(words, importances):
    if importances is None or len(importances) == 0:
        return "<td></td>"
    assert len(words) <= len(importances)
    tags = ["<td>"]
    for word, importance in zip(words, importances[: len(words)]):
        word = format_special_tokens(word)
        color = _get_color(importance)
        unwrapped_tag = '<mark style="background-color: {color}; opacity:1.0; \
                    line-height:1.75"><font color="black"> {word}\
                    </font></mark>'.format(
            color=color, word=word
        )
        tags.append(unwrapped_tag)
    tags.append("</td>")
    return "".join(tags)



def visualize_text(datarecords: Iterable[VisualizationDataRecord]) -> None:
    assert HAS_IPYTHON, (
        "IPython must be available to visualize text. "
        "Please run 'pip install ipython'."
    )
    dom = ["<table width: 65%>"]
    rows = [
        "<th>Predicted<br /> Label</th>"
        "<th>Word Importance</th>"
    ]
    for datarecord in datarecords:
        rows.append(
            "".join(
                [
                    "<tr>",
                    format_classname(
                        # "{0} ({1:.2f})".format(
                            "{0}".format(
                            label_dict[datarecord.pred_class.item()]
                            # , 1-datarecord.pred_prob.item()
                        )
                    ),
                    format_word_importances(
                        datarecord.raw_input[1:-1], datarecord.word_attributions[1:-1]
                    ),
                    "<tr>",
                ]
            )
        )

    dom.append("".join(rows))
    dom.append("</table>")
    display(HTML("".join(dom)))

def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)


def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions


def fix_characters(text):
    text = text.replace("‘", "'")
    text = text.replace("’", "'")
    text = text.replace("“", "\"")
    text = text.replace("”", "\"")

    return text

def predict(inputs):
    return model(inputs)[0]


def _get_color(attr):
    # clip values to prevent CSS errors (Values should be from [-1,1])
    attr = max(-1, min(1, attr))
    if attr < 0:
        hue = 360
        sat = 1
        lig = max([100 - int(-900 * attr), 50])
    else:
        hue = 200
        sat = 100
        lig = 100 #- int(90 * attr)

    return "hsl({}, {}%, {}%)".format(hue, sat, lig)



# lig = LayerIntegratedGradients(custom_forward, model.roberta.embeddings)
lig = LayerIntegratedGradients(custom_forward, model.base_model.embeddings)
# lig = LayerIntegratedGradients(custom_forward, model.embeddings)



def robertainterp(text):

    text = fix_characters(text)

    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
    # token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
    # position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
    # attention_mask = construct_attention_mask(input_ids)

    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    all_tokens = [tok.replace('Ġ', '') for tok in all_tokens]

    attributions, delta = lig.attribute(inputs=input_ids,
                                        baselines=ref_input_ids,
                                        return_convergence_delta=True)

    score = predict(input_ids)

    attributions_sum = summarize_attributions(attributions)

    # if torch.argmax(torch.softmax(score, dim = 1)[0]).item() == 1:
    #     attributions_sum = torch.mul(attributions_sum, -1)

    score_vis = VisualizationDataRecord(
                            attributions_sum,
                            torch.softmax(score, dim = 1)[0][0],
                            torch.argmax(torch.softmax(score, dim = 1)[0]),
                            0,
                            text,
                            attributions_sum.sum(),       
                            all_tokens,
                            delta)

    # print('\033[1m', 'Visualization For Score', '\033[0m')
    visualize_text([score_vis])



## Run it!

In [None]:
text = widgets.Text(
    value='You & your children won’t be SAFE in Biden’s America, and neither will anyone else!',
    placeholder='Enter tweet to test here',
    description='Tweet:',
    disabled=False, 
    layout=Layout(width='75%', height='80px')
)

def callback(wdgt):
    robertainterp(wdgt.value)

display(text)
callback(text)

text.on_submit(callback)

Text(value='You & your children won’t be SAFE in Biden’s America, and neither will anyone else!', description=…

0,1
Not Polarizing,"You & your children won 't be SAF E in Biden 's America , and neither will anyone else !"
,


In [None]:
robertainterp("The Trump Admin wants to create an unnecessary process that would hurt Minnesota home care workers, the majority of which are women, including many women of color. I'm standing up against this misguided attack with my fellow Minnesota colleagues.")

0,1
Polarizing,"The Trump Admin wants to create an unnecessary process that would hurt Minnesota home care workers , the majority of which are women , including many women of color . I 'm standing up against this misguided attack with my fellow Minnesota colleagues ."
,


In [None]:
robertainterp("On this Memorial Day, join me in taking time to honor those who made the ultimate sacrifice & through that sacrifice blessed us with liberty")

0,1
Not Polarizing,"On this Memorial Day , join me in taking time to honor those who made the ultimate sacrifice & through that sacrifice blessed us with liberty"
,


The Trump Admin wants to create an unnecessary process that would hurt Minnesota home care workers, the majority of which are women, including many women of color. I'm standing up against this misguided attack with my fellow minnesota colleagues.

On this Memorial Day, join me in taking time to honor those who made the ultimate sacrifice & through that sacrifice blessed us with liberty