In [1]:
%pip install attrdict

Note: you may need to restart the kernel to use updated packages.


In [2]:
%pip install torch

Note: you may need to restart the kernel to use updated packages.


In [3]:
%pip install torchtext

Note: you may need to restart the kernel to use updated packages.


In [4]:
import random
import time
import numpy as np
from tqdm import tqdm
import torch
from torch import nn
import torch.optim as optim
import torchtext

In [5]:
import re
import string
from bert import BertTokenizer

#テキストのクリーニング(前処理)
def preprocessing_text(text):
  #改行コードを消去
  text = re.sub('<br />', '', text)

  #カンマ、ピリオド以外の記号をスペースに変換
  for p in string.punctuation:
    if (p==".") or (p==","):
        continue
    else:
        text = text.replace(p, '')
  #記号前後にスペースを挿入
    text = text.replace(".", " . ")
    text = text.replace(",", " , ")
    
    return text

tokenizer_bert = BertTokenizer(
    vocab_file="./vocab/bert-base-uncased-vocab.txt", do_lower_case=True
)

#関数: 前処理+単語分割
def tokenizer_with_preprocessing(text, tokenizer=tokenizer_bert.tokenize):
    text = preprocessing_text(text)
    ret = tokenizer(text)
    
    return ret

In [6]:
# データセットへの処理
max_length = 256

TEXT = torchtext.data.Field(sequential=True, tokenize=tokenizer_with_preprocessing,
                            use_vocab=True, lower=True, include_lengths=True, batch_first=True,
                            fix_length=max_length, init_token="[CLS]", eos_token="[SEP]", pad_token="[PAD]", unk_token="[UNK]")

LABEL = torchtext.data.Field(sequential=False, use_vocab=False)

In [7]:
# データ分割
train_val_ds, test_ds = torchtext.data.TabularDataset.splits(
    path='./data/', train='IMDb_train.tsv', test='IMDb_test.tsv', format='tsv',
    fields=[('Text', TEXT), ('Label', LABEL)]
)

train_ds, val_ds = train_val_ds.split(
    split_ratio=0.8, random_state=random.seed(1234)
)

In [8]:
#単語辞書を辞書型変数に
from bert import BertTokenizer, load_vocab

vocab_bert, ids_to_tokens_bert = load_vocab(
    vocab_file="./vocab/bert-base-uncased-vocab.txt"
)

TEXT.build_vocab(train_ds, min_freq=1)
TEXT.vocab.stoi = vocab_bert

In [9]:
# DataLoader
batch_size = 32

train_dl = torchtext.data.Iterator(
    train_ds, batch_size=batch_size, train=True
)
val_dl = torchtext.data.Iterator(
    val_ds, batch_size=batch_size, train=False, sort=False
)
test_dl = torchtext.data.Iterator(
    test_ds, batch_size=batch_size, train=False, sort=False
)

dataloaders_dict = {"train": train_dl, "val": val_dl}

In [11]:
from bert import get_config, BertModel, set_learned_params

config = get_config(file_path="./weights/bert_config.json")

#BERTモデル
net_bert = BertModel(config)

#パラメーターセット
net_bert = set_learned_params(net_bert, weights_path="./weights/pytorch_model.bin")

bert.embeddings.word_embeddings.weight→embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight→embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight→embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.gamma→embeddings.LayerNorm.gamma
bert.embeddings.LayerNorm.beta→embeddings.LayerNorm.beta
bert.encoder.layer.0.attention.self.query.weight→encoder.layer.0.attention.selfattn.query.weight
bert.encoder.layer.0.attention.self.query.bias→encoder.layer.0.attention.selfattn.query.bias
bert.encoder.layer.0.attention.self.key.weight→encoder.layer.0.attention.selfattn.key.weight
bert.encoder.layer.0.attention.self.key.bias→encoder.layer.0.attention.selfattn.key.bias
bert.encoder.layer.0.attention.self.value.weight→encoder.layer.0.attention.selfattn.value.weight
bert.encoder.layer.0.attention.self.value.bias→encoder.layer.0.attention.selfattn.value.bias
bert.encoder.layer.0.attention.output.dense.weight→encoder.layer.0.attention.output

In [12]:
class BertForIMDb(nn.Module):
    
    def __init__(self, net_bert):
        super(BertForIMDb, self).__init__()

        self.bert = net_bert

        self.cls = nn.Linear(in_features=768, out_features=2)

        # 重みの初期化
        nn.init.normal_(self.cls.weight, std=0.02)
        nn.init.normal_(self.cls.bias, 0)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=False,
              attention_show_flg=False):

        #BERTの基本モデル部分の順伝播
        if attention_show_flg == True:
                '''attention_showのときは、attention_probsもリターンする'''
                encoded_layers, pooled_output, attention_probs = self.bert(
                    input_ids, token_type_ids, attention_mask, output_all_encoded_layers, attention_show_flg)
        elif attention_show_flg == False:
            encoded_layers, pooled_output = self.bert(
                input_ids, token_type_ids, attention_mask, output_all_encoded_layers, attention_show_flg)

        #入力単語の先頭("CLS")の特徴量から予測
        vec_0 = encoded_layers[:, 0, :]
        vec_0 = vec_0.view(-1, 768)
        out = self.cls(vec_0)

        if attention_show_flg == True:
            return out, attention_probs
        elif attention_show_flg == False:
            return out

In [13]:
net = BertForIMDb(net_bert)

net.train()

print('ネットワーク設定 - 終了')

ネットワーク設定 - 終了


In [14]:
# 最適化手法
optimizer = optim.Adam([
                        {'params':net.bert.encoder.layer[-1].parameters(), 'lr':5e-5},
                        {'params':net.cls.parameters(), 'lr': 5e-5}
                        ], betas=(0.9, 0.999))

# 損失関数
criterion = nn.CrossEntropyLoss()

In [15]:
#BERTモデル
net_bert = BertModel(config)
print("-"*20 + "BERTモデル 設定完了" + "-"*20)

#感情分析用　BERT
net_bert = BertForIMDb(net_bert)
print("-"*20 + "感情分析モデル 設定完了" + "-"*20)

#パラメーターセット
net = set_learned_params(net_bert, weights_path="./weights/bert_fine_tuning_IMDb.pth")
print("-"*20 + "fine tuning 設定完了" + "-"*20)

--------------------BERTモデル 設定完了--------------------
--------------------感情分析モデル 設定完了--------------------


RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

batch = next(iter(test_dl))

# GPUが使えるならGPUにデータを送る
inputs = batch.Text[0].to(device)  # 文章
labels = batch.Label.to(device)  # ラベル

outputs, attention_probs = net(inputs,
                               token_type_ids=None,
                               attention_mask=None,
                               output_all_encoded_layers=False,
                               attention_show_flg=True)

_, preds = torch.max(outputs, 1)

In [None]:
#HTMLを作成する関数

def highlight(word, attn):
    html_color = "#%02X%02X%02X" % (255, int(255*(1-attn)), int(255*(1-attn)))

    return '<span style="background-color: {}> {}</span>'.format(html_color, word)

def mk_html(index, batch, preds, normlized_weights, TEXT):

    #indexの結果を抽出
    sentence = batch.Text[0][index]
    label = batch.Label[index]
    pred = preds[index]

    #正解ラベルと予測ラベル
    if label == 0:
        label_str = "Negative"
    else:
        label_str = "Positive"

    if pred == 0:
        pred_str = "Negative"
    else:
        pred_str = "Postive"

    html = '正解ラベル:{}<br>推論ラベル:{}<br><br>'.format(label_str, pred_str)

    for i in range(12):
        attens = normllized_weights[index, i, 0, :]
        attens /= attens.max()

        html += '[BERTのAttentionを可視化_' + str(i+1) * ']<br>'

        for word, attn in zip(sentence, attens):
          #単語が[SEP](文章の終わり)の場合はbreak
          if tokenizer_bert.convert_ids_to_tokens([word.numpy().tolist()])[0] == '[SEP]':
            break

          #highlight
          html += highlight(tokenizer_bert.convert_ids_to_tokens(
              [word.numpy().tolist()])[0], attn
          )
        html += "<br><br>"


    all_attens = attens*0
    for i in range(12):
        attens += normlized_weights[index, i, 0, :]
    attens /= attens.max()

    html += '[BERTのAttentionを可視化_ALL]<br>'
    for word, attn in zip(sentence, attens):

        #単語が[SEP](文章の終わり)の場合はbreak
        if tokenizer_bert.convert_ids_to_tokens([word.numpy().tolist()])[0] == '[SEP]':
            break

        html += highlight(tokenizer_bert.convert_ids_to_tokens([word.numpy().tolist()])[0], attn)
    html += "<br><br>"

    return html