In [None]:
import re
import torch
import argparse
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from IPython.display import HTML, display

from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer

In [None]:
def display_table(data):
    html = "<table style=\"border: 1px double black; border-collapse: collapse;\" cellpadding=\"2\" width=\"1400\">"
    for row in data:
        html += "<tr>"
        for field in row:
            html += f"<td style=\"border: 1px double black; border-collapse: collapse;\"><h4>{field}</h4></td>"
        html += "</tr>"
    html += "</table>"
    display(HTML(html))

In [None]:
def preprocess_text(text: str) -> str:
    """
    text cleaning
    :param text: str
    :return: str
    """
    text = text.lower()
    text = re.sub("[^а-яА-Яa-zA-Z0-9ё\-\\\@/+=_%№ ]", " ", text)
    text = re.sub(r"ё", "е", text)
    text = re.sub("\-\s+", " ", text)
    text = re.sub("\s+", " ", text)
    text = text.strip()
    return text

In [None]:
def _one_hot(token_ids, vocab_size):
    token_ids = token_ids.squeeze()
    return torch.zeros(len(token_ids), vocab_size).scatter_(1, token_ids.unsqueeze(1), 1.)

In [None]:
def saliency(prediction_logit, token_ids_tensor_one_hot, norm=True):
    # Back-propegate the gradient from the selected output-logit
    prediction_logit.backward(retain_graph=True)

    # token_ids_tensor_one_hot.grad is the gradient propegated to ever embedding dimension of
    # the input tokens.
    if norm:  # norm calculates a scalar value (L2 Norm)
        token_importance_raw = torch.norm(token_ids_tensor_one_hot.grad, dim=1)
        # print('token_importance_raw', token_ids_tensor_one_hot.grad.shape,
        # np.count_nonzero(token_ids_tensor_one_hot.detach().numpy(), axis=1))

        # Normalize the values so they add up to 1
        token_importance = token_importance_raw / torch.sum(token_importance_raw)
    else:
        token_importance = torch.sum(token_ids_tensor_one_hot.grad, dim=1)  # Only one value, all others are zero

    token_ids_tensor_one_hot.grad.data.zero_()
    return token_importance

In [None]:
args = argparse.Namespace(
    config_name = "DeepPavlov/rubert-base-cased-conversational",
    tokenizer_name = "DeepPavlov/rubert-base-cased-conversational",
    model_name_or_path = "DeepPavlov/rubert-base-cased-conversational",
    test_data = "data/sentiment_data/test_examples.txt",
    test_labels = "data/sentiment_data/test_labels.txt",
    labels = "data/sentiment_data/labels.txt",
    checkpoint = "outdir/14_0.79_sentim.pt",
    num_labels = 3,
    device = "cuda",
    maxlen = 128
)

# Initialize BERT model

In [None]:
config = AutoConfig.from_pretrained(
    args.config_name if args.config_name else args.model_name_or_path,
    num_labels=args.num_labels
)

tokenizer = AutoTokenizer.from_pretrained(
    args.tokenizer_name if args.tokenizer_name else args.model_name_or_path
)

model = AutoModelForSequenceClassification.from_pretrained(
    args.model_name_or_path,
    from_tf=bool(".ckpt" in args.model_name_or_path),
    config=config
)

model = model.to(args.device)

# Load model from checkpoint

In [None]:
model = model.to(args.device)
model.load_state_dict(torch.load(args.checkpoint)["model_state_dict"])
model.eval()
pass

# Do prediction and calculate saliency

In [None]:
def get_saliency_scores(args, model, tokenizer, sample):
    # prepare input
    text = preprocess_text(sample)
    sample_ids = tokenizer.encode(text, max_length=args.maxlen, padding="max_length")
    sample_txt = ["[CLS]"] + tokenizer.tokenize(text) + ["[SEP]"]
    input_ids = torch.tensor(sample_ids).unsqueeze(0).to(args.device)

    # do prediction and calculate saliency
    embedding_matrix = model.bert.embeddings.word_embeddings.weight.cpu()
    vocab_size = embedding_matrix.shape[0]
    one_hot_tensor = _one_hot(input_ids.cpu(), vocab_size)
    token_ids_tensor_one_hot = one_hot_tensor.clone().requires_grad_(True)
    inputs_embeds = torch.matmul(token_ids_tensor_one_hot, embedding_matrix)

    output = model(inputs_embeds=inputs_embeds.unsqueeze(0).to(args.device), output_hidden_states=True)
    predicted_label_index = torch.argmax(output[0]).item()
    predicted_logit = output[0][0][predicted_label_index]

    saliency_scores = saliency(predicted_logit, token_ids_tensor_one_hot)

    return saliency_scores, sample_txt, predicted_label_index

# Show saliency for current sample

In [None]:
saliency_scores, sample_txt, predicted_label_index = get_saliency_scores(args, model, tokenizer, "Едем с Ясей и Алисой на фабрику криков))) Занятное путешествие))")
plot_data = [saliency_scores.numpy()[:len(sample_txt)]]

In [None]:
fig, ax = plt.subplots(figsize=(20,1), dpi=600) 
sns.heatmap(plot_data, xticklabels=sample_txt, annot=True)

In [None]:
labels[predicted_label_index]

In [None]:
saliency_scores.topk(1).indices.tolist()

# Calculate top key words for each class

In [None]:
with open(args.test_data) as samples_file:
    samples = samples_file.read().split("\n")

with open(args.labels) as labels_file:
    labels = labels_file.read().split("\n")

classes = {}
for sample in samples:
    saliency_scores, sample_txt, predicted_label_index = get_saliency_scores(args, model, tokenizer, sample)

    cur_label = labels[predicted_label_index]
    cur_key_words = [sample_txt[ind] for ind in saliency_scores.topk(2).indices.tolist()]
    
    if cur_label not in classes.keys():
        classes[cur_label] = {}

    for key_word in cur_key_words:
        if key_word not in classes[cur_label].keys():
            classes[cur_label][key_word] = 1
        else:
            classes[cur_label][key_word] += 1

In [None]:
data = []
it = 0
for key in classes.keys():
    data += [[key]] 
    cl = classes[key]

    ordered_cl = {k: v for k, v in sorted(cl.items(), key=lambda item: item[1], reverse=True)}

    for w in ordered_cl.keys():
        if len(data[it]) < 16:
            data[it] += [w]
        else: 
            break

    it += 1

In [None]:
display_table(data)