In [14]:
MAX_LEN = 256 # 128
BATCH_SIZE = 16
NUM_EPOCHS = 10
LEARNING_RATE = 2e-5

In [15]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

import transformers
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel
from torch import optim
from torch import cuda
import time
from matplotlib import pyplot as plt

from transformers import BertJapaneseTokenizer
from tqdm import tqdm_notebook as tqdm

# test
path = './BERT-base_mecab-ipadic-bpe-32k'
tokenizer = BertJapaneseTokenizer.from_pretrained(path, word_tokenizer_type='mecab')
print(tokenizer.encode_plus(text='\nいつぞや、日向地方を行乞した時の出来事である。', add_special_tokens=False))

{'input_ids': [4794, 29189, 28528, 6, 11627, 794, 11, 77, 18765, 15, 10, 72, 5, 6157, 12, 31, 8], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [68]:
files = pd.read_csv('./train_author_novel.csv')[['author_id', 'filename']]
classes = list(files['author_id'].unique())

In [17]:
train, valid = train_test_split(files, test_size=0.2, shuffle=True, random_state=42, stratify=files['author_id'])
train.reset_index(drop=True, inplace=True)
valid.reset_index(drop=True, inplace=True)

train.head()

Unnamed: 0,author_id,filename
0,153,49679.txt
1,281,3597.txt
2,146,48246.txt
3,1670,54627.txt
4,305,50390.txt


In [33]:
def read_files(files, phase='train'):
    texts = []
    for filename in files:
        with open(f'./pp_{phase}/{filename}') as f:
            text = f.read()
            text = text[:MAX_LEN*3] # bertのtokenize的に２〜3倍程度あれば大丈夫そう
        texts.append(text)
    return texts

train_texts = read_files(list(train['filename']))
train_auther_ids = list(train['author_id'])
valid_texts = read_files(list(valid['filename']))
valid_auther_ids = list(valid['author_id'])

In [38]:
class AozoraDataset(Dataset):
    def __init__(self, texts, auther_ids, classes, tokenizer, max_len):
        self.texts = texts
        self.auther_ids = auther_ids
        self.classes = classes
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.auther_ids)
    
    def __getitem__(self, index):
        inputs = self.tokenizer.encode_plus(
          self.texts[index],
          add_special_tokens=True,
          max_length=self.max_len,
          pad_to_max_length=True
        )

        labels = len(classes) * [0]
        if self.auther_ids[index] is not None:
            labels[self.classes.index(self.auther_ids[index])] = 1

        return {
          'ids': torch.LongTensor(inputs['input_ids']),
          'mask': torch.LongTensor(inputs['attention_mask']),
          'labels': torch.Tensor(labels)
        }

In [20]:
train_dataset = AozoraDataset(train_texts, train_auther_ids, classes, tokenizer, MAX_LEN)
valid_dataset = AozoraDataset(valid_texts, valid_auther_ids, classes, tokenizer, MAX_LEN)
train_dataset[0]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


{'ids': tensor([    2,  2090,   114, 29333,   120,     7,  7330,    16,     6,  1325,
             9,  5206, 28593,     6,  2894,   284,  6566, 28489,    16,  1276,
            10,     8,     1,  2090,     7,  7330,    16,  1726,  1582,    12,
             9,    80,     8,  1325,     5,  3246,     1, 12622,    16,  1276,
            10,   732,     9,     6,     1,   319,     7,   830,    16, 26075,
         28449, 24108, 31785,     5, 24108,     5,  2867,   465, 30794,  1346,
         18147,  3488,    49, 28489,     7,     6,  1326,   114, 29333,   120,
             1, 11367,     5,  1863,     7,  1040,     5,    36,  9467,    38,
             5,    32,    52,    32,     7,  4192,    26,    20,    16,  1276,
            10,   732,  3337,    12,    31,     8, 13818, 28545,  4914,  4914,
             5,  1040,    11,  3649,    34,  1863,     7, 16166,     5,   493,
         30642,    11,   474, 28647,  2078,    16,   212,    16,    28,     6,
           218,    14,     6,  1037,    72,  

In [21]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Train

In [22]:
class AozoraClassficationModel(torch.nn.Module):
    def __init__(self, path, output_size):
        super().__init__()
        self.bert = BertModel.from_pretrained(path)
        self.ln = torch.nn.Linear(768, output_size)

    def forward(self, ids, mask):
        _, x = self.bert(ids, attention_mask=mask)
        x = self.ln(x)
        return x

model = AozoraClassficationModel(path, len(classes))

In [23]:
def evaluate(model, loader, device='cpu', criterion=None):
    model.eval()
    loss = 0.0
    total = 0
    correct = 0
    with torch.no_grad():
        print('*'*20 + 'start valid' + '*'*20)
        for data in tqdm(loader):
            ids = data['ids'].to(device)
            mask = data['mask'].to(device)
            labels = data['labels'].to(device)
            outputs = model(ids, mask)
            if criterion != None:
                loss += criterion(outputs, labels).item()
            pred = torch.argmax(outputs, dim=-1).cpu().numpy()
            labels = torch.argmax(labels, dim=-1).cpu().numpy()
            total += len(labels)
            correct += (pred == labels).sum().item()

    return loss / len(loader), correct / total

def train_model(train_dataloader, valid_dataloader, model, criterion, optimizer, device='cpu'):
    model.to(device)
    train_log = []
    valid_log = []
    for epoch in range(NUM_EPOCHS):
        start = time.time()
        model.train()
        print('*'*20 + 'start train' + '*'*20)
        for data in tqdm(train_dataloader):
            ids = data['ids'].to(device)
            mask = data['mask'].to(device)
            labels = data['labels'].to(device)
            
            optimizer.zero_grad()
            outputs = model(ids, mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        train_loss, train_acc = evaluate(model, train_dataloader, device, criterion)
        valid_loss, valid_acc = evaluate(model, valid_dataloader, device, criterion)
        train_log.append([train_loss, train_acc])
        valid_log.append([valid_loss, valid_acc])

        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()}, f'./models/checkpoint{epoch + 1}.pt')
        end = time.time()
        print(f'epoch: {epoch + 1}, loss_train: {train_loss:.4f}, accuracy_train: {train_acc:.4f}, loss_valid: {valid_loss:.4f}, accuracy_valid: {valid_acc:.4f}, {(end - start):.4f}sec') 

    return {'train': train_log, 'valid': valid_log}

In [24]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(params=model.parameters(), lr=LEARNING_RATE)
log = train_model(train_dataloader, valid_dataloader, model, criterion, optimizer)

********************start train********************


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for data in tqdm(train_dataloader):


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for data in tqdm(loader):


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


epoch: 1, loss_train: 0.2095, accuracy_train: 0.2408, loss_valid: 0.2098, accuracy_valid: 0.2327, 3394.1501sec
********************start train********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


epoch: 2, loss_train: 0.1643, accuracy_train: 0.5594, loss_valid: 0.1666, accuracy_valid: 0.5189, 3478.4826sec
********************start train********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


epoch: 3, loss_train: 0.1277, accuracy_train: 0.6703, loss_valid: 0.1361, accuracy_valid: 0.6069, 3548.2596sec
********************start train********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


epoch: 4, loss_train: 0.0948, accuracy_train: 0.8529, loss_valid: 0.1065, accuracy_valid: 0.7893, 3448.9762sec
********************start train********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


epoch: 5, loss_train: 0.0758, accuracy_train: 0.9064, loss_valid: 0.0952, accuracy_valid: 0.7767, 3365.9806sec
********************start train********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


epoch: 6, loss_train: 0.0542, accuracy_train: 0.9827, loss_valid: 0.0754, accuracy_valid: 0.8553, 3241.3957sec
********************start train********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


epoch: 7, loss_train: 0.0413, accuracy_train: 0.9921, loss_valid: 0.0659, accuracy_valid: 0.8679, 3414.5583sec
********************start train********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


epoch: 8, loss_train: 0.0323, accuracy_train: 0.9992, loss_valid: 0.0582, accuracy_valid: 0.8868, 3305.5577sec
********************start train********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


epoch: 9, loss_train: 0.0264, accuracy_train: 1.0000, loss_valid: 0.0539, accuracy_valid: 0.8899, 3341.9573sec
********************start train********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=80.0), HTML(value='')))


********************start valid********************


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


epoch: 10, loss_train: 0.0222, accuracy_train: 1.0000, loss_valid: 0.0502, accuracy_valid: 0.8994, 3309.7975sec


# Inference

In [47]:
model = AozoraClassficationModel(path, len(classes))
model.load_state_dict(torch.load(f'./models/checkpoint10.pt')['model_state_dict'])

<All keys matched successfully>

In [40]:
test = pd.read_csv('./test_author_novel.csv')
test['author_id'] = None

test_texts = read_files(list(test['filename']), 'test')
test_auther_ids = list(test['author_id'])

test_dataset = AozoraDataset(test_texts, test_auther_ids, classes, tokenizer, MAX_LEN)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [60]:
def inference(test_dataloader, model, device='cpu'):
    preds = []
    model.to(device)
    model.eval()
    for data in tqdm(test_dataloader):
        ids = data['ids'].to(device)
        mask = data['mask'].to(device)

        outputs = model(ids, mask)
        pred = torch.argmax(outputs, dim=-1).cpu().numpy()
        preds += list(pred)
    return preds

In [61]:
preds = inference(test_dataloader, model)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for data in tqdm(test_dataloader):


HBox(children=(FloatProgress(value=0.0, max=34.0), HTML(value='')))




In [72]:
test['author_id'] = preds
test.head()

Unnamed: 0,novel_id,filename,author_id
0,2198,2198.txt,0
1,2213,2213.txt,0
2,2627,2627.txt,0
3,2618,2618.txt,12
4,2621,2621.txt,12


In [73]:
test.to_csv('./test_author_novel.csv', index=None)