<a href="https://colab.research.google.com/github/trtd56/SpamExplainable/blob/main/spam_train_bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 文書分類における古典的手法とBERTの判断根拠の比較

必要なライブラリのインストール

In [None]:
!pip install -U transformers
!wget https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz
!tar -xvf sst2_tiny.tar.gz
!pip install lit_nlp tfds-nightly
!pip install lime

SMS Spam Collectionのデータのダウンロードと解凍

In [None]:
!wget https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip
!unzip smsspamcollection.zip

In [None]:
from sklearn.model_selection import train_test_split
import pandas as pd
import torch
import torch.nn as nn
from transformers import  AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup, AdamW, BertConfig
from tqdm.notebook import tqdm
from IPython.display import display, HTML
import numpy as np
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix

device = torch.device("cuda")

## 学習

データの読み込みと分割

In [None]:
SEED = 0

spam_df = pd.read_csv("SMSSpamCollection", sep='\t', header=None)  # データの読み込み

# ラベルと文章を分ける
labels = spam_df[0].values
sentences = spam_df[1].values

label_dic = {'ham': 0, 'spam': 1}  # spamを真値とする
label_dic_inv = {v: k for k, v in label_dic.items()}
label_ids = [label_dic[i] for i in labels]

# 7:3に学習データとテストデータを分割する
train_sentence, test_sentence, y_train, y_test = train_test_split(sentences, label_ids, test_size=0.3, random_state=SEED, stratify=label_ids)

spam学習用のデータセット

In [None]:
class SpamDataset():
    def __init__(self, toks, targets):
        self.toks = toks
        self.targets = targets

    def __len__(self):
        return len(self.toks)

    def __getitem__(self, item):
        tok = self.toks[item]
        target = self.targets[item]

        input_ids = torch.tensor(tok["input_ids"])
        attention_mask = torch.tensor(tok["attention_mask"])
        token_type_ids = torch.tensor(tok["token_type_ids"])
        target = torch.tensor(target).float()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
            "target": target,
        }

スパム学習用のBERTクラス

In [None]:
class SpamBert(nn.Module):
    def __init__(self, model_type, tokenizer):
        super(SpamBert, self).__init__()

        bert_conf = BertConfig(model_type, output_hidden_states=False, output_attentions=True)
        bert_conf.vocab_size = tokenizer.vocab_size

        self.bert = AutoModel.from_pretrained(model_type, config=bert_conf)
        self.fc = nn.Linear(bert_conf.hidden_size, 1)

    def forward(self, ids, mask, token_type_ids):
        out = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
        h = out['pooler_output']
        a = out['attentions']
        h = nn.ReLU()(h)
        h = self.fc(h)
        h = h[:, 0]
        a = a[-1].sum(1)[:, 0, :]
        return h, a

学習・評価用関数

In [None]:
loss_fn = nn.BCEWithLogitsLoss()

def train_loop(train_dataloader, model, optimizer, device, tqdm):
    losses = []
    model.train()
    optimizer.zero_grad()
    for n_iter, d in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        token_type_ids = d["token_type_ids"].to(device)
        target = d["target"].to(device)

        output, _ = model(input_ids, attention_mask, token_type_ids)
        loss = loss_fn(output, target)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        losses.append(loss.item())
    return losses

def test_loop(test_dataloader, model, device, tqdm):
    losses, predicts = [], []
    model.eval()
    for n_iter, d in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        token_type_ids = d["token_type_ids"].to(device)
        target = d["target"].to(device)

        with torch.no_grad():
            output, _ = model(input_ids, attention_mask, token_type_ids)

        loss = loss_fn(output, target)
        losses.append(loss.item())
        predicts += output.sigmoid().cpu().tolist()

    return predicts, np.array(losses).mean()

パラメータとトークナイザの定義

In [None]:
MODEL_TYPE = "bert-base-uncased"
LEAENING_RATE = 1e-6
BATCH_SIZE = 64
N_EPOCHS = 10

TOKENIZER = AutoTokenizer.from_pretrained(MODEL_TYPE)

学習データとテストデータをともにトークナイズし、dataloaderを定義する

In [None]:
train_toks = []
for sent in train_sentence:
    tok = TOKENIZER.encode_plus(sent,
                                   add_special_tokens=True,
                                   max_length=128,
                                   pad_to_max_length=True)
    train_toks.append(tok)

test_toks = []
for sent in test_sentence:
    tok = TOKENIZER.encode_plus(sent,
                                   add_special_tokens=True,
                                   max_length=128,
                                   pad_to_max_length=True)
    test_toks.append(tok)

train_dataset = SpamDataset(train_toks, y_train)
test_dataset = SpamDataset(test_toks, y_test)
train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        drop_last=True,
        shuffle=True,
)
test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        drop_last=False,
        shuffle=False,
)

学習を行う

In [None]:
model = SpamBert(MODEL_TYPE, TOKENIZER)
model.to(device)

optimizer = AdamW(model.parameters(), lr=LEAENING_RATE)

train_losses, test_losses = [], []
for epoch in range(N_EPOCHS):
    print(f"Epoch-{epoch}")
    train_losses += train_loop(train_dataloader, model, optimizer, device, tqdm)
    y_pred, test_loss = test_loop(test_dataloader, model, device, tqdm)

    test_losses.append(test_loss)

    # 各epochでのの　Confusion Matrixを確認
    _y_pred = (np.array(y_pred) > 0.5).astype(int)
    cm = confusion_matrix(y_test, _y_pred)
    cm_df = pd.DataFrame(cm,columns=['Predicted ham', 'Predicted spam'], index=['Actual ham', 'Actual spam'])
    display(cm_df)

学習結果の確認

In [None]:
plt.plot(train_losses)

In [None]:
plt.plot(test_losses)

## Attention

In [None]:
def highlight_r(word, attn):
  html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
  return '<span style="background-color: {}">{}</span>'.format(html_color, word)

In [None]:
def show_bert_explaination(check_idx):
    for idx, d in enumerate(test_dataset):
        if idx == check_idx:
            break

    input_ids = d["input_ids"].to(device).unsqueeze(0)
    attention_mask = d["attention_mask"].to(device).unsqueeze(0)
    token_type_ids = d["token_type_ids"].to(device).unsqueeze(0)
    target = d["target"].to(device)

    with torch.no_grad():
        output, attention = model(input_ids, attention_mask, token_type_ids)

    attention = attention.cpu()[0].numpy()
    attention_mask = attention_mask.cpu()[0].numpy()
    attention = attention[attention_mask == 1][1:-1]

    ids = input_ids.cpu()[0][attention_mask == 1][1:-1].tolist()
    tokens = TOKENIZER.convert_ids_to_tokens(ids)

    html_outputs = []
    
    for word, attn in zip(tokens, attention):
        html_outputs.append(highlight_r(word, attn))
    
    display(HTML(' '.join(html_outputs)))

スパム

In [None]:
show_bert_explaination(15)
show_bert_explaination(27)
show_bert_explaination(28)

非スパム

In [None]:
show_bert_explaination(0)
show_bert_explaination(1)
show_bert_explaination(2)

## LIME

In [None]:
def predictor(texts):
    tok = TOKENIZER.batch_encode_plus(texts, padding=True)
    input_ids = torch.tensor(tok['input_ids']).to(device)
    attention_mask = torch.tensor(tok['attention_mask']).to(device)
    token_type_ids = torch.tensor(tok['token_type_ids']).to(device)

    with torch.no_grad():
        output, _ = model(input_ids, attention_mask, token_type_ids)

    probas = output.sigmoid().cpu().numpy()
    return np.vstack([1 - probas, probas]).T

In [None]:
from lime.lime_text import LimeTextExplainer
explainer = LimeTextExplainer(class_names=['ham', 'spam'])

In [None]:
str_to_predict = test_sentence[0]
exp = explainer.explain_instance(str_to_predict, predictor, num_features=20, num_samples=100)
exp.show_in_notebook(text=str_to_predict)

In [None]:
str_to_predict = test_sentence[27]
exp = explainer.explain_instance(str_to_predict, predictor, num_features=20, num_samples=100)
exp.show_in_notebook(text=str_to_predict)

## Grad-CAM

In [None]:
class GradCAM:
    def __init__(self, model, feature_layer):
        self.model = model
        self.feature_layer = feature_layer
        self.model.eval()
        self.feature_grad = None
        self.feature_map = None
        self.hooks = []

        def save_feature_grad(module, in_grad, out_grad):
            self.feature_grad = out_grad[0]
        self.hooks.append(self.feature_layer.register_backward_hook(save_feature_grad))

        def save_feature_map(module, inp, outp):
            self.feature_map = outp[0]
        self.hooks.append(self.feature_layer.register_forward_hook(save_feature_map))

    def forward(self, input_ids, attention_mask, token_type_ids):
        return self.model(input_ids, attention_mask, token_type_ids)

    def backward_on_target(self, output, target):
        self.model.zero_grad()
        output.backward(gradient=target, retain_graph=True)

    def clear_hook(self):
        for hook in self.hooks:
            hook.remove()

In [None]:
grad_cam = GradCAM(model, model.bert.encoder.layer[-1])

In [None]:
for idx, d in enumerate(test_dataset):
    if idx == 27:
        break
input_ids = d["input_ids"].to(device).unsqueeze(0)
attention_mask = d["attention_mask"].to(device).unsqueeze(0)
token_type_ids = d["token_type_ids"].to(device).unsqueeze(0)

In [None]:
model_output, _ = grad_cam.forward(input_ids, attention_mask, token_type_ids)
predicted_label = model_output.sigmoid()
grad_cam.backward_on_target(model_output, predicted_label)

feature_grad = grad_cam.feature_grad.cpu().data.numpy()[0]
weights = np.mean(feature_grad, axis=1)

feature_map = grad_cam.feature_map.cpu().data.numpy()

cam = np.sum((weights * feature_map.T), axis=2).T
cam = np.maximum(cam, 0)
grad_cam.clear_hook()

In [None]:
attention_mask = attention_mask.cpu()[0].numpy()
ids = input_ids.cpu()[0][attention_mask == 1][1:-1].tolist()
tokens = TOKENIZER.convert_ids_to_tokens(ids)

html_outputs = []
    
for word, attn in zip(tokens, cam.sum(1)):
    html_outputs.append(highlight_r(word, attn))
    
display(HTML(' '.join(html_outputs)))

In [None]:
cam_w = cam.sum(1)[attention_mask == 1][1:-1].tolist()
x = np.arange(len(tokens))

In [None]:
from sklearn.preprocessing import MinMaxScaler

In [None]:
scaler = MinMaxScaler()
color_arr = scaler.fit_transform(np.array(cam_w).reshape(-1, 1))

In [None]:
colorlist = [[1.0, 1-min([c, 1.0]), 1.0] for c in color_arr.T[0]]

In [None]:
width = 0.35

fig, ax = plt.subplots(figsize=(24,4))

rect = ax.bar(x, cam_w, width, color=colorlist)
ax.set_xticks(x)
ax.set_xticklabels(tokens, rotation=45)
plt.ylim(0.027, 0.035)

plt.show()

## LIT(WIP)

In [None]:
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types

class SpamDataset_lit(lit_dataset.Dataset):

  LABELS = ['0', '1']

  def __init__(self):
    self._examples = []
    for X, y in zip(test_sentence, y_test):
      self._examples.append({
          'sentence': X,
          'label': self.LABELS[y],
      })

  def spec(self):
    return {
        'sentence': lit_types.TextSegment(),
        'label': lit_types.CategoryLabel(vocab=self.LABELS)
    }

In [None]:
BERT_CONF = BertConfig(MODEL_TYPE, output_hidden_states=False, output_attentions=True)
BERT_CONF.vocab_size = TOKENIZER.vocab_size

In [None]:
model.config = BERT_CONF

In [None]:
from lit_nlp.api import model as lit_model
class SpamBert_lit(lit_model.Model):
    compute_grads = False
    def __init__(self):
        self.model = model
        self.config = BERT_CONF
        self.model.eval()

    def max_minibatch_size(self):
        return 8

    def predict_minibatch(self, inputs):
        encoded_input = TOKENIZER.batch_encode_plus(
            [sent['sentence'] for sent in inputs],
            add_special_tokens=True,
            max_length=128,
            pad_to_max_length=True)
        encoded_input = {
            key : torch.tensor(value, dtype=torch.long) for key, value in encoded_input.items()
        }
        
        if torch.cuda.is_available():
            self.model.cuda()
            for tensor in encoded_input:
                encoded_input[tensor] = encoded_input[tensor].cuda()
    
        with torch.set_grad_enabled(self.compute_grads):
            outputs = self.model(encoded_input['input_ids'], encoded_input['attention_mask'])
            logits, _ = outputs
            out= self.model.bert(encoded_input['input_ids'], encoded_input['attention_mask'])
    
            output_attentions = out['attentions']
            last_hidden_state = out['last_hidden_state']
            #if self.model.config.output_attentions:
            #    logits, hidden_states, output_attentions = outputs[0], outputs[1], outputs[2]
            #else:
            #    logits, hidden_states = outputs[0], outputs[1]

        batched_outputs = {
            "input_ids": encoded_input["input_ids"],
            "ntok": torch.sum(encoded_input["attention_mask"], dim=1),
            #"cls_emb": hidden_states[-1][:, 0],
            "cls_emb": last_hidden_state[:, 0],
            "score": torch.squeeze(logits, dim=-1)
        }
        
        if self.model.config.output_attentions:
            assert len(output_attentions) == self.model.config.num_hidden_layers
            for i, layer_attention in enumerate(output_attentions[-2:]):
                batched_outputs[f"layer_{i}/attention"] = layer_attention

        if self.compute_grads:
            scalar_pred_for_gradients = batched_outputs["score"]
            batched_outputs["input_emb_grad"] = torch.autograd.grad(
                scalar_pred_for_gradients,
                hidden_states[0],
                grad_outputs=torch.ones_like(scalar_pred_for_gradients)
            )[0]

        detached_outputs = {k: v.cpu().numpy() for k, v in batched_outputs.items()}
        for output in utils.unbatch_preds(detached_outputs):
            ntok = output.pop("ntok")
            output["tokens"] = self.tokenizer.convert_ids_to_tokens(
                output.pop("input_ids")[1:ntok - 1]
            )
            if self.compute_grads:
                output["token_grad_sentence"] = output["input_emb_grad"][:ntok]
            if self.model.config.output_attentions:
                for key in output:
                    if not re.match(r"layer_(\d+)/attention", key):
                        continue
                    output[key] = output[key][:, :ntok, :ntok].transpose((0, 2, 1))
                    output[key] = output[key].copy()
            yield output

    def input_spec(self) -> lit_types.Spec:
        return {
            "sentence": lit_types.TextSegment(),
            #"label": lit_types.RegressionScore(),
            'label': lit_types.CategoryLabel(vocab=['0', '1']),
        }

    def output_spec(self) -> lit_types.Spec:
        ret = {
            "tokens": lit_types.Tokens(),
            "score": lit_types.RegressionScore(parent="label"),
            #'label': lit_types.CategoryLabel(vocab=['0', '1']),
            "cls_emb": lit_types.Embeddings()
        }
        if self.compute_grads:
            ret["token_grad_sentence"] = lit_types.TokenGradients(
                align="tokens"
            )
        if self.model.config.output_attentions:
            for i in range(2): # self.model.config.num_hidden_layers
                ret[f"layer_{i}/attention"] = lit_types.AttentionHeads(
                    align_in="tokens", align_out="tokens")
        return ret

In [None]:
datasets = {
    'test':SpamDataset_lit(),
}
models = {
    'model_0': SpamBert_lit(),
}


from lit_nlp import notebook
widget = notebook.LitWidget(models, datasets, height=800)

In [None]:
widget.render()