# 6. 学習・推論・可視化

## 6.1. 準備

In [1]:
import sys
sys.path.append("/content/drive/My Drive/Transformer")

In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
import os

import utils
from utils.dataloader import get_IMDb_DataLoaders_and_TEXT
from utils.transformer import TransformerClassification

In [21]:
# ※リロード用
import importlib
importlib.reload(torch.nn)
importlib.reload(utils.transformer)

<module 'utils.transformer' from '/content/drive/My Drive/Transformer/utils/transformer.py'>

In [119]:
d_model = 300
max_seq_len = 256
d_hidden = 1024
drop_ratio = 0.1
d_out = 2
batch_size = 64
lr = 2e-5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("使用デバイス: {}".format(device))

使用デバイス: cuda:0


### 6.1.1. DataLoader

In [120]:
# 学習データの取得
train_dl, val_dl, test_dl, TEXT = get_IMDb_DataLoaders_and_TEXT(max_seq_len, batch_size)
train_data_dict = {"train": train_dl, "val": val_dl}

### 6.1.2. Transformer

In [121]:
# モデル定義
net = TransformerClassification(TEXT.vocab.vectors, d_model, max_seq_len, d_hidden, d_out, drop_ratio, device)
net.to(device)
net.train()

# 重みの初期化
def weight_init(m):
    cls_name = m.__class__.__name__
    if cls_name.find("Linear") != -1:
        nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)

net.trm1.apply(weight_init)
net.trm2.apply(weight_init)

print(net)

TransformerClassification(
  (emb): Embedder(
    (emb): Embedding(69959, 300)
  )
  (pe): PositionEncoder()
  (trm1): TransformerBlock(
    (norm1): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
    (attn): Attention(
      (q_linear): Linear(in_features=300, out_features=300, bias=True)
      (k_linear): Linear(in_features=300, out_features=300, bias=True)
      (v_linear): Linear(in_features=300, out_features=300, bias=True)
      (out): Linear(in_features=300, out_features=300, bias=True)
    )
    (dropout1): Dropout(p=0.1, inplace=False)
    (norm2): LayerNorm((300,), eps=1e-05, elementwise_affine=True)
    (ff): FeedForward(
      (layers): Sequential(
        (0): Linear(in_features=300, out_features=1024, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
        (3): Linear(in_features=1024, out_features=300, bias=True)
      )
    )
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (trm2): TransformerBlock(
    (norm1): LayerNorm((300,), eps=

### 6.1.3. 損失関数と最適化方法

In [122]:
# 損失関数
loss_func = nn.CrossEntropyLoss()

# 最適化方法
optimizer = optim.Adam(net.parameters(), lr=lr)

## 6.2. 学習

In [123]:
def train(num_epochs):
    torch.backends.cudnn.benchmark = True
    for epoch in range(num_epochs):
        for phase in ["train", "val"]:
            if phase == "train":
                net.train()
            else:
                net.eval()

            epoch_loss = 0.0    # エポック内の損失
            epoch_corrects = 0  # エポック内の合計正解数

            for batch in train_data_dict[phase]:
                x = batch.Text[0].to(device)
                t = batch.Label.to(device)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):  # 学習データのみ勾配を計算
                    input_pad = TEXT.vocab.stoi["<pad>"]
                    input_mask = (x != input_pad)       # 文章でない箇所をマスクする

                    y, _, _ = net(x, input_mask)
                    loss = loss_func(y, t)

                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                    _, preds = torch.max(y, 1)  # ラベル予測(大きい方のインデックスを取得)
                    epoch_loss += loss.item() * batch_size
                    epoch_corrects += torch.sum(preds == t.data)

            epoch_loss = epoch_loss / len(train_data_dict[phase].dataset)
            epoch_acc = epoch_corrects.double() / len(train_data_dict[phase].dataset)   # 正解率
            print("Epoch {}/{} | {:^5} | Loss: {:.4f} Acc: {:.4f}".format(epoch + 1, num_epochs, phase, epoch_loss, epoch_acc))


In [124]:
# 学習モデルの読込み
model_file = "/content/drive/My Drive/Transformer/model/transformer.pth"
if os.path.exists(model_file):
    load_data = torch.load(model_file, map_location=device)
    net.load_state_dict(load_data["state_dict"])
    print("Load model")
else:
    print("Model file not found")

Load model


In [None]:
num_iterate = 10
num_epochs = 10

for i in range(num_iterate):
    print("-"*10 + " Iterate: {} ".format(i) + "-"*10)
    train(num_epochs)

    # 学習モデルの保存
    save_data = {"state_dict": net.state_dict()}
    torch.save(save_data, model_file)
    print("Save model")

In [126]:
# 学習モデルの保存
save_data = {"state_dict": net.state_dict()}
torch.save(save_data, model_file)
print("Save model")

Save model


## 6.3. 推論

In [127]:
net.eval()

epoch_corrects = 0
for batch in test_dl:
    x = batch.Text[0].to(device)
    t = batch.Label.to(device)

    with torch.set_grad_enabled(False):
        input_pad = TEXT.vocab.stoi["<pad>"]
        input_mask = (x != input_pad)       # 文章でない箇所をマスクする
        y, _, _ = net(x, input_mask)

        _, preds = torch.max(y, 1)  # ラベル予測(大きい方のインデックスを取得)
        epoch_corrects += torch.sum(preds == t.data)

epoch_acc = epoch_corrects.double() / len(test_dl.dataset)   # 正解率
print("テストデータでの正解率: {:.4f}".format(epoch_acc))

テストデータでの正解率: 0.8360


## 6.4. 可視化

In [128]:
# 適当なデータをモデルに入力
batch = next(iter(test_dl))
x = batch.Text[0].to(device)
t = batch.Label.to(device)

input_pad = TEXT.vocab.stoi["<pad>"]
input_mask = (x != input_pad)       # 文章でない箇所をマスクする
y, attn_w1, attn_w2 = net(x, input_mask)

_, preds = torch.max(y, 1)  # ラベル予測(大きい方のインデックスを取得)

In [129]:
def highlight(word, attn):
    html_color = "#{:02X}{:02X}{:02X}".format(255, int(255 * (1 - attn)), int(255 * (1 - attn)))
    return "<span style=\"color: #000000; background-color: {}\"> {}</span>".format(html_color, word)

def mk_html(index, batch, preds, attn_w1, attn_w2, TEXT):
    sentence = batch.Text[0][index]     # 文章(単語ID)
    label = batch.Label[index]          # 正解ラベル
    pred = preds[index]                 # 予測ラベル

    label_strs = ["Negative", "Positive"]
    label_str = label_strs[label]
    pred_str = label_strs[pred]

    attn1 = attn_w1[index, 0, :]        # <cls>からみたAttention
    attn1 = attn1 / attn1.max()
    attn2 = attn_w2[index, 0, :]        # <cls>からみたAttention
    attn2 = attn2 / attn2.max()

    html = "<hr>"
    html += "正解ラベル: {}<br>予測ラベル: {}<br><br>".format(label_str, pred_str)

    html += "[TransformerBlock1段目のAttention]<br>"
    for word, attn in zip(sentence, attn1):
        html += highlight(TEXT.vocab.itos[word], attn)
    html += "<br><br>"

    html += "[TransformerBlock2段目のAttention]<br>"
    for word, attn in zip(sentence, attn2):
        html += highlight(TEXT.vocab.itos[word], attn)

    html += "<hr>"

    return html

In [133]:
from IPython.display import HTML

index = 32
html_out = mk_html(index, batch, preds, attn_w1, attn_w2, TEXT)
HTML(html_out)