In [1]:
import pandas as pd
import torch
import random
import numpy as np

from torch.utils.data import DataLoader
from torch.optim import Adam

from transformers import BartTokenizerFast, DataCollatorWithPadding
from transformers import BartModel, BartForConditionalGeneration, Trainer, TrainingArguments, EvalPrediction
from datasets import Dataset, DatasetDict

# from transformers import AdamW
from transformers import get_scheduler

from tqdm.auto import tqdm

import evaluate

In [2]:
CHECKPOINT = 'facebook/bart-base'

## Dataset

In [3]:
data = pd.read_csv("../data/intermit/merged_dataset.tsv", sep='\t')
data = data.iloc[:20000]

  data = pd.read_csv("../data/intermit/merged_dataset.tsv", sep='\t')


In [4]:
tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT)


In [5]:
X_tokenized = tokenizer(list(data["toxic"].values), padding=True, truncation=True, return_tensors="pt")
# type(data["toxic"].values)

In [6]:
X_tokenized['input_ids']

tensor([[    0,   700,    56,  ...,     1,     1,     1],
        [    0,   417,  6343,  ...,     1,     1,     1],
        [    0,   757,    45,  ...,     1,     1,     1],
        ...,
        [    0, 40992,   162,  ...,     1,     1,     1],
        [    0, 17762,    64,  ...,     1,     1,     1],
        [    0,   100,   818,  ...,     1,     1,     1]])

In [7]:
y_tokenized = tokenizer(list(data["neutral1"].values), padding=True, truncation=True, return_tensors="pt")

In [8]:
y_tokenized.data

{'input_ids': tensor([[   0,  700,   21,  ...,    1,    1,    1],
         [   0,  243,   74,  ...,    1,    1,    1],
         [   0,  100,  437,  ...,    1,    1,    1],
         ...,
         [   0, 8267,  127,  ...,    1,    1,    1],
         [   0,  627,  604,  ...,    1,    1,    1],
         [   0,  100,  818,  ...,    1,    1,    1]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]])}

In [9]:
dataset = Dataset.from_dict({"input_ids": X_tokenized['input_ids'], "attention_mask": X_tokenized['attention_mask'], "label_ids": y_tokenized['input_ids']})

In [10]:
dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'label_ids'],
    num_rows: 20000
})

In [11]:
type(dataset)

datasets.arrow_dataset.Dataset

In [12]:
from sklearn.model_selection import train_test_split

# Split your dataset into training, validation, and test sets
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)

In [13]:
train_data = Dataset.from_dict(train_data)

In [14]:
train_data, validation_data = train_test_split(train_data, test_size=0.1, random_state=42)

In [15]:
train_data = Dataset.from_dict(train_data)
test_data = Dataset.from_dict(test_data)
validation_data = Dataset.from_dict(validation_data)

In [16]:
datadict = DatasetDict({"train":train_data, "test":test_data, "validation":validation_data})

In [17]:
datadict

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'label_ids'],
        num_rows: 14400
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'label_ids'],
        num_rows: 4000
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'label_ids'],
        num_rows: 1600
    })
})

In [36]:
datadict.save_to_disk(dataset_dict_path="../data/intermit/dataset-dict")

Saving the dataset (0/1 shards):   0%|          | 0/14400 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/4000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1600 [00:00<?, ? examples/s]

Torch data loader

In [18]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [19]:
train_dataloader = DataLoader(
    datadict["train"], shuffle=True, batch_size=16, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    datadict["validation"], batch_size=16, collate_fn=data_collator
)

In [20]:
for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'input_ids': torch.Size([16, 170]),
 'attention_mask': torch.Size([16, 170]),
 'labels': torch.Size([16, 155])}

## Model

In [21]:
tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT)
model = BartForConditionalGeneration.from_pretrained(CHECKPOINT)

# inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
# outputs = model(**inputs)

# last_hidden_states = outputs.last_hidden_state

In [22]:
optimizer = Adam(model.parameters(), lr=3e-5)

In [23]:
num_epochs = 20
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)
print(num_training_steps)

18000


In [24]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
device

device(type='cuda')

In [25]:
progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

  0%|          | 0/18000 [00:00<?, ?it/s]

### Evaluate

In [26]:
metric = evaluate.load("glue", "stsb")
model.eval()

predictions_list = []
references_list = []

for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    
    # Flatten the predictions and references tensors
    flat_predictions = predictions.view(-1)
    flat_references = batch["labels"].view(-1)

    predictions_list.append(flat_predictions)
    references_list.append(flat_references)


# Concatenate the lists of predictions and references into tensors
predictions = torch.cat(predictions_list, dim=0)
references = torch.cat(references_list, dim=0)

# Ensure that predictions and references are on the same device
predictions = predictions.to(device)
references = references.to(device)

# Compute the metric
metric.add_batch(predictions=predictions, references=references)
result = metric.compute()

In [27]:
result

{'pearson': 0.7458301157089653, 'spearmanr': 0.999611734921778}

### Generate a batch

In [28]:
qq = eval_dataloader.__iter__()
print(qq.__next__())
batch = qq.__next__()
batch = {k: v.to(device) for k, v in batch.items()}

{'input_ids': tensor([[    0, 12375,    64,  ...,     1,     1,     1],
        [    0,  1185,   214,  ...,     1,     1,     1],
        [    0,  6460,    10,  ...,     1,     1,     1],
        ...,
        [    0,   506, 24029,  ...,     1,     1,     1],
        [    0,  2847,    52,  ...,     1,     1,     1],
        [    0, 27037, 38538,  ...,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[    0,  8155,    64,  ...,     1,     1,     1],
        [    0, 13724,     6,  ...,     1,     1,     1],
        [    0,  6968,   197,  ...,     1,     1,     1],
        ...,
        [    0,   100,   269,  ...,     1,     1,     1],
        [    0,  2527,    52,  ...,     1,     1,     1],
        [    0, 27037, 38538,  ...,     1,     1,     1]])}


In [29]:
outputs = model.generate(**batch)



In [30]:
outputs

tensor([[    2,     0,   118,    21,    95, 27537,    19, 23644,   479,     2,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [    2,     0, 16625,  1623,     6,    84,  8453,     6,   362,   130,
          6317,     9,    84,   275,    13,    10,  1402,   744,     4,     2],
        [    2,     0,  2527,    38,   236,    47,     7,   356,    23,   127,
           652,    13,    10,    94,    86,     4,     2,     1,     1,     1],
        [    2,     0,   102,   221,     4,   100,     4, 34850,  1295,   272,
         13093,   102,  1516,   848,   734,     8,   172,   272, 13093,     2],
        [    2,     0,  1185,   357,  1407,   162,   124,     2,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [    2,     0,  6968,   236,     7,   283,     7,   127,  6085,   116,
             2,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [    2,     0,   100,  2638,   167, 30

In [31]:
batch.keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [32]:
# logits = outputs.logits[0].tolist()
# # Convert logits to text
# generated_text = tokenizer.decode(logits, skip_special_tokens=True)

# # Print the generated text
# print(generated_text)
inputs = batch['input_ids']
for i in range(len(inputs)):
    print(tokenizer.decode(inputs[i], skip_special_tokens=True, clean_up_tokenization_spaces=False))
    print(tokenizer.decode(outputs[i], skip_special_tokens=True, clean_up_tokenization_spaces=False))
# print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in inputs])
# print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in outputs])

i was just fuckin wit cha .
i was just kidding with cha .
Your husband, our king, has taken 300 of our finest to slaughter.
your husband, our king, took three hundred of our best for a certain death.
So I want you to look at my face one last fucking time.
so I want you to look at my face for a last time.
A P.I. Investigating gedda gets killed... then gedda gets killed... maybe this dirty copdid them both.
a P.I. Investigating Gedda gets killed... and then Gedd
u better follow me back for i swag ya ass and not in a good way
You better follow me back
Do you want to come in my mouth?
you want to come to my mouth?
fucking loved those grenade water balloons .
I loved those grenade water balloons.
aite hold on , downloadin shit .
aite hold on , downloadin it .
they sure as hell fucking do .
they sure do
all the gunmen are dead now , so i hope that was their entire fucking life 's goal .
all the gunmen are dead now , so i hope that was their entire life 's
You're the beneficiary of the cruele

### Save model

In [35]:
torch.save(model.state_dict(), "../models/bart-paradetox")