# Attention with BERT

In [None]:
import torch
from transformers import BertForPreTraining
from transformers import BertJapaneseTokenizer, BertModel

tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-v2')
model = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-v2')

In [None]:
text = '私の大好きな食べ物はカレーライスです。'

import pandas as pd
import numpy as np

encoded_input = tokenizer._encode_plus(text, truncatin=True, padding=False, return_tensors='pt')

# トークナイズ

In [None]:
input_ids = encoded_input["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
tokens = [s.replace('__', '') for s in tokens]

outputs = model(**encoded_input)

df = pd.DataFrame(data=torch.squeeze(outputs.last_hidden_state, 0).detach().numpy(), index=tokens)
df

## Self-Attention の実装。
$Q$, $K$, $V$ はすべて、同じ文章から作られる。

In [None]:
from scipy.special import softmax
import numpy as np
import pandas as pd

In [None]:
Q = df.values
K = df.values
V = df.values

$$
\mathrm{Attention} (Q, K, V) = \mathrm{Softmax} \left( \frac{Q K^T}{\sqrt{d_K}} \right) V
$$

In [None]:
scaled_attention_logits = np.dot(Q, K.T) / np.sqrt(K.shape[1])
attention_weights = softmax(scaled_attention_logits, axis=1)
np.dot(attention_weights, V).shape

## 計算結果の出力。

In [None]:
weights = attention_weights.copy()
np.fill_diagonal(weights, 0)

result_df = pd.DataFrame(data=weights, index=tokens, columns=tokens)
result_df.drop('[CLS]', axis=0, inplace=True)
result_df.drop('[CLS]', axis=1, inplace=True)
result_df.drop('[SEP]', axis=0, inplace=True)
result_df.drop('[SEP]', axis=1, inplace=True)
result_df

In [None]:
import matplotlib.font_manager as fm

# 利用可能なフォントのリストを取得
fonts = [f.name for f in fm.fontManager.ttflist]
# 日本語フォントを探す
japanese_fonts = [f for f in fonts if 'Gothic' in f or 'Mincho' in f or 'Hiragino' in f or 'IPA' in f or 'Yu' in f]
print(japanese_fonts)

In [None]:
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['font.family'] = 'Hiragino Sans'  # または 'MS Gothic', 'Yu Gothic' など

sns.heatmap(result_df)