In [1]:
import torch
from transformers import BertModel, BertConfig, BertTokenizer
import json
import pandas as pd
from torchtext.data.utils import get_tokenizer
from torchtext import data
from torch import nn
from torch.autograd import Variable
from torch.utils.data import (TensorDataset, random_split,
                              RandomSampler, DataLoader,
                              SequentialSampler)
from tqdm.notebook import tqdm
import numpy as np
import torch.nn.functional as F
import os
from sklearn.metrics import accuracy_score 
import warnings
warnings.filterwarnings('ignore')
from transformers import BertForSequenceClassification
from transformers import AdamW

In [2]:
np.random.seed(seed=54)

# Load data

Dataset: https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/overview

In [3]:
train = pd.read_csv('./dataset/train.csv')
test_data = pd.read_csv('./dataset/test.csv')
test_labels = pd.read_csv('./dataset/test_labels.csv')
test = pd.merge(test_data, test_labels, on='id')
train.head(5)

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0


In [4]:
f'train: {train.shape}, test:  {test.shape}'

'train: (159571, 8), test:  (153164, 8)'

### Remove long sentences

In [6]:
MAX_LEN = 200
TRAIN_SIZE = 5000
TEST_SIZE = 500

In [7]:
def get_data(data, size, ratio):
    X, y = data['comment_text'].values, data['toxic'].values
    lengths = np.array([len(x.split()) for x in X])
    idxs = np.where(lengths <= MAX_LEN)[0]
    X, y = X[idxs], y[idxs]
    
    lengths = np.array([len(x) for x in X])
    idxs = np.where(lengths <= 512)[0]
    X, y = X[idxs], y[idxs]


    toxic = np.where(y == 1)[0]
    normal = np.where(y == 0)[0]
    idxs_toxic = np.random.choice(toxic, round(size * ratio), replace=False)
    idxs_normal = np.random.choice(normal, round(size * (1 - ratio)), replace=False)

    X = np.concatenate([X[idxs_toxic], X[idxs_normal]])
    y = np.concatenate([y[idxs_toxic], y[idxs_normal]])
    
    return X, y

In [8]:
X_train, y_train = get_data(train, TRAIN_SIZE, 0.5)
X_test, y_test = get_data(test, TEST_SIZE, 0.6)

In [9]:
f'Train: normal - {len(np.where(y_train == 0)[0])} toxic - {len(np.where(y_train == 1)[0])}'

'Train: normal - 2500 toxic - 2500'

In [10]:
f'Test: normal - {len(np.where(y_test == 0)[0])} toxic - {len(np.where(y_test == 1)[0])}'

'Test: normal - 200 toxic - 300'

### Preprocessing

In [11]:
def get_vocab(X):
    X_split = [t.split() for t in X]
    text_field = data.Field()
    text_field.build_vocab(X_split, max_size=10000)
    return text_field

def pad(seq, max_len):
    if len(seq) < max_len:
        seq = seq + ['<pad>'] * (max_len - len(seq))
    return seq[0:max_len]

def to_indexes(vocab, words):
    return [vocab.stoi[w] for w in words]

def to_dataset(x, y, y_real):
    torch_x = torch.tensor(x, dtype=torch.long)
    torch_y = torch.tensor(y, dtype=torch.float)
    torch_real_y = torch.tensor(y_real, dtype=torch.long)
    return TensorDataset(torch_x, torch_y, torch_real_y)

# Bi-LSTM

In [12]:
class SimpleLSTM(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, n_layers,
                 bidirectional, dropout, batch_size):
        super(SimpleLSTM, self).__init__()

        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, embedding_dim)

        self.rnn = nn.LSTM(self.embedding.embedding_dim,
                           hidden_dim,
                           num_layers=n_layers,
                           bidirectional=bidirectional,
                           dropout=dropout)

        self.fc = nn.Linear(hidden_dim * 2, output_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, text, text_lengths=None):
        x = self.embedding(text)
        x, hidden = self.rnn(x)
        hidden, cell = hidden
        hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        x = self.fc(hidden)
        return x

In [13]:
def get_optimizer(model):
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.9)
    return optimizer, scheduler

def epoch_train_func(model, dataset, loss_func, batch_size):
    train_loss = 0
    train_sampler = RandomSampler(dataset)
    data_loader = DataLoader(dataset, sampler=train_sampler,
                             batch_size=batch_size,
                             drop_last=True)
    model.train()
    optimizer, scheduler = get_optimizer(model)
    for i, (text, bert_prob, real_label) in enumerate(tqdm(data_loader, desc='Train')):
        model.zero_grad()
        output = model(text.t(), None).squeeze(1)
        loss = loss_func(output, bert_prob, real_label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    scheduler.step()
    return train_loss / len(data_loader)

In [14]:
def epoch_evaluate_func(model, eval_dataset, loss_func, batch_size):
    eval_sampler = SequentialSampler(eval_dataset)
    data_loader = DataLoader(eval_dataset, sampler=eval_sampler,
                             batch_size=batch_size,
                             drop_last=True)

    eval_loss = 0.0
    model.eval()
    for i, (text, bert_prob, real_label) in enumerate(tqdm(data_loader, desc='Val')):
        output = model(text.t(), None).squeeze(1)
        loss = loss_func(output, bert_prob, real_label)
        eval_loss += loss.item()

    return eval_loss / len(data_loader)

In [15]:
class LSTMBaseline(object):
    vocab_name = 'text_vocab.pt'
    weights_name = 'simple_lstm.pt'

    def __init__(self, settings):
        self.settings = settings
        self.criterion = torch.nn.BCEWithLogitsLoss()

    def loss(self, output, bert_prob, real_label):
        return self.criterion(output, real_label.float())

    def model(self, text_field):
        model = SimpleLSTM(
            input_dim=len(text_field.vocab),
            embedding_dim=64,
            hidden_dim=128,
            output_dim=1,
            n_layers=1,
            bidirectional=True,
            dropout=0.5,
            batch_size=self.settings['train_batch_size'])
        return model

    def train(self, X, y, y_real, output_dir):
        max_len = self.settings['max_seq_length']
        text_field = get_vocab(X)

        X_split = [t.split() for t in X]
        X_pad = [pad(s, max_len) for s in tqdm(X_split, desc='pad')]
        X_index = [to_indexes(text_field.vocab, s) for s in tqdm(X_pad, desc='to index')]

        dataset = to_dataset(X_index, y, y_real)
        val_len = int(len(dataset) * 0.1)
        train_dataset, val_dataset = random_split(dataset, (len(dataset) - val_len, val_len))

        model = self.model(text_field)
        
        self.full_train(model, train_dataset, val_dataset, output_dir)
        torch.save(text_field, os.path.join(output_dir, self.vocab_name))
        
        return model, text_field.vocab

    def full_train(self, model, train_dataset, val_dataset, output_dir):
        train_settings = self.settings
        num_train_epochs = train_settings['num_train_epochs']
        best_eval_loss = 100000
        losses = []
        for epoch in tqdm(range(num_train_epochs), desc='Epochs'):
            train_loss = epoch_train_func(model, train_dataset, self.loss, self.settings['train_batch_size'])
            eval_loss = epoch_evaluate_func(model, val_dataset, self.loss, self.settings['eval_batch_size'])
            
            print(f'Epoch: {epoch} loss: {eval_loss}')
            losses.append(eval_loss)
            
            if eval_loss < best_eval_loss:
                best_eval_loss = eval_loss
                torch.save(model.state_dict(), os.path.join(output_dir, self.weights_name))

# Train

In [16]:
lstm_settings = {
    'max_seq_length': MAX_LEN,
    'num_train_epochs': 0,
    'train_batch_size': 32,
    'eval_batch_size': 32,
}

In [76]:
trainer = LSTMBaseline(lstm_settings)

In [77]:
model, vocab = trainer.train(X_train, y_train, y_train, './output/')

HBox(children=(FloatProgress(value=0.0, description='pad', max=5000.0, style=ProgressStyle(description_width='…




HBox(children=(FloatProgress(value=0.0, description='to index', max=5000.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Epochs', max=1.0, style=ProgressStyle(d…




# Test

In [78]:
def preprocessing(X, y, vocab):
    X_split = [t.split() for t in X]
    X_pad = [pad(s, MAX_LEN) for s in tqdm(X_split, desc='pad')]
    X_index = [to_indexes(vocab, s) for s in tqdm(X_pad, desc='to index')]
    X = torch.transpose(torch.tensor(X_index, dtype=torch.long), 1, 0)
    
    y = torch.tensor(y, dtype=torch.float)
    return X, y

In [79]:
def test(model, vocab, X, y):
    X_test, y_test = preprocessing(X, y, vocab)
    softmax = nn.Sigmoid()
#     softmax = nn.Softmax()
    y_pred = softmax(model(X_test)).detach().numpy()
    y_pred = np.around(y_pred)
#     y_pred = np.argmax(y_pred, axis=1)
    accuracy = accuracy_score(y_pred, y_test)
    return accuracy

In [80]:
model.load_state_dict(torch.load('./output/simple_lstm.pt'))
model.eval()

SimpleLSTM(
  (embedding): Embedding(10002, 64)
  (rnn): LSTM(64, 128, dropout=0.5, bidirectional=True)
  (fc): Linear(in_features=256, out_features=1, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [81]:
f'Accuracy: {test(model, vocab, X_test, y_test)}'

HBox(children=(FloatProgress(value=0.0, description='pad', max=500.0, style=ProgressStyle(description_width='i…




HBox(children=(FloatProgress(value=0.0, description='to index', max=500.0, style=ProgressStyle(description_wid…




'Accuracy: 0.564'

# Distillation

In [28]:
class LSTMDistilled(LSTMBaseline):
    vocab_name = 'distil_text_vocab.pt'
    weights_name = 'distil_lstm.pt'

    def __init__(self, settings):
        super(LSTMDistilled, self).__init__(settings)
        self.criterion_mse = torch.nn.MSELoss()
        self.criterion_ce = torch.nn.CrossEntropyLoss()
        self.a = 0.5

    def loss(self, output, bert_prob, real_label):
        return self.a * self.criterion_ce(output, real_label) + (1 - self.a) * self.criterion_mse(output, bert_prob)

    def model(self, text_field):
        model = SimpleLSTM(
            input_dim=len(text_field.vocab),
            embedding_dim=64,
            hidden_dim=128,
            output_dim=2,
            n_layers=1,
            bidirectional=True,
            dropout=0.5,
            batch_size=self.settings['train_batch_size'])
        return model

# BERT

In [17]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [24]:
X_bert = torch.tensor([tokenizer.encode(x, add_special_tokens=True, pad_to_max_length=True) for x in X_train])
y_bert = torch.tensor(y_train)

In [19]:
y_bert = torch.tensor(y_train)

In [20]:
bert = BertForSequenceClassification.from_pretrained('bert-base-uncased')

In [22]:
outputs = bert(X_bert, labels=y_bert)