### 参考https://www.kaggle.com/hawkeoni/pytorch-simple-bert

In [None]:
import os
from typing import Tuple, List

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer, BertModel, AdamW, WarmupLinearSchedule, BertPreTrainedModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from tqdm import tqdm

In [2]:
path = "./toxic-comment-classification/"
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda:0')
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
# tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
assert tokenizer.pad_token_id == 0, "Padding value used in masks is set to zero, please change it everywhere"
train_df = pd.read_csv(os.path.join(path, 'train.csv'))
# training on a part of data for speed
# train_df = train_df.sample(frac=0.33)
train_df, val_df = train_test_split(train_df, test_size=0.05)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 213450/213450 [00:00<00:00, 226383.07B/s]


In [3]:
class ToxicDataset(Dataset):
    
    def __init__(self, tokenizer, dataframe, device):
        self.device = device
        self.tokenizer = tokenizer
        self.pad_idx = tokenizer.pad_token_id
        self.X = []
        self.Y = []
        for i, (row) in tqdm(dataframe.iterrows()):
            if len(tokenizer.tokenize(row["comment_text"])) > 120:
                continue
            text = tokenizer.encode(row["comment_text"], add_special_tokens=True)
            text = torch.LongTensor(text)
            tags = torch.FloatTensor(row[["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]])
            self.X.append(text)
            self.Y.append(tags)
    
    def __len__(self):
        return len(self.X)

    def __getitem__(self, index: int) -> Tuple[torch.LongTensor, torch.LongTensor]:
        return self.X[index], self.Y[index]

def collate_fn(batch: List[Tuple[torch.LongTensor, torch.LongTensor]]) \
        -> Tuple[torch.LongTensor, torch.LongTensor]:
    x, y = list(zip(*batch))
    x = pad_sequence(x, batch_first=True, padding_value=0)
    y = torch.stack(y)
    return x.to(device), y.to(device)

train_dataset = ToxicDataset(tokenizer, train_df, device)
dev_dataset = ToxicDataset(tokenizer, val_df, device)

BATCH_SIZE = 32
train_sampler = RandomSampler(train_dataset)
dev_sampler = RandomSampler(dev_dataset)
train_iterator = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, collate_fn=collate_fn)
dev_iterator = DataLoader(dev_dataset, batch_size=BATCH_SIZE, sampler=dev_sampler, collate_fn=collate_fn)

151592it [04:44, 532.78it/s]
7979it [00:14, 532.22it/s]


In [4]:
class BertClassifier(BertPreTrainedModel):
    
    def __init__(self, config):
        super(BertClassifier, self).__init__(config)
        self.bert = BertModel(config)
        self.classifier = nn.Linear(config.hidden_size, 6)
        
    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
                
            labels=None):
        outputs = self.bert(input_ids,
                               attention_mask=attention_mask,
                               token_type_ids=token_type_ids,
                               position_ids=position_ids,
                               head_mask=head_mask)
        cls_output = outputs[1] # batch, hidden
        cls_output = self.classifier(cls_output) # batch, 6
        cls_output = torch.sigmoid(cls_output)
        criterion = nn.BCELoss()
        loss = 0
        if labels is not None:
            loss = criterion(cls_output, labels)
        return loss, cls_output

model = BertClassifier.from_pretrained('bert-base-cased').to(device)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 313/313 [00:00<00:00, 304527.29B/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 435779157/435779157 [47:53<00:00, 151676.20B/s]


In [5]:
def train(model, iterator, optimizer, scheduler):
    model.train()
    total_loss = 0
    for x, y in tqdm(iterator):
        optimizer.zero_grad()
        mask = (x != 0).float()
        loss, outputs = model(x, attention_mask=mask, labels=y)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        scheduler.step()
    print(f"Train loss {total_loss / len(iterator)}")

def evaluate(model, iterator):
    model.eval()
    pred = []
    true = []
    with torch.no_grad():
        total_loss = 0
        for x, y in tqdm(iterator):
            mask = (x != 0).float()
            loss, outputs = model(x, attention_mask=mask, labels=y)
            total_loss += loss
            true += y.cpu().numpy().tolist()
            pred += outputs.cpu().numpy().tolist()
    true = np.array(true)
    pred = np.array(pred)
    for i, name in enumerate(['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']):
        print(f"{name} roc_auc {roc_auc_score(true[:, i], pred[:, i])}")
    print(f"Evaluate loss {total_loss / len(iterator)}")

In [6]:
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
EPOCH_NUM = 2
# triangular learning rate, linearly grows untill half of first epoch, then linearly decays 
warmup_steps = int(0.5 * len(train_iterator))
total_steps = len(train_iterator) * EPOCH_NUM - warmup_steps
optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5, eps=1e-8)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=total_steps)

In [7]:
for i in range(EPOCH_NUM):
    print('=' * 50, f"EPOCH {i}", '=' * 50)
    train(model, train_iterator, optimizer, scheduler)
    evaluate(model, dev_iterator)



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3718/3718 [11:39:58<00:00, 11.30s/it]


Train loss 0.09369936712543846


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [13:06<00:00,  4.02s/it]


toxic roc_auc 0.9864197344390776
severe_toxic roc_auc 0.9927542098445596
obscene roc_auc 0.9941830211470146
threat roc_auc 0.9837078651685393
insult roc_auc 0.9893937215409495
identity_hate roc_auc 0.9845207903926951
Evaluate loss 0.04085414856672287


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3718/3718 [11:30:15<00:00, 11.14s/it]


Train loss 0.03512513823299997


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [13:01<00:00,  3.99s/it]


toxic roc_auc 0.987829738498702
severe_toxic roc_auc 0.9928280249923804
obscene roc_auc 0.994578459237941
threat roc_auc 0.989795918367347
insult roc_auc 0.9894437330099259
identity_hate roc_auc 0.9912329860088714
Evaluate loss 0.03872765228152275
