In [1]:
!git clone https://github.com/ipavlopoulos/toxic_spans.git

fatal: destination path 'toxic_spans' already exists and is not an empty directory.


In [2]:
!pip install transformers



In [3]:
!nvidia-smi

Tue Apr 20 04:27:40 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.67       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   46C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
from ast import literal_eval
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split
import torch.nn as nn
from tqdm import tqdm
import math
import os
device = 'cuda' if torch.cuda.is_available() else 'cpu'   

In [5]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)
logger = logging.getLogger(__name__)
logging.getLogger('transformers').setLevel(logging.ERROR)

In [6]:
translationTable = str.maketrans("éàèùaêóïü","eaeuaeoiu")

In [7]:
tsd = pd.read_csv("toxic_spans/data/tsd_train.csv") 
tsd.text = tsd.text.apply(lambda x:x.translate(translationTable))
tsd.spans = tsd.spans.apply(literal_eval)
tsd.tail(5)

Unnamed: 0,spans,text
7934,"[8, 9, 10, 11]",Another fool pipes in.
7935,"[48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 5...",So if a restaurant owner puts up a sign saying...
7936,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",Any faith that can't stand up to logic and rea...
7937,"[5, 6, 7, 8, 9, 10, 11]",This idiotic. Use the surplus to pay down the ...
7938,"[106, 107, 108, 109, 110, 169, 170, 171, 172, ...","Who is this ""we"" of which you speak? Are you r..."


In [8]:
text_list = tsd.text.to_list()
spans_list = tsd.spans.to_list()

In [9]:
print(len(text_list))
print(len(spans_list))

7939
7939


In [10]:
model_name = "bert-base-uncased"

In [11]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [12]:
# tokenizer.convert_ids_to_tokens(tokenizer(text_list[-1])['input_ids'])

In [13]:
special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}

In [14]:
def find_idx(sentence, token, position=0):
    start = sentence.find(token, position)
    end = start + len(token)
    if start == -1:
        return []
    return list(range(start,end))

In [15]:
def encode_and_trans_labels(text_list, spans_list):
    inputs = tokenizer(
        text_list,                      
        add_special_tokens = True,             
        truncation=True,
        padding = 'max_length',     
        return_tensors = 'pt',
        max_length = 128
    )
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    labels = []
    for ids, sentence, toxic_spans in zip(input_ids, text_list, spans_list):
        # print(sentence)
        # print(toxic_spans)
        tokens = tokenizer.convert_ids_to_tokens(ids)
        token_labels = []
        position = 0
        for token in tokens[:]:
            if token in special_tokens.values():
                token_labels.append(0.)
            else:
                token = token.replace("##","")
                spans = find_idx(sentence.lower(), token, position=position)
                if spans == []:
                    spans = list(range(position,position+len(token)))
                    print("not find:",token,spans,position,"".join([sentence[i] for i in spans]))
                # print(token,spans,position)
                position = spans[-1]+1
                if set(spans[:]) <= set(toxic_spans) or (set(toxic_spans)<=set(spans) and len(set(toxic_spans))>0):
                    token_labels.append(1.)
                elif len(set(spans[:-1])) > 0 and (set(spans[:-1]) <= set(toxic_spans)):
                    token_labels.append(1.)
                else:
                    token_labels.append(0.)
                # print(token,spans,position)
        labels.append(token_labels)
    labels = torch.tensor(labels)

    return input_ids, attention_mask, labels

In [16]:
class TextDataSet(Dataset):
    def __init__(self, input_ids, attention_mask, labels, text_list, spans_list):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels
        self.text_list = text_list
        self.spans_list = spans_list

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):   
        return (self.input_ids[idx], self.attention_mask[idx], self.labels[idx], self.text_list[idx], self.spans_list[idx])

In [17]:
input_ids,attention_mask,labels = encode_and_trans_labels(text_list, spans_list)

In [18]:
dataset = TextDataSet(input_ids, attention_mask, labels, text_list, spans_list)
train_size = int(len(dataset)*0.8)
valid_size = len(dataset)-train_size
train_dataset, valid_dataset = random_split(dataset, [train_size,valid_size])
print('Train samples: {}, Valid samples: {}'.format(len(train_dataset), len(valid_dataset)))

Train samples: 6351, Valid samples: 1588


In [19]:
import re
import collections
from torch._six import string_classes

np_str_obj_array_pattern = re.compile(r'[SaUO]')

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, list):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

In [20]:
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=default_collate)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, collate_fn=default_collate)

In [21]:
model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=2)

In [22]:
weight_CE = torch.FloatTensor([1,2]).to(device)
loss_fct = nn.CrossEntropyLoss(weight=weight_CE)
def loss_fn(logits, attention_mask, labels):
    active_loss = attention_mask.view(-1) == 1
    active_logits = logits.view(-1, 2)
    active_labels = torch.where(
        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
    )
    loss = loss_fct(active_logits, active_labels)
    return loss

In [23]:
class Trainer:

    def __init__(self, model, train_loader, valid_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.config = config

        # take over whatever gpus are on the system
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = torch.cuda.current_device()
            self.model = torch.nn.DataParallel(self.model).to(self.device)

    def save_checkpoint(self):
        # DataParallel wrappers keep raw model object in .module attribute
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        os.makedirs(self.config.ckpt_path, exist_ok=True)
        logger.info("Save model to {}".format(self.config.ckpt_path))
        torch.save(raw_model.state_dict(), self.config.ckpt_path+"bert_model.pt")

    def train(self):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=config.learning_rate, betas=config.betas)
        def run_epoch(split):
            is_train = (split == 'train')
            model.train(is_train)
            loader = self.train_loader if is_train else self.valid_loader
            losses = []
            spans_list_all = []
            spans_pred_all = []
            pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
            for it, (input_ids, attention_mask, y, text_list, spans_list) in pbar:
                # place data on the correct device
                input_ids = input_ids.to(self.device)
                attention_mask = attention_mask.to(self.device)
                y = y.to(self.device).long()
                # forward the model
                with torch.set_grad_enabled(is_train):
                    outputs = model(input_ids, attention_mask, labels=y)
                    logits = outputs.logits
                    # loss = outputs.loss
                    loss = loss_fn(logits, attention_mask, y)
                    loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
                    losses.append(loss.item())

                    spans_pred = decode_and_trans_labels(text_list, input_ids, logits)
                    # print("pred: ",spans_pred)
                    # print("gold: ",spans_list)
                    f1_score, recall_score, precision_score = batch_score(spans_pred, spans_list)

                    spans_list_all.extend(spans_list)
                    spans_pred_all.extend(spans_pred)

                    # gold = ("".join(text_list[0][i] for i in spans_list[0])).lower()
                    # spans_labels = torch.nonzero(y[0].cpu().detach())
                    # spans_labels = tokenizer.convert_ids_to_tokens([input_ids[0][i] for i in spans_labels])
                    # spans_labels = "".join([i.replace("##","") for i in spans_labels])
                    # pred = "".join(text_list[0][i] for i in spans_pred[0])
                    # print("\ngold: {}\nlables: {}\npred: {}".format(gold,spans_labels,pred))
                
                if is_train:

                    # backprop and update the parameters
                    model.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                    optimizer.step()

                    # decay the learning rate based on our progress
                    if config.lr_decay:
                        self.tokens += batch_size # number of tokens processed this step (i.e. label is not -100)
                        if self.tokens < config.warmup_tokens:
                            # linear warmup
                            lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
                        else:
                            # cosine learning rate decay
                            progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                            lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
                        lr = config.learning_rate * lr_mult
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                    else:
                        lr = config.learning_rate

                    # report progress
                    pbar.set_description("epoch {} iter {}: train loss {:.5f}, f1 {:.2f}%, recall {:.2f}%, precision {:.2f}%, lr {:e}"\
                                         .format(epoch+1,it,loss.item(),f1_score*100,recall_score*100,precision_score*100,lr))
                    # pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}, f1 {f1_score:.2f}, recall_score {recall_score:.5f}, lr {lr:e}")

            if not is_train:
                valid_loss = float(np.mean(losses))
                valid_f1_score, valid_recall_score, valid_precision_score = batch_score(spans_pred_all, spans_list_all)
                logger.info("valid loss: {:.5f}".format(valid_loss))
                logger.info("valid f1 score: {:.2f}%".format(valid_f1_score*100))
                logger.info("valid recall: {:.2f}%".format(valid_recall_score*100))
                logger.info("valid precision: {:.2f}%".format(valid_precision_score*100))
                return valid_loss

        self.tokens = 0 # counter used for learning rate decay
        best_loss = float('inf')
        valid_loss = run_epoch('valid')
        for epoch in range(config.max_epochs):
            
            run_epoch('train')
            if self.valid_loader is not None:
                valid_loss = run_epoch('valid')
            # supports early stopping based on the valid loss, or just save always if no valid set is provided
            good_model = self.valid_loader is None or valid_loss < best_loss
            if self.config.ckpt_path is not None and good_model:
                best_loss = valid_loss
                self.save_checkpoint()

In [24]:
def decode_and_trans_labels(text_list, input_ids, logits):
    token_labels = torch.argmax(logits, dim=-1).cpu().detach().numpy()
    spans_pred = []
    for ids, sentence, labels in zip(input_ids, text_list, token_labels):
        tokens = tokenizer.convert_ids_to_tokens(ids)
        position = 0
        toxic_spans = []
        # print(sentence)
        for token, label in zip(tokens, labels):
            if token in special_tokens.values():
                continue
            token = token.replace("##","")
            spans = find_idx(sentence.lower(), token, position=position)
            # print(token,spans)
            if spans == []:
                spans = list(range(position,position+len(token)))
                print("not find:",token,spans,position,"".join([sentence[i] for i in spans]))
            position = spans[-1]+1
            if label == 1:
                toxic_spans.extend(spans)
        spans_pred.append(toxic_spans)
    return spans_pred

In [25]:
def f1(predictions, gold):
    """
    F1 (a.k.a. DICE) operating on two lists of offsets (e.g., character).
    >>> assert f1([0, 1, 4, 5], [0, 1, 6]) == 0.5714285714285714
    :param predictions: a list of predicted offsets
    :param gold: a list of offsets serving as the ground truth
    :return: a score between 0 and 1
    """
    if len(gold) == 0:
        return 1. if len(predictions) == 0 else 0.
    if len(predictions) == 0:
        return 0.
    predictions_set = set(predictions)
    gold_set = set(gold)
    nom = 2 * len(predictions_set.intersection(gold_set))
    denom = len(predictions_set) + len(gold_set)
    return float(nom)/float(denom)

In [26]:
def recall(predictions, gold):
    if len(gold) == 0:
        return 1. if len(predictions) == 0 else 0.
    if len(predictions) == 0:
        return 0.
    predictions_set = set(predictions)
    gold_set = set(gold)
    nom = len(predictions_set.intersection(gold_set))
    denom = len(gold_set)
    return float(nom)/float(denom)

def precision(predictions, gold):
    if len(gold) == 0:
        return 1. if len(predictions) == 0 else 0.
    if len(predictions) == 0:
        return 0.
    predictions_set = set(predictions)
    gold_set = set(gold)
    nom = len(predictions_set.intersection(gold_set))
    denom = len(predictions_set)
    return float(nom)/float(denom)

In [27]:
def batch_score(spans_pred, spans_list):
    f1_scores = []
    recall_scores = []
    precision_scores = []
    for pred, gold in zip(spans_pred, spans_list):
        f1_score = f1(pred, gold)
        recall_score = recall(pred, gold)
        precision_score = precision(pred, gold)
        f1_scores.append(f1_score)
        recall_scores.append(recall_score)
        precision_scores.append(precision_score)
    return np.mean(f1_scores), np.mean(recall_scores), np.mean(precision_scores)

In [28]:
class TrainerConfig:
    # optimization parameters
    max_epochs = 10
    learning_rate = 1e-5
    betas = (0.9, 0.95)
    grad_norm_clip = 1.0
    weight_decay = 0.1 # may useful optimize method
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    lr_decay = False # optimize method
    warmup_tokens = 375e6 # use this to train model from a lower learning rate
    final_tokens = 260e9 # all tokens during whole training process
    # checkpoint settings
    ckpt_path = './models/' # save model path

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            print(k,v)
            setattr(self, k, v)

In [29]:
# print model all parameters and parameters need training
print('{} : all params: {:4f}M'.format(model._get_name(), sum(p.numel() for p in model.parameters()) / 1000 / 1000))

BertForTokenClassification : all params: 108.893186M


In [30]:
max_epochs = 6
final_tokens = max_epochs * batch_size * len(train_loader)
warmup_tokens = final_tokens//10
tconf = TrainerConfig(max_epochs=max_epochs, learning_rate=1e-5, lr_decay=True, 
                      warmup_tokens=warmup_tokens, final_tokens=final_tokens)

max_epochs 6
learning_rate 1e-05
lr_decay True
warmup_tokens 3820
final_tokens 38208


In [31]:
trainer = Trainer(model, train_loader, valid_loader, tconf)

In [32]:
trainer.train()

04/20/2021 04:28:02 - valid loss: 0.61711
04/20/2021 04:28:02 - valid f1 score: 20.55%
04/20/2021 04:28:02 - valid recall: 43.82%
04/20/2021 04:28:02 - valid precision: 18.47%
epoch 1 iter 198: train loss 0.15111, f1 63.93%, recall 61.61%, precision 80.00%, lr 9.865146e-06: 100%|██████████| 199/199 [01:29<00:00,  2.23it/s]
04/20/2021 04:29:40 - valid loss: 0.31078
04/20/2021 04:29:40 - valid f1 score: 60.69%
04/20/2021 04:29:40 - valid recall: 64.59%
04/20/2021 04:29:40 - valid precision: 67.45%
04/20/2021 04:29:40 - Save model to ./models/
epoch 2 iter 198: train loss 0.27882, f1 72.61%, recall 77.86%, precision 79.13%, lr 8.431011e-06: 100%|██████████| 199/199 [01:29<00:00,  2.23it/s]
04/20/2021 04:31:20 - valid loss: 0.30185
04/20/2021 04:31:20 - valid f1 score: 63.88%
04/20/2021 04:31:20 - valid recall: 70.63%
04/20/2021 04:31:20 - valid precision: 68.97%
04/20/2021 04:31:20 - Save model to ./models/
epoch 3 iter 198: train loss 0.16657, f1 68.92%, recall 76.94%, precision 72.79%, 

## On Test Data

In [33]:
model.load_state_dict(torch.load('models/bert_model.pt'))

<All keys matched successfully>

In [34]:
tsd_test = pd.read_csv("toxic_spans/data/tsd_test.csv") 
tsd_test.text = tsd_test.text.apply(lambda x:x.translate(translationTable))
tsd_test.spans = tsd_test.spans.apply(literal_eval)
tsd_test.tail(5)

Unnamed: 0,spans,text
1995,"[4, 5, 6, 7, 8, 70, 71, 72, 73, 74, 75, 76, 77...",hey loser change your name to something more a...
1996,"[23, 24, 25, 26, 27]",And you are a complete moron who obviously doe...
1997,"[157, 158, 159, 160, 161, 162, 163, 164, 165, ...",Such vitriol from the left. Who would have th...
1998,[],It is now time for most of you to expand your ...
1999,"[828, 829, 830, 831]","Why does this author think she can demand, or ..."


In [35]:
text_list_test = tsd_test.text.to_list()
spans_list_test = tsd_test.spans.to_list()
print(len(text_list_test))
print(len(spans_list_test))

2000
2000


In [36]:
def predict(text,spans):
    input_ids,attention_mask,labels = encode_and_trans_labels([text], [spans])
    outputs = model(input_ids.to(device), attention_mask.to(device), labels=labels.to(device).long())
    logits = outputs.logits 
    spans_pred = decode_and_trans_labels([text], input_ids, logits)[0]
    gold = "".join(text[i] for i in spans)
    pred = "".join(text[i] for i in spans_pred)
    # print("gold: {}\npred: {}".format(gold,pred))
    f1_score = f1(spans_pred, spans)
    recall_score = recall(spans_pred, spans)
    precision_score = precision(spans_pred, spans)
    # print("f1 {:.2f}%, recall {:.2f}%, precision {:.2f}%,".format(f1_score*100,recall_score*100,precision_score*100))
    return f1_score,recall_score,precision_score

In [37]:
f1_scores,recall_scores,precision_scores = [],[],[]
for i in range(len(text_list_test)):
    f1_score,recall_score,precision_score = predict(text_list[i],spans_list[i])
    f1_scores.append(f1_score)
    recall_scores.append(recall_score)
    precision_scores.append(precision_score)
print("All: f1 {:.2f}%, recall {:.2f}%, precision {:.2f}%,".format(np.mean(f1_scores)*100,np.mean(recall_scores)*100,np.mean(precision_scores)*100))

All: f1 67.50%, recall 74.65%, precision 70.95%,
