<a href="https://colab.research.google.com/github/re-study/re-study/blob/yanagi/colab/7_7_mycode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# はじめに
ここではTransfoirmerの学習・推論を行い、判定根拠の可視化を実装する。

In [None]:
pip install torchtext==0.8.1



In [None]:
# マウント
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
#%cd drive/MyDrive/hogehoge…

In [None]:
# パッケージのimport
import numpy as np
import random

import torch
import torch.nn as nn
import torch.optim as optim

import torchtext

In [None]:
# 乱数のシードを設定
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

# DatasetとDataLoaderを作成

In [None]:
from utils.dataloader import get_IMDb_DataLoaders_and_TEXT

# 読み込み
train_dl, val_dl, test_dl, TEXT = get_IMDb_DataLoaders_and_TEXT(
    max_length=256, batch_size=64)

# 辞書オブジェクト
dataloaders_dict = {"train": train_dl, "val":val_dl}



# ネットワークモデルの作成

In [None]:
from utils.transformer import TransformerClassification

# モデル構築
net = TransformerClassification(
    text_embedding_vectors=TEXT.vocab.vectors, d_model=300, max_seq_len=256, output_dim=2)

# ネットワークの初期化を定義
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        # Linear層の初期化
        # https://pytorch.org/docs/stable/nn.init.html
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)

# 訓練モードに設定
net.train()

# 初期化実行
net.net3_1.apply(weights_init)
net.net3_2.apply(weights_init)

print('設定完了')

設定完了


# 損失関数と最適化手法を定義

In [None]:
# 損失関数の設定
criterion = nn.CrossEntropyLoss()

# 最適化手法
learning_rate = 2e-5
optimizer = optim.Adam(net.parameters(), lr=learning_rate)

# 学習を実施

In [None]:
def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):

    # GPUが使えるかを検証
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("使用デバイス:", device)
    print("----------start----------")
    # net→GPU
    net.to(device)

    # ネットワークがある程度固定であれば、高速化させる
    torch.backends.cudnn.benchmark = True

    # epoch
    for epoch in range(num_epochs):
        for phase in ['train', 'val']:
            if phase == 'train': # 訓練モードかそうでないかの設定
                net.train()
            else:
                net.eval()

            epoch_loss = 0.0 # 損失和
            epoch_corrects = 0 # 正解数

            # dataloader
            for batch in (dataloaders_dict[phase]):
                # batchはText,Labelの辞書オブジェクト
                inputs = batch.Text[0].to(device)
                labels = batch.Label.to(device)

                # optimizerを初期化
                optimizer.zero_grad()

                # forward
                with torch.set_grad_enabled(phase=='train'):

                    # mask
                    input_pad = 1
                    input_mask = (inputs != input_pad)

                    # Transformerに入力
                    outputs, _, _ = net(inputs, input_mask)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1) # ラベルを予測

                    # 訓練時はバックプロパゲーション
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                    # 結果の計算
                    epoch_loss += loss.item() * inputs.size(0) # lossの合計を更新
                    # 正解数の合計を更新
                    epoch_corrects += torch.sum(preds == labels.data)

            # epochごとのlossと正解率
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)

            print('Epoch {}/{} ({:^5}) ----------------> Loss: {:.4f} Acc: {:.4f}'.format(epoch+1, num_epochs, phase, epoch_loss, epoch_acc))

    return net


In [None]:
# 学習・検証を実行する 15分ほどかかります
num_epochs = 10
net_trained = train_model(net, dataloaders_dict,
                          criterion, optimizer, num_epochs=num_epochs)

使用デバイス: cuda:0
----------start----------




Epoch 1/10 (train) ----------------> Loss: 0.4034 Acc: 0.8213
Epoch 1/10 ( val ) ----------------> Loss: 0.3971 Acc: 0.8294
Epoch 2/10 (train) ----------------> Loss: 0.3818 Acc: 0.8311
Epoch 2/10 ( val ) ----------------> Loss: 0.3811 Acc: 0.8360
Epoch 3/10 (train) ----------------> Loss: 0.3710 Acc: 0.8375
Epoch 3/10 ( val ) ----------------> Loss: 0.3689 Acc: 0.8392
Epoch 4/10 (train) ----------------> Loss: 0.3587 Acc: 0.8442
Epoch 4/10 ( val ) ----------------> Loss: 0.3763 Acc: 0.8432
Epoch 5/10 (train) ----------------> Loss: 0.3485 Acc: 0.8495
Epoch 5/10 ( val ) ----------------> Loss: 0.3598 Acc: 0.8466
Epoch 6/10 (train) ----------------> Loss: 0.3414 Acc: 0.8528
Epoch 6/10 ( val ) ----------------> Loss: 0.3556 Acc: 0.8516
Epoch 7/10 (train) ----------------> Loss: 0.3373 Acc: 0.8544
Epoch 7/10 ( val ) ----------------> Loss: 0.3531 Acc: 0.8474
Epoch 8/10 (train) ----------------> Loss: 0.3299 Acc: 0.8583
Epoch 8/10 ( val ) ----------------> Loss: 0.3552 Acc: 0.8506
Epoch 9/

# テストデータでの正解率を求める

In [None]:
# device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net_trained.eval()
net_trained.to(device)

epoch_corrects = 0 # 正解数

for batch in (test_dl):
    inputs = batch.Text[0].to(device)
    labels = batch.Label.to(device)

    # forward
    with torch.set_grad_enabled(False):

        # mask
        input_pad = 1
        input_mask = (inputs != input_pad)

        # Transformer
        outputs, _, _ = net_trained(inputs, input_mask)
        _, preds = torch.max(outputs, 1) # ラベルを予測

        # 結果の計算
        # 正解数の合計を更新
        epoch_corrects += torch.sum(preds == labels.data)


# 正解率
epoch_acc = epoch_corrects.double() / len(test_dl.dataset)

print('テストデータ{}個での正解率：{:.4f}'.format(len(test_dl.dataset),epoch_acc))



テストデータ25000個での正解率：0.8542


# Attentionの可視化で判定根拠を探る

In [None]:
# HTMLを作成する関数を実装

def highlight(word, attn):
    '''Attentionの値が大きいと文字の背景が濃い赤になるhtmlを出力'''

    html_color = '#%02X%02X%02X' % (
        255, int(255*(1 - attn)), int(255*(1 - attn)))
    return '<span style="background-color: {}"> {}</span>'.format(html_color, word)

def mk_html(index, batch, preds, normalized_weights_1, normalized_weights_2, TEXT):
        '''HTMLデータを作成する'''

        # indexの結果を抽出
        sentence = batch.Text[0][index] # 文章
        label = batch.Label[index]
        pred = preds[index]

        # indexのAttentionを抽出と規格化
        attens1 = normalized_weights_1[index, 0, :]
        attens1 /= attens1.max()

        attens2 = normalized_weights_2[index, 0, :]
        attens2 /= attens2.max()

        # ラベルと予測結果を文字に置き換え
        if label == 0:
            label_str = "Negative"
        else:
            label_str = "Positive"
        
        if pred == 0:
            pred_str = "Negative"
        else:
            pred_str = "Positive"

        # 表示用のHTMLを作成する
        html = '正解ラベル：{}<br>推論ラベル：{}<br><br>'.format(label_str, pred_str)

        # 1段目のAttention
        html += '[TransformerBlockの1段目のAttentionを可視化]<br>'
        for word, attn in zip(sentence, attens1):
            html += highlight(TEXT.vocab.itos[word], attn)
        html += "<br><br>"

        # 2段目のAttention
        html += '[TransformerBlockの2段目のAttentionを可視化]<br>'
        for word, attn in zip(sentence, attens2):
            html += highlight(TEXT.vocab.itos[word], attn)

        html += "<br><br>"

        return html


In [None]:
from IPython.display import HTML

# Transformerで処理

# ミニバッチの用意
batch = next(iter(test_dl))

# GPUが使えるならGPUにデータを送る
inputs = batch.Text[0].to(device)  # 文章
labels = batch.Label.to(device)  # ラベル

# mask作成
input_pad = 1  # 単語のIDにおいて、'<pad>': 1 なので
input_mask = (inputs != input_pad)

# Transformerに入力
outputs, normlized_weights_1, normlized_weights_2 = net_trained(
    inputs, input_mask)
_, preds = torch.max(outputs, 1)  # ラベルを予測


index = 3  # 出力させたいデータ
html_output = mk_html(index, batch, preds, normlized_weights_1,
                      normlized_weights_2, TEXT)  # HTML作成
HTML(html_output)  # HTML形式で出力



うまくattentionかけられてる

In [None]:
index = 9  # 出力させたいデータ
html_output = mk_html(index, batch, preds, normlized_weights_1,
                      normlized_weights_2, TEXT)  # HTML作成
HTML(html_output)  # HTML形式で出力

これは正解ラベルがそもそもPositiveなのか…？

In [None]:
|