In [8]:
from transformers import BertForSequenceClassification, BertJapaneseTokenizer
from torch.utils.data import SequentialSampler
from src.dataset import My_DATASET
import pandas as pd
import matplotlib.pyplot as plt
import torch

# Attentionの可視化

In [4]:
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
model_path = 'save_model/best_model.pth'
data_path = 'DATA/serched_tweet/イーロンマスク.csv'

In [7]:
# モデルを読み込む
model = BertForSequenceClassification.from_pretrained(
        MODEL_NAME, # 日本語Pre trainedモデルの指定
        num_labels = 3, # ラベル数
        output_attentions = False, # アテンションベクトルを出力するか
        output_hidden_states = False, # 隠れ層を出力するか
    )
model.load_state_dict(torch.load(model_path))
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
df = pd.read_csv(data_path)
df.columns = ['tweet']
tweets = df.tweet.values.tolist()
# データローダー
dataset = My_DATASET(MODEL_NAME, tweets)

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-whole-word-masking were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialize

In [51]:
def highlight(word, attn):
    html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
    return f'<span style="background-color: {html_color}">{word}</span>'

def id2label(id):
    if id == 0:
        return 'positive'
    elif id == 1:
        return 'negative'
    elif id == 2:
        return 'neutral'

In [45]:
tweets[85]

'イーロンマスクは旧経営陣と同類の独裁者ではあるけど、旧悪によって苦しめられた層にとっては旧悪を倒してくれた救世主だし、自分達自身には直接的な不利益はもたらしてはいない以上は良質な経営者。本質的には単に視界に入っていないから、手出ししていないだけだとしてもこの事実には変わりない。'

In [55]:
ind = 85
input_ids, input_mask = dataset[ind]
output = model(input_ids.unsqueeze(0),
            token_type_ids=None, 
            attention_mask=input_mask.unsqueeze(0),
            output_attentions=True)
attention_weight = output.attentions[-1]
id = output.logits.argmax(dim=1).item()
label = id2label(id)
# 文章の長さ分のzero tensorを宣言
seq_len = attention_weight.size()[2]
all_attens = torch.zeros(seq_len)

for i in range(12):
    all_attens += attention_weight[0, i, 0, :]
    
html = f'<big>推論ラベル：{label}</big><br>'
for ids, attn in zip(input_ids, all_attens):
    word = tokenizer.convert_ids_to_tokens([ids.numpy().tolist()])[0]
    if word == "[SEP]":
        break
    html += highlight(word, attn)
    # print(word, attn)
html += "<br><br>"

In [56]:
from IPython.display import HTML
display(HTML(html))