In [1]:
!pip install accelerate transformers datasets -U



Get the base model

In [2]:
from transformers import BartForConditionalGeneration, BartTokenizer

base_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0)
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")

Train G- model

We will train the G- model on toxic comments from the Jigsaw corpus

In [3]:
from datasets import load_dataset
datasets = load_dataset('jigsaw_toxicity_pred', data_dir='jigsaw-toxic-content-classification-challenge')
toxic_datasets = datasets.filter(lambda x: int(x['toxic']) == 1)
print(toxic_datasets['train'][1])



We define a tokenizer for text comments

In [4]:
def tokenize_function(examples):
    return tokenizer(examples["comment_text"], max_length=1024, truncation=True)


We process the toxic comments with the tokenizer

In [5]:
td_columns = ["comment_text", 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
tokenized_datasets = toxic_datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=td_columns)

In [6]:
print(tokenized_datasets["train"][1])


{'input_ids': [0, 13368, 734, 99, 16, 24, 7586, 50118, 1039, 1721, 1067, 479, 50118, 2264, 16, 24, 734, 41, 5451, 333, 9, 103, 26218, 255, 2118, 8863, 6557, 734, 8155, 32, 205, 23, 14340, 6, 1403, 12, 19051, 9695, 661, 54, 272, 9298, 5570, 143, 65, 54, 6990, 106, 1142, 4091, 90, 49, 5102, 13216, 12, 104, 4571, 11694, 8, 24566, 6997, 28120, 10002, 36, 13424, 19281, 10800, 35719, 23, 26218, 116, 50118, 50118, 33895, 208, 1571, 3810, 7, 2382, 62, 39, 3650, 87, 696, 162, 42475, 8383, 734, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


We concatenate and create blocks of text comments of a fixed block size for fine tuning a base model into the G- model

In [7]:
block_size = 128


In [8]:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result


In [9]:
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=100,
    num_proc=4,
)

We fine tune the G- model using the toxic comments dataset

In [10]:
gminus_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0)

In [11]:
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
    "gminus-bart-large",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    weight_decay=0.01
)

In [12]:
trainer = Trainer(
    model=gminus_model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["test"],
)

In [13]:
trainer.train()



Epoch,Training Loss,Validation Loss
1,0.0598,0.042259
2,0.042,0.039843
3,0.0356,0.034121


TrainOutput(global_step=3438, training_loss=0.06486834475576635, metrics={'train_runtime': 3623.5292, 'train_samples_per_second': 7.59, 'train_steps_per_second': 0.949, 'total_flos': 7449692957835264.0, 'train_loss': 0.06486834475576635, 'epoch': 3.0})

In [14]:
import math
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

Perplexity: 1.03


Train G+ model

We extract non toxic comments from the Jigsaw corpus

In [15]:
nontoxic_datasets = datasets.filter(lambda x: int(x['toxic']) == 0)
print(nontoxic_datasets['train'][1])

Filter:   0%|          | 0/159571 [00:00<?, ? examples/s]

Filter:   0%|          | 0/63978 [00:00<?, ? examples/s]

{'comment_text': "D'aww! He matches this background colour I'm seemingly stuck with. Thanks.  (talk) 21:51, January 11, 2016 (UTC)", 'toxic': 0, 'severe_toxic': 0, 'obscene': 0, 'threat': 0, 'insult': 0, 'identity_hate': 0}


We process the non toxic comments with the previously defined tokenizer

In [16]:
nontoxic_tokenized_datasets = nontoxic_datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=td_columns)

print(nontoxic_tokenized_datasets["train"][1])

Map (num_proc=4):   0%|          | 0/144277 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/57888 [00:00<?, ? examples/s]

{'input_ids': [0, 495, 108, 1584, 605, 328, 91, 2856, 42, 3618, 7705, 38, 437, 6590, 4889, 19, 4, 4557, 4, 1437, 36, 26594, 43, 733, 35, 4708, 6, 644, 365, 6, 336, 36, 41934, 43, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


Same as before we concatenate and create batches of tokenized data to feed for training the G+ model

In [17]:
nontoxic_lm_datasets = nontoxic_tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=100,
    num_proc=4,
)

print(nontoxic_lm_datasets['train'])

Map (num_proc=4):   0%|          | 0/144277 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/57888 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 108743
})


We create again a base model and then train it with non toxic comments this time

In [18]:
gplus_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large", forced_bos_token_id=0)


In [19]:
nt_training_args = TrainingArguments(
    "gplus-bart-large",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    weight_decay=0.01
)


In [20]:
nt_trainer = Trainer(
    model=gplus_model,
    args=nt_training_args,
    train_dataset=nontoxic_lm_datasets["train"],
    eval_dataset=nontoxic_lm_datasets["test"],
)


In [None]:
nt_trainer.train()


Epoch,Training Loss,Validation Loss


In [None]:
nt_eval_results = nt_trainer.evaluate()
print(f"Perplexity: {math.exp(nt_eval_results['eval_loss']):.2f}")


In [23]:
gminus_path = '/content/gminus'
trainer.save_model(gminus_path)

In [None]:
gplus_path = '/content/gplus'
nt_trainer.save_model(gplus_path)

In [24]:
from google.colab import files
files.download(gminus_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
files.download(gplus_path)