In [40]:
import numpy as np
import torch
from torch import Tensor
from torch import nn
from torch.functional import F
import matplotlib.pyplot as plt
import os
from typing import Dict, Tuple, Union, NewType, List, Optional, Any
from pathlib import Path, WindowsPath
import warnings
import pickle
warnings.filterwarnings("ignore")

In [49]:
cls_tokens_heads = torch.tensor([[0.0432, 0.0381, 0.0384, 0.0199, 0.0373, 0.0222, 0.0121, 0.0135, 0.0071,
         0.0214, 0.0632, 0.0286, 0.0160, 0.0277, 0.0151, 0.1030, 0.0407, 0.0523,
         0.0601, 0.0748, 0.1127, 0.0975],
        [0.0196, 0.0148, 0.0240, 0.0141, 0.0195, 0.0144, 0.0367, 0.0268, 0.0252,
         0.0194, 0.0770, 0.0090, 0.0037, 0.0409, 0.0246, 0.1645, 0.0629, 0.0079,
         0.0196, 0.0188, 0.1806, 0.1480],
        [0.0247, 0.0327, 0.0276, 0.0323, 0.0679, 0.0466, 0.0245, 0.0169, 0.0175,
         0.0341, 0.0878, 0.0196, 0.0225, 0.0127, 0.0186, 0.0514, 0.0391, 0.1011,
         0.0445, 0.0903, 0.0543, 0.0572],
        [0.0669, 0.1618, 0.0277, 0.0332, 0.0906, 0.1035, 0.0333, 0.0247, 0.0160,
         0.0930, 0.0282, 0.0014, 0.0025, 0.0007, 0.0003, 0.0260, 0.0247, 0.0590,
         0.0329, 0.0671, 0.0288, 0.0430],
        [0.0297, 0.0198, 0.0613, 0.0359, 0.0415, 0.0191, 0.0418, 0.0438, 0.0169,
         0.0156, 0.1545, 0.0018, 0.0006, 0.0031, 0.0011, 0.0408, 0.1228, 0.0521,
         0.0416, 0.0394, 0.0461, 0.0790],
        [0.0386, 0.0724, 0.0323, 0.0274, 0.0456, 0.0881, 0.0192, 0.0187, 0.0124,
         0.0322, 0.0488, 0.0086, 0.0106, 0.0152, 0.0099, 0.0512, 0.0581, 0.0556,
         0.0639, 0.0761, 0.0567, 0.0648],
        [0.0428, 0.0303, 0.0148, 0.0158, 0.0265, 0.0300, 0.0190, 0.0343, 0.0226,
         0.0228, 0.0295, 0.0588, 0.0195, 0.0876, 0.1674, 0.0675, 0.0777, 0.0199,
         0.0117, 0.0244, 0.0722, 0.0807],
        [0.0320, 0.0682, 0.0224, 0.0261, 0.0584, 0.1119, 0.0197, 0.0262, 0.0261,
         0.0314, 0.0550, 0.0023, 0.0052, 0.0053, 0.0056, 0.1080, 0.0214, 0.0598,
         0.0300, 0.0599, 0.1199, 0.0665],
        [0.0390, 0.0259, 0.0652, 0.0463, 0.1139, 0.0477, 0.0242, 0.0202, 0.0232,
         0.0490, 0.0928, 0.0101, 0.0068, 0.0086, 0.0037, 0.0231, 0.0367, 0.0429,
         0.0840, 0.1244, 0.0248, 0.0278],
        [0.0556, 0.0273, 0.1120, 0.0775, 0.1500, 0.0288, 0.0296, 0.0278, 0.0235,
         0.0489, 0.0376, 0.0023, 0.0014, 0.0035, 0.0015, 0.0067, 0.0154, 0.1192,
         0.0857, 0.1073, 0.0073, 0.0099],
        [0.0705, 0.0525, 0.0492, 0.0244, 0.0621, 0.0516, 0.0242, 0.0223, 0.0108,
         0.0427, 0.0458, 0.0300, 0.0133, 0.0287, 0.0129, 0.0436, 0.0709, 0.0505,
         0.0636, 0.0938, 0.0485, 0.0431],
        [0.0153, 0.0452, 0.0083, 0.0070, 0.0121, 0.0645, 0.0142, 0.0067, 0.0076,
         0.0199, 0.0250, 0.0499, 0.1046, 0.0418, 0.0677, 0.1479, 0.0328, 0.0252,
         0.0123, 0.0127, 0.1613, 0.1026]])

In [133]:
from transformers import AutoTokenizer
model_name = "textattack/bert-base-uncased-SST-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def get_input_ids(tokenizer, text):
    encoding = tokenizer(text, return_tensors='pt')
    input_ids = encoding['input_ids']
    return input_ids[0]

def generate_scoring_by_head_agg(cls_tokens_scores,input_ids, agg: str ='median'):
    if agg == 'median':
        scores = cls_tokens_heads.median(dim=0)[0]
    else:
        scores = cls_tokens_heads.mean(dim=0)
    for token_id in torch.topk(scores, k=cls_tokens_heads.shape[-1]-1, largest=True)[1].tolist():
        print(tokenizer.convert_ids_to_tokens(input_ids[1:])[token_id])

In [136]:
text_1 = "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."
text_2 = "I really didn't like this movie. Some of the actors were good, but overall the movie was boring."
input_ids_1 = get_input_ids(tokenizer, text_1)
input_ids_2 = get_input_ids(tokenizer, text_2)

In [None]:
text_1 = "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."

In [135]:
generate_scoring_by_head_agg(cls_tokens_scores=cls_tokens_heads, input_ids=input_ids_1, agg='mean')

.
,
[SEP]
great
!
best
acting
movie
but
movie
was
was
this
seen
the
ridiculous
i
have
were
some
ever


In [140]:
cls_tokens = torch.tensor([[6.4104e-03, 5.8950e-03, 7.3963e-03, 1.1112e-01, 4.9748e-03, 1.1208e-02,
         1.2484e-02, 3.0761e-02, 1.1140e-01, 3.7053e-02, 3.4892e-02, 3.4263e-02,
         2.3692e-02, 7.7873e-02, 6.7125e-02, 9.3894e-02, 2.1148e-02, 2.7793e-02,
         1.9592e-02, 3.9174e-02, 2.4819e-02, 3.1904e-02, 1.1053e-01, 2.6811e-02],
        [1.6081e-02, 5.2960e-03, 3.7188e-03, 2.0494e-01, 2.7993e-03, 6.3701e-03,
         8.1040e-03, 4.1798e-02, 2.0626e-01, 2.2433e-03, 2.2982e-03, 5.6787e-03,
         1.2111e-02, 1.0414e-02, 1.1217e-02, 1.6521e-01, 1.0568e-02, 4.5132e-03,
         2.5439e-03, 2.5158e-02, 5.0774e-03, 5.2127e-03, 2.0313e-01, 3.4380e-02],
        [5.5605e-02, 2.2398e-02, 3.1904e-02, 8.9320e-02, 2.7649e-02, 5.7096e-02,
         6.1439e-02, 2.5013e-02, 8.9268e-02, 8.8895e-03, 8.4077e-03, 9.2272e-03,
         3.8409e-03, 7.8632e-03, 1.2310e-02, 7.6161e-02, 1.7915e-02, 3.0854e-02,
         4.2111e-02, 1.8258e-02, 3.3905e-02, 2.9221e-02, 8.9485e-02, 7.6939e-02],
        [6.7770e-02, 2.9528e-02, 5.8736e-02, 1.6303e-02, 7.3499e-02, 4.8006e-02,
         4.2763e-02, 4.9812e-02, 1.6232e-02, 3.7550e-04, 1.0014e-04, 1.2337e-03,
         4.2675e-04, 4.0432e-04, 1.8811e-04, 1.2231e-02, 4.6974e-02, 5.7076e-02,
         5.0806e-02, 6.3656e-02, 9.0147e-02, 1.2159e-01, 1.5791e-02, 3.6935e-02],
        [7.0917e-02, 3.9410e-02, 4.4955e-02, 1.9495e-02, 4.3082e-02, 1.9820e-02,
         1.0992e-01, 8.9653e-02, 1.9615e-02, 1.3402e-02, 3.4477e-03, 1.1566e-02,
         9.8973e-03, 2.3897e-02, 9.0668e-03, 1.5490e-02, 1.5208e-02, 3.5183e-02,
         5.1934e-02, 1.0888e-01, 6.0186e-02, 8.0560e-02, 1.9521e-02, 4.1533e-02],
        [7.3402e-02, 5.5083e-02, 7.3332e-02, 5.2238e-02, 4.8512e-02, 2.6609e-02,
         7.1359e-02, 3.0227e-02, 5.2184e-02, 3.7800e-03, 1.4793e-03, 3.5185e-03,
         1.2530e-03, 6.2069e-03, 4.4170e-03, 3.9913e-02, 3.1964e-02, 7.2925e-02,
         2.6476e-02, 4.3997e-02, 6.8329e-02, 8.1571e-02, 5.0849e-02, 1.7959e-02],
        [4.0512e-02, 3.9551e-02, 6.5053e-02, 3.2882e-02, 4.9105e-02, 7.2485e-02,
         2.2645e-02, 1.2710e-02, 3.2836e-02, 3.0208e-02, 4.0186e-02, 6.9794e-02,
         2.6245e-02, 9.2029e-02, 5.8314e-02, 2.8314e-02, 2.5194e-02, 4.2633e-02,
         2.2257e-02, 2.4377e-02, 4.1102e-02, 5.1038e-02, 3.2597e-02, 2.1396e-02],
        [4.7662e-03, 8.0312e-03, 8.4058e-03, 1.4366e-01, 1.1627e-02, 2.4519e-02,
         1.1019e-02, 2.5886e-02, 1.4293e-01, 5.0888e-03, 1.2422e-02, 1.1076e-02,
         1.0603e-02, 1.8308e-02, 1.0975e-02, 1.1430e-01, 3.8103e-02, 4.1989e-02,
         1.6911e-02, 2.5934e-02, 4.0428e-02, 1.0573e-01, 1.3983e-01, 5.5560e-03],
        [1.4957e-02, 3.2509e-02, 3.1079e-02, 3.4686e-03, 5.0143e-02, 7.0228e-03,
         2.8079e-02, 6.3270e-03, 3.4780e-03, 3.4937e-03, 4.9953e-04, 1.0301e-03,
         1.3514e-03, 3.8533e-03, 1.8502e-03, 3.0443e-03, 1.0073e-01, 1.6688e-01,
         3.7605e-02, 2.4433e-02, 1.9494e-01, 1.8356e-01, 3.5092e-03, 4.5591e-03],
        [1.6805e-01, 7.0912e-02, 8.2866e-02, 4.1625e-02, 4.6934e-02, 3.8950e-02,
         1.3644e-01, 1.9286e-02, 4.1325e-02, 4.8915e-03, 1.8935e-03, 1.2695e-03,
         1.3345e-03, 4.3902e-03, 3.1051e-03, 3.6085e-02, 7.3072e-03, 3.8269e-02,
         2.3223e-02, 1.0033e-02, 4.7459e-02, 4.9558e-02, 3.9952e-02, 4.4906e-02],
        [3.2699e-02, 5.3724e-02, 1.8123e-02, 4.1401e-02, 1.8138e-02, 1.7870e-02,
         1.3234e-01, 2.6384e-02, 4.1400e-02, 2.6874e-02, 8.9529e-03, 2.3039e-02,
         4.0302e-03, 2.8555e-02, 1.8251e-02, 3.4747e-02, 4.9756e-02, 7.5777e-02,
         5.4779e-02, 3.8367e-02, 4.6330e-02, 6.3429e-02, 4.1174e-02, 5.9552e-02],
        [6.0238e-02, 2.5602e-02, 4.2559e-02, 2.6144e-02, 4.3952e-02, 1.0944e-01,
         1.4571e-01, 4.2017e-02, 2.6101e-02, 5.0681e-03, 1.7276e-03, 4.2100e-03,
         8.0228e-04, 5.7174e-03, 3.1362e-03, 2.0360e-02, 1.8735e-02, 6.0264e-02,
         5.2154e-02, 3.3526e-02, 7.2682e-02, 1.2959e-01, 2.5471e-02, 1.1575e-02]])

In [144]:
generate_scoring_by_head_agg(cls_tokens_scores=cls_tokens, input_ids=input_ids_2, agg='mean')

was
,
boring
movie
of
t
overall
like
but
really
the
didn
i
some
'
good
this
movie
were
the
.


Lior

In [85]:
l = [('[CLS]', 0.0), ('this', 0.4398406744003296), ('movie', 0.3385171890258789), ('was', 0.2850261628627777), ('the', 0.3722951412200928), ('best', 0.6413642764091492), ('movie', 0.3098682463169098), ('i', 0.20284101366996765), ('have', 0.12214731425046921), ('ever', 0.15835356712341309), ('seen', 0.2082878053188324), ('!', 0.6001579761505127), ('some', 0.021879158914089203), ('scenes', 0.05488050356507301), ('were', 0.0371897891163826), ('ridiculous', 0.03780526667833328), (',', 0.02076297625899315), ('but', 0.44531309604644775), ('acting', 0.45006945729255676), ('was', 0.5168584585189819), ('great', 1.0), ('.', 0.035734280943870544), ('[SEP]', 0.10382220149040222)]
d = {}
for token, score in l:
    d[token] = score
{k: v for k, v in sorted(d.items(), key=lambda item: item[1])}