In [64]:
from pprint import pprint
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from typing import List

import pandas as pd
import torch

In [65]:
class TextDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_len=512):
        self.dataframe = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        text = self.dataframe.iloc[idx, 0]
        label = 1 if self.dataframe.iloc[idx, 1] == '米津玄師' else 0
        inputs = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        input_ids = inputs['input_ids'].flatten()
        attention_mask = inputs['attention_mask'].flatten()
        return input_ids, attention_mask, label

In [66]:
tokenizer = BertTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
dataset = TextDataset('data/train.csv', tokenizer)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertJapaneseTokenizer'. 
The class this function is called from is 'BertTokenizer'.


In [67]:
class ClassifierModel(torch.nn.Module):
    def __init__(self):
        super(ClassifierModel, self).__init__()
        self.bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
        self.linear = torch.nn.Linear(768, 2)  # BERTの隠れ層の次元数と出力クラス数

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs[0]
        pooled_output = last_hidden_state[:, 0]
        return self.linear(pooled_output)

In [68]:
is_cuda = torch.cuda.is_available()
device = torch.device('cuda' if is_cuda else 'cpu')
pprint(f'device: {device}')

'device: cuda'


In [69]:
def to_device(device, *args):
    return [x.to(device) for x in args]

In [70]:
# Train
epoch = 3 # for debug
net: torch.nn.Module = ClassifierModel().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

for e in range(epoch):
    for input_ids, attention_mask, labels in loader:
        input_ids, attention_mask, labels = to_device(device, input_ids, attention_mask, labels)
        optimizer.zero_grad()
        outputs = net(input_ids, attention_mask)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [72]:
# Test
net.eval()

with torch.no_grad():
    dataset = TextDataset('data/test.csv', tokenizer)
    loader = DataLoader(dataset, batch_size=32, shuffle=False)
    for input_ids, attention_mask, labels in loader:
        input_ids, attention_mask, labels = to_device(device, input_ids, attention_mask, labels)
        outputs = net(input_ids, attention_mask)
        predicted = torch.argmax(outputs, dim=1)
        decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
        label_text = ['米津玄師' if l == 1 else 'ちいかわ' for l in labels]
        predicted_text = ['米津玄師' if l == 1 else 'ちいかわ' for l in predicted]
        zipped = zip(decoded, label_text, predicted_text)
        pprint("テキスト, 正解ラベル, 予測ラベル")
        for d, l, p in zipped:
            print(f'{d} {l} {p}')



'テキスト, 正解ラベル, 予測ラベル'
泣 いても 涙 がでないや ちいかわ 米津玄師
お 前 になんかやるもんか 米津玄師 米津玄師
守 りたいんだ みんなが 戻 ってくるまで ちいかわ 米津玄師
この 像 に 誓 ったんだ 強 くなると ちいかわ 米津玄師
どうしてどうしてどうして 米津玄師 米津玄師
難 解 なパズルみたい ちいかわ 米津玄師
そこから 見 ていてね 大 丈 夫 ありがとう 米津玄師 米津玄師
ヤンパパン ラララルルラ ちいかわ ちいかわ
ヒッピヒッピシェイク ダンディダンディドン 米津玄師 ちいかわ
るるらったったったった 米津玄師 米津玄師
