In [1]:
!pip install transformers[torch]

Collecting transformers[torch]
  Downloading transformers-4.35.0-py3-none-any.whl (7.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m48.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers[torch])
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m29.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers[torch])
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m70.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers[torch])
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m67.6 MB/s

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import AdamW
import pandas as pd
from tqdm import tqdm


In [3]:
# Load the dataset
df = pd.read_csv('most_toxic_data.csv').head(1000)

# Tokenization
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokens = tokenizer(df['reference'].tolist(), max_length=128, padding=True, truncation=True, return_tensors="pt")

In [4]:
# Dataset class
class DetoxificationDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(int(self.labels[idx]))
        return item

    def __len__(self):
        return len(self.labels)

# Prepare dataset and dataloader
dataset = DetoxificationDataset(tokens, df['trn_tox'])
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [5]:
# Load pre-trained BERT model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# Define optimizer
optimizer = AdamW(model.parameters(), lr=1e-5)
epochs = 5


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Assuming 'epochs' is defined and you have a 'dataloader' for your data
loss_values = []  # List to store loss values

for epoch in tqdm(range(epochs), desc="Epoch"):
    total_loss = 0  # Variable to store the epoch loss

    for batch in tqdm(dataloader, desc="Iteration"):
        optimizer.zero_grad()

        # Unpack the batch contents and feed them into the model
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        # Forward pass and compute loss
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        # Backward pass
        loss.backward()
        optimizer.step()

        # Update total loss
        total_loss += loss.item()

    # Calculate the average loss over the entire epoch
    avg_epoch_loss = total_loss / len(dataloader)
    print(f"Average Loss at Epoch {epoch}: {avg_epoch_loss}")

    # Store the average loss value for plotting or analysis later
    loss_values.append(avg_epoch_loss)

# After training loop, you might want to save the model


Epoch:   0%|          | 0/5 [00:00<?, ?it/s]
  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

Iteration:   2%|▏         | 1/63 [00:17<17:42, 17.14s/it][A
Iteration:   3%|▎         | 2/63 [00:26<12:48, 12.59s/it][A
Iteration:   5%|▍         | 3/63 [00:37<11:36, 11.62s/it][A
Iteration:   6%|▋         | 4/63 [00:46<10:28, 10.66s/it][A
Iteration:   8%|▊         | 5/63 [00:55<09:52, 10.22s/it][A
Iteration:  10%|▉         | 6/63 [01:06<09:45, 10.28s/it][A
Iteration:  11%|█         | 7/63 [01:14<09:06,  9.76s/it][A
Iteration:  13%|█▎        | 8/63 [01:24<08:58,  9.79s/it][A
Iteration:  14%|█▍        | 9/63 [01:35<09:02, 10.05s/it][A
Iteration:  16%|█▌        | 10/63 [01:44<08:38,  9.78s/it][A
Iteration:  17%|█▋        | 11/63 [01:54<08:30,  9.82s/it][A
Iteration:  19%|█▉        | 12/63 [02:04<08:29, 10.00s/it][A
Iteration:  21%|██        | 13/63 [02:13<07:57,  9.54s/it][A
Iteration:  22%|██▏       | 14/63 [02:23<08:06,  9.92s/it][A
Iteration:  24%|█

Average Loss at Epoch 0: 0.07926585171963015



Iteration:   0%|          | 0/63 [00:00<?, ?it/s][A
Iteration:   2%|▏         | 1/63 [00:10<10:42, 10.37s/it][A
Iteration:   3%|▎         | 2/63 [00:18<09:24,  9.25s/it][A
Iteration:   5%|▍         | 3/63 [00:29<09:47,  9.78s/it][A
Iteration:   6%|▋         | 4/63 [00:39<09:40,  9.83s/it][A
Iteration:   8%|▊         | 5/63 [00:48<09:10,  9.49s/it][A
Iteration:  10%|▉         | 6/63 [01:00<09:52, 10.39s/it][A
Iteration:  11%|█         | 7/63 [01:10<09:43, 10.42s/it][A
Iteration:  13%|█▎        | 8/63 [01:19<08:58,  9.80s/it][A
Iteration:  14%|█▍        | 9/63 [01:29<08:59,  9.99s/it][A
Iteration:  16%|█▌        | 10/63 [01:39<08:52, 10.05s/it][A
Iteration:  17%|█▋        | 11/63 [01:48<08:19,  9.61s/it][A
Iteration:  19%|█▉        | 12/63 [01:58<08:23,  9.87s/it][A
Iteration:  21%|██        | 13/63 [02:08<08:12,  9.85s/it][A
Iteration:  22%|██▏       | 14/63 [02:17<07:48,  9.57s/it][A
Iteration:  24%|██▍       | 15/63 [02:27<07:51,  9.82s/it][A
Iteration:  25%|██▌      

Average Loss at Epoch 1: 0.0042768856732263456



Iteration:   0%|          | 0/63 [00:00<?, ?it/s][A
Iteration:   2%|▏         | 1/63 [00:09<09:23,  9.09s/it][A
Iteration:   3%|▎         | 2/63 [00:18<09:27,  9.31s/it][A
Iteration:   5%|▍         | 3/63 [00:28<09:48,  9.80s/it][A
Iteration:   6%|▋         | 4/63 [00:37<09:11,  9.34s/it][A
Iteration:   8%|▊         | 5/63 [00:47<09:16,  9.59s/it][A
Iteration:  10%|▉         | 6/63 [00:57<09:20,  9.84s/it][A
Iteration:  11%|█         | 7/63 [01:06<08:45,  9.38s/it][A
Iteration:  13%|█▎        | 8/63 [01:16<08:54,  9.72s/it][A
Iteration:  14%|█▍        | 9/63 [01:27<08:55,  9.92s/it][A
Iteration:  16%|█▌        | 10/63 [01:35<08:21,  9.46s/it][A
Iteration:  17%|█▋        | 11/63 [01:48<09:00, 10.39s/it][A
Iteration:  19%|█▉        | 12/63 [01:57<08:35, 10.11s/it][A
Iteration:  21%|██        | 13/63 [02:06<08:09,  9.80s/it][A
Iteration:  22%|██▏       | 14/63 [02:17<08:08,  9.97s/it][A
Iteration:  24%|██▍       | 15/63 [02:25<07:42,  9.64s/it][A
Iteration:  25%|██▌      

Average Loss at Epoch 2: 0.0018442948441213323



Iteration:   0%|          | 0/63 [00:00<?, ?it/s][A
Iteration:   2%|▏         | 1/63 [00:10<10:37, 10.28s/it][A
Iteration:   3%|▎         | 2/63 [00:19<10:06,  9.94s/it][A
Iteration:   5%|▍         | 3/63 [00:28<09:28,  9.48s/it][A
Iteration:   6%|▋         | 4/63 [00:39<09:39,  9.83s/it][A
Iteration:   8%|▊         | 5/63 [00:48<09:15,  9.57s/it][A
Iteration:  10%|▉         | 6/63 [00:57<09:02,  9.51s/it][A
Iteration:  11%|█         | 7/63 [01:08<09:08,  9.79s/it][A
Iteration:  13%|█▎        | 8/63 [01:16<08:38,  9.42s/it][A
Iteration:  14%|█▍        | 9/63 [01:26<08:38,  9.61s/it][A
Iteration:  16%|█▌        | 10/63 [01:37<08:41,  9.84s/it][A
Iteration:  17%|█▋        | 11/63 [01:45<08:09,  9.41s/it][A
Iteration:  19%|█▉        | 12/63 [01:55<08:13,  9.69s/it][A
Iteration:  21%|██        | 13/63 [02:06<08:14,  9.88s/it][A
Iteration:  22%|██▏       | 14/63 [02:14<07:42,  9.44s/it][A
Iteration:  24%|██▍       | 15/63 [02:26<08:14, 10.30s/it][A
Iteration:  25%|██▌      

Average Loss at Epoch 3: 0.0010203922782758518



Iteration:   0%|          | 0/63 [00:00<?, ?it/s][A
Iteration:   2%|▏         | 1/63 [00:08<08:45,  8.47s/it][A
Iteration:   3%|▎         | 2/63 [00:18<09:44,  9.58s/it][A
Iteration:   5%|▍         | 3/63 [00:29<09:57,  9.96s/it][A
Iteration:   6%|▋         | 4/63 [00:37<09:13,  9.38s/it][A
Iteration:   8%|▊         | 5/63 [00:48<09:25,  9.76s/it][A
Iteration:  10%|▉         | 6/63 [00:58<09:27,  9.96s/it][A
Iteration:  11%|█         | 7/63 [01:07<08:51,  9.49s/it][A
Iteration:  13%|█▎        | 8/63 [01:17<08:58,  9.79s/it][A
Iteration:  14%|█▍        | 9/63 [01:27<08:54,  9.89s/it][A
Iteration:  16%|█▌        | 10/63 [01:36<08:24,  9.52s/it][A
Iteration:  17%|█▋        | 11/63 [01:46<08:29,  9.79s/it][A
Iteration:  19%|█▉        | 12/63 [01:56<08:17,  9.75s/it][A
Iteration:  21%|██        | 13/63 [02:05<07:58,  9.57s/it][A
Iteration:  22%|██▏       | 14/63 [02:15<08:00,  9.81s/it][A
Iteration:  24%|██▍       | 15/63 [02:24<07:40,  9.60s/it][A
Iteration:  25%|██▌      

Average Loss at Epoch 4: 0.0006582681856502498





In [7]:
# Save the model's state_dict
model.save_pretrained('bert_model')
tokenizer.save_pretrained('bert_model')

('bert_model/tokenizer_config.json',
 'bert_model/special_tokens_map.json',
 'bert_model/vocab.txt',
 'bert_model/added_tokens.json')

In [8]:
!zip -r bert_model.zip bert_model

  adding: bert_model/ (stored 0%)
  adding: bert_model/vocab.txt (deflated 53%)
  adding: bert_model/tokenizer_config.json (deflated 75%)
  adding: bert_model/special_tokens_map.json (deflated 42%)
  adding: bert_model/config.json (deflated 49%)
  adding: bert_model/model.safetensors (deflated 7%)


In [13]:
def predict(toxicity_classifier, tokenizer, text):
    # Prepare the input text
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    # Get model predictions
    with torch.no_grad():
        outputs = toxicity_classifier(input_ids, attention_mask=attention_mask)

    logits = outputs.logits.squeeze()  # Squeeze the logits to remove any extra dimensions

    # Apply softmax if there are multiple classes (logits would be a 1D Tensor with more than one element)
    if logits.numel() > 1:  # More than one element means multi-class classification
        probabilities = torch.softmax(logits, dim=0)
        predicted_class = torch.argmax(probabilities).item()
        return probabilities, predicted_class
    # Apply sigmoid if it's binary classification (logits would be a single-element tensor)
    else:  # Single element means binary classification
        probability = torch.sigmoid(logits).item()
        return probability

In [17]:
example_text = "I like that shit."
prediction = predict(model, tokenizer, example_text)

# Process the prediction
if isinstance(prediction, tuple):  # Multi-class classification
    probabilities, predicted_class = prediction
    print(f"Probabilities: {probabilities[0].item()}")
    print(f"Predicted toxic: {'Yes' if predicted_class == 0 else 'No'}")


Probabilities: 0.9992675185203552
Predicted toxic: Yes
