# BERTでクラス分類をしてみよう

## ライブラリのインストール・データのダウンロード

In [None]:
!pip install transformers["ja"] nlp

In [None]:
%cd "/content/drive/My Drive/Colab Notebooks/introduction_to_pytorch_and_bert"

In [None]:
!mkdir data
!mkdir checkpoints
!wget https://github.com/tealgreen0503/introduction_to_pytorch_and_bert/raw/main/data/amazon_reviews_multilingual_JP_v1_00_20000_binary.tsv.gz -P data/
!wget https://github.com/tealgreen0503/introduction_to_pytorch_and_bert/raw/main/data/amazon_reviews_multilingual_JP_v1_00_10000_binary.tsv.gz -P data/
!gunzip -d ./data/amazon_reviews_multilingual_JP_v1_00_20000_binary.tsv.gz
!gunzip -d ./data/amazon_reviews_multilingual_JP_v1_00_10000_binary.tsv.gz

## ライブラリのインポート・シードの固定・定数の設定

In [None]:
import os
import random
import collections

from bs4 import BeautifulSoup
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from transformers import AutoTokenizer, AutoModel, AdamW
import nlp

In [None]:
SEED = 42
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
seed_everything(SEED)

In [None]:
if torch.cuda.is_available():
    current_device = torch.cuda.current_device()
    print('Device:', torch.cuda.get_device_name(current_device))

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TRAIN_PATH = './data/amazon_reviews_multilingual_JP_v1_00_20000_binary.tsv'
TEST_PATH = './data/amazon_reviews_multilingual_JP_v1_00_10000_binary.tsv'
CKPT_DIR = './checkpoints/'
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 128
NUM_CLASSES = 2
NUM_EPOCH = 5

## データの確認

In [None]:
train_df = pd.read_csv(TRAIN_PATH, sep='\t')
train_df.head()

In [None]:
train_df.info()

In [None]:
train_df['binary_star_rating'].value_counts()

In [None]:
def load_data(data_path, target='binary_star_rating'):
    df = pd.read_csv(data_path, sep='\t')
    df = df[['review_body', target]]

    def clean_html(html, strip=True):
        soup = BeautifulSoup(html, 'html.parser')
        text = soup.get_text(strip=strip)
        return text

    df['review_body'] = df['review_body'].map(clean_html)
    df = df.rename(columns={target: 'labels'})
    return df

In [None]:
df = load_data(TRAIN_PATH)
df.head()

## 関数・モデルの定義

In [None]:
def make_dataset(df, tokenizer, device):
    dataset = nlp.Dataset.from_pandas(df)
    dataset = dataset.map(
        lambda example: tokenizer(example["review_body"],
                                  padding="max_length",
                                  truncation=True,
                                  max_length=128))
    dataset.set_format(type='torch', 
                       columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'], 
                       device=device)
    return dataset

In [None]:
class Classifier(nn.Module):
    def __init__(self, model_name, num_classes=2):
        super().__init__()

        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.linear = nn.Linear(768, num_classes)
        nn.init.normal_(self.linear.weight, std=self.bert.config.initializer_range)
        nn.init.zeros_(self.linear.bias)

    def forward(self, **inputs):
        outputs = self.bert(**inputs)
        output = outputs.last_hidden_state
        output = output[:, 0, :]
        output = self.dropout(output)
        output = self.linear(output)
        return output

In [None]:
class Trainer:
    def __init__(self,
                 model,
                 train_dataloader,
                 valid_dataloader,
                 criterion,
                 optimizer,
                 scheduler=None,
                 num_epoch=10,
                 ckpt_name='./bert'):
        
        self.model = model
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.criterion = criterion
        self.optimizer = optimizer
        if scheduler is not None:
            self.scheduler = scheduler
        else:
            self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                       step_size=1e+10,
                                                       gamma=1.0)
        self.num_epoch = num_epoch
        self.ckpt_name = ckpt_name


    def _train_step(self, epoch):
        self.model.train()
        total_loss = 0
        total_corrects = 0
        all_labels = np.array([])
        all_preds = np.array([])

        progress = tqdm(self.train_dataloader, total=len(self.train_dataloader))

        for i, batch in enumerate(progress):
            progress.set_description(f"<Train> Epoch{epoch+1}")

            labels = batch.pop('labels')
            inputs = batch

            self.optimizer.zero_grad()

            output = self.model(**inputs)
            loss = self.criterion(output, labels)
            preds = torch.argmax(output, dim=1)

            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            total_loss += loss.item()
            total_corrects += torch.sum(preds == labels)
            all_labels = np.r_[all_labels, labels.to('cpu').detach().numpy()]
            all_preds = np.r_[all_preds, preds.to('cpu').detach().numpy()]
            f1 = f1_score(all_labels, all_preds)

            progress.set_postfix(loss=total_loss/(i+1), f1=f1)

        train_loss = total_loss / len(self.train_dataloader)
        train_acc = total_corrects.to('cpu').detach().numpy() / len(self.train_dataloader.dataset)
        train_f1 = f1

        return train_loss, train_acc, train_f1

    def _eval_step(self, epoch):
        self.model.eval()
        total_loss = 0
        total_corrects = 0
        all_labels = np.array([])
        all_preds = np.array([])

        with torch.no_grad():
            progress = tqdm(self.valid_dataloader,
                            total=len(self.valid_dataloader))
            
            for i, batch in enumerate(progress):
                progress.set_description(f"<Valid> Epoch{epoch+1}")

                labels = batch.pop('labels')
                inputs = batch

                output = self.model(**inputs)
                loss = self.criterion(output, labels)
                preds = torch.argmax(output, dim=1)
                
                total_loss += loss.item()
                total_corrects += torch.sum(preds == labels)
                all_labels = np.r_[all_labels, labels.to('cpu').detach().numpy()]
                all_preds = np.r_[all_preds, preds.to('cpu').detach().numpy()]
                f1 = f1_score(all_labels, all_preds)

                progress.set_postfix(loss=total_loss/(i+1), f1=f1)

            valid_loss = total_loss / len(self.valid_dataloader)
            valid_acc = total_corrects.to('cpu').detach().numpy() / len(self.valid_dataloader.dataset)
            valid_f1 = f1

        return valid_loss, valid_acc, valid_f1

    def train(self, metric='f1'):
        if metric == 'f1':
            best_metric = 0
        elif metric == 'acc':
            best_metric = 0
        elif metric == 'loss':
            best_metric = np.inf
        else:
            raise RuntimeError()

        for epoch in range(self.num_epoch):
            train_loss, train_acc, train_f1= self._train_step(epoch)
            valid_loss, valid_acc, valid_f1 = self._eval_step(epoch)
            print(f'Loss: {valid_loss}  Acc: {valid_acc}  f1: {valid_f1}', end='  ')

            if metric == 'f1':
                if valid_f1 > best_metric:
                    best_metric = valid_f1
                    print('model saving!', end='')
                    torch.save(self.model.state_dict(), f"{self.ckpt_name}.pth")
            elif metric == 'acc':
                if valid_acc > best_metric:
                    best_metric = valid_acc
                    print('model saving!', end='')
                    torch.save(self.model.state_dict(), f"{self.ckpt_name}.pth")
            elif metric == 'loss':
                if valid_loss < best_metric:
                    best_metric = valid_loss
                    print('model saving!', end='')
                    torch.save(self.model.state_dict(), f"{self.ckpt_name}.pth")
            else:
                raise RuntimeError()
            print('\n\n')

        return best_metric

## 学習

In [None]:
df = load_data(TRAIN_PATH)
train_df, valid_df = train_test_split(df, test_size=0.2, random_state=SEED, stratify=df['labels'])

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = make_dataset(train_df, tokenizer, DEVICE)
valid_dataset = make_dataset(valid_df, tokenizer, DEVICE)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True
)
valid_dataloader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=VALID_BATCH_SIZE, shuffle=False
)

In [None]:
model = Classifier(MODEL_NAME, num_classes=NUM_CLASSES)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=2e-5)

ckpt_name = CKPT_DIR + MODEL_NAME.replace('/', '_')

trainer = Trainer(model,
                  train_dataloader,
                  valid_dataloader,
                  criterion,
                  optimizer,
                  num_epoch=NUM_EPOCH,
                  ckpt_name=ckpt_name)

In [None]:
valid_f1 = trainer.train()

## テスト

In [None]:
model = Classifier(MODEL_NAME, num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(ckpt_name + '.pth'))
model.to(DEVICE)
model.eval()

In [None]:
test_df = load_data(TEST_PATH)
test_dataset = make_dataset(test_df, tokenizer, DEVICE)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=VALID_BATCH_SIZE, shuffle=False)

In [None]:
with torch.no_grad():
    progress = tqdm(test_dataloader, total=len(test_dataloader))
    final_output = np.array([])

    for batch in progress:
        progress.set_description("<Test>")

        _ = batch.pop('labels')
        inputs = batch

        output = model(**inputs)
        output = torch.softmax(output, dim=1).to('cpu').detach().numpy()
        output = np.argmax(output, axis=1)

        final_output = np.r_[final_output, output]

In [None]:
test_acc = np.sum(test_df['labels'] == final_output) / len(test_df['labels'])
test_f1 = f1_score(test_df['labels'], final_output)

print(f'Test Acc: {test_acc}  Test f1: {test_f1}')

In [None]:
def print_classification_report(all_labels, all_preds):
    cr = classification_report(all_labels, all_preds)
    print(cr)
    freq = collections.Counter(all_labels)
    freq = [freq[i] for i in range(NUM_CLASSES)]
    cm = confusion_matrix(all_labels, all_preds)
    cm = cm / freq
    sns.heatmap(cm, cmap="Reds", annot=True)
    plt.show()

In [None]:
print_classification_report(test_df['labels'], final_output)