## Cleaning the data

### Fix the test data

In [7]:
import pandas as pd

test_data = pd.read_csv("../data/test.csv")
test_labels = pd.read_csv("../data/test_labels.csv")

In [11]:
test_data.head()

Unnamed: 0,id,comment_text
0,00001cee341fdb12,Yo bitch Ja Rule is more succesful then you'll...
1,0000247867823ef7,== From RfC == \n\n The title is fine as it is...
2,00013b17ad220c46,""" \n\n == Sources == \n\n * Zawe Ashton on Lap..."
3,00017563c3f7919a,":If you have a look back at the source, the in..."
4,00017695ad8997eb,I don't anonymously edit articles at all.


In [10]:
test_labels.head()

Unnamed: 0,id,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,00001cee341fdb12,-1,-1,-1,-1,-1,-1
1,0000247867823ef7,-1,-1,-1,-1,-1,-1
2,00013b17ad220c46,-1,-1,-1,-1,-1,-1
3,00017563c3f7919a,-1,-1,-1,-1,-1,-1
4,00017695ad8997eb,-1,-1,-1,-1,-1,-1


In [28]:
# Merge test_data and test_labels on 'id'
merged_test_data = pd.merge(test_data, test_labels, on='id', how='inner')

# Remove rows where any label is -1
valid_test_data = merged_test_data[~(merged_test_data[['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']] == -1).any(axis=1)]

print(valid_test_data.head())
# Reset index after filtering
valid_test_data.reset_index(drop=True, inplace=True)

                  id                                       comment_text  \
5   0001ea8717f6de06  Thank you for understanding. I think very high...   
7   000247e83dcc1211                   :Dear god this site is horrible.   
11  0002f87b16116a7f  "::: Somebody will invariably try to add Relig...   
13  0003e1cccfd5a40a  " \n\n It says it right there that it IS a typ...   
14  00059ace3e3e9a53  " \n\n == Before adding a new product to the l...   

    toxic  severe_toxic  obscene  threat  insult  identity_hate  
5       0             0        0       0       0              0  
7       0             0        0       0       0              0  
11      0             0        0       0       0              0  
13      0             0        0       0       0              0  
14      0             0        0       0       0              0  


In [32]:
valid_test_data.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0001ea8717f6de06,Thank you for understanding. I think very high...,0,0,0,0,0,0
1,000247e83dcc1211,:Dear god this site is horrible.,0,0,0,0,0,0
2,0002f87b16116a7f,"""::: Somebody will invariably try to add Relig...",0,0,0,0,0,0
3,0003e1cccfd5a40a,""" \n\n It says it right there that it IS a typ...",0,0,0,0,0,0
4,00059ace3e3e9a53,""" \n\n == Before adding a new product to the l...",0,0,0,0,0,0


In [22]:
print((valid_test_data[['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']] == -1).any(axis=1).sum() == 0)

True


In [35]:
#valid_test_data.to_csv('../data/train_data.csv', index=False)

### Fix the train data

In [2]:
import pandas as pd
train_data = pd.read_csv("../data/train.csv")

In [3]:
train_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 159571 entries, 0 to 159570
Data columns (total 8 columns):
 #   Column         Non-Null Count   Dtype 
---  ------         --------------   ----- 
 0   id             159571 non-null  object
 1   comment_text   159571 non-null  object
 2   toxic          159571 non-null  int64 
 3   severe_toxic   159571 non-null  int64 
 4   obscene        159571 non-null  int64 
 5   threat         159571 non-null  int64 
 6   insult         159571 non-null  int64 
 7   identity_hate  159571 non-null  int64 
dtypes: int64(6), object(2)
memory usage: 9.7+ MB


In [17]:
import pandas as pd
import numpy as np

# Read the data
df = pd.read_csv('../data/train.csv')

# List of toxic categories
toxic_categories = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

def analyze_toxic_comments(df):
    # 1. Exclusive labels analysis
    print("1. EXCLUSIVE LABELS ANALYSIS")
    print("-" * 50)
    
    exclusive_counts = {}
    for category in toxic_categories:
        # Get rows where only this category is 1 and others are 0
        mask = df[toxic_categories].sum(axis=1) == 1
        exclusive = df[mask & (df[category] == 1)]
        if len(exclusive) > 0:
            exclusive_counts[category] = len(exclusive)
    
    print("Comments with exactly one label:")
    for category, count in exclusive_counts.items():
        print(f"{category}: {count}")
    print(f"Total exclusive labels: {sum(exclusive_counts.values())}")
    
    print("\n2. COMBINATION LABELS ANALYSIS")
    print("-" * 50)
    
    # Get number of labels per comment
    label_counts = df[toxic_categories].sum(axis=1)
    
    # Count combinations
    for i in range(2, len(toxic_categories) + 1):
        count = len(df[label_counts == i])
        if count > 0:
            print(f"Comments with {i} labels: {count}")
    
    print("\n3. TOTAL COUNT PER LABEL")
    print("-" * 50)
    
    # Count total occurrences of each label
    for category in toxic_categories:
        print(f"{category}: {df[category].sum()}")
    
    print("\n4. TOTAL COMMENTS")
    print("-" * 50)
    print(f"Total number of comments: {len(df)}")

# Run the analysis
analyze_toxic_comments(df)

1. EXCLUSIVE LABELS ANALYSIS
--------------------------------------------------
Comments with exactly one label:
toxic: 5666
obscene: 317
threat: 22
insult: 301
identity_hate: 54
Total exclusive labels: 6360

2. COMBINATION LABELS ANALYSIS
--------------------------------------------------
Comments with 2 labels: 3480
Comments with 3 labels: 4209
Comments with 4 labels: 1760
Comments with 5 labels: 385
Comments with 6 labels: 31

3. TOTAL COUNT PER LABEL
--------------------------------------------------
toxic: 15294
severe_toxic: 1595
obscene: 8449
threat: 478
insult: 7877
identity_hate: 1405

4. TOTAL COMMENTS
--------------------------------------------------
Total number of comments: 159571


Exclusive Labels Analysis
- **Total comments with single labels:** 6,360
  - **toxic alone:** 5,666 (89% of exclusive labels)
  - Very few comments are exclusively labeled as other categories.
  - **severe_toxic:** never appears alone (0 exclusive counts).
  - **threat:** only 22 exclusive cases.
  - This suggests that **severe_toxic** always appears with other labels.

Label Combinations
- **Total comments with multiple labels:** 9,865 (3,480 + 4,209 + 1,760 + 385 + 31).
- **Interesting pattern:** More comments have **3 labels (4,209)** than **2 labels (3,480)**.
- Very few comments have all 6 labels (31).

Distribution of Combinations:
- **2 labels:** 35.3% of multi-label cases.
- **3 labels:** 42.7% of multi-label cases.
- **4 labels:** 17.8% of multi-label cases.
- **5 labels:** 3.9% of multi-label cases.
- **6 labels:** 0.3% of multi-label cases.

Total Label Distribution:
- **toxic** appears in 15,294 comments.
- Other labels in descending order:
  - **obscene:** 8,449 (55.2% of toxic comments).
  - **insult:** 7,877 (51.5% of toxic comments).
  - **severe_toxic:** 1,595 (10.4% of toxic comments).
  - **identity_hate:** 1,405 (9.2% of toxic comments).
  - **threat:** 478 (3.1% of toxic comments).

Key Findings:
1. The dataset has a **hierarchical nature**, where **toxic** acts as a parent category.
2. There's a **strong correlation between labels**, especially with **toxic**.
3. **Imbalances** in the dataset:
   - Between **single** vs **multiple labels**.
   - Between **different categories**.
   - In the **combination patterns** of labels.


In [10]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

def create_balanced_datasets(input_file, output_dir, toxic_cols):
    # Read original data
    df = pd.read_csv(input_file)
    
    # Identify toxic and non-toxic comments
    toxic_mask = df[toxic_cols].sum(axis=1) > 0
    toxic_samples = df[toxic_mask]
    non_toxic_samples = df[~toxic_mask]
    
    # Calculate sizes
    n_toxic = len(toxic_samples)
    
    # Create 1:1 ratio dataset
    n_non_toxic_1_1 = n_toxic
    non_toxic_1_1 = non_toxic_samples.sample(n=n_non_toxic_1_1, random_state=11)
    balanced_1_1 = pd.concat([toxic_samples, non_toxic_1_1]).sample(frac=1, random_state=11)
    
    # Create 2:1 ratio dataset
    n_non_toxic_2_1 = n_toxic * 2
    non_toxic_2_1 = non_toxic_samples.sample(n=n_non_toxic_2_1, random_state=11)
    balanced_2_1 = pd.concat([toxic_samples, non_toxic_2_1]).sample(frac=1, random_state=11)
    
    # Save datasets
    balanced_1_1.to_csv(f"{output_dir}/balanced_1_1_ratio.csv", index=False)
    balanced_2_1.to_csv(f"{output_dir}/balanced_2_1_ratio.csv", index=False)
    
    # Create distribution statistics
    def get_stats(df, name):
        total = len(df)
        toxic = df[toxic_mask].shape[0]
        non_toxic = df[~toxic_mask].shape[0]
        return {
            'dataset': name,
            'total_samples': total,
            'toxic_samples': toxic,
            'non_toxic_samples': non_toxic,
            'toxic_ratio': toxic/total,
            'non_toxic_ratio': non_toxic/total
        }
    
    stats = pd.DataFrame([
        get_stats(df, 'original'),
        get_stats(balanced_1_1, 'balanced_1_1'),
        get_stats(balanced_2_1, 'balanced_2_1')
    ])
    
    return stats

# Create the datasets and get statistics
toxic_categories = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
stats = create_balanced_datasets('../data/train.csv', '../data', toxic_categories)
print("\nDataset Statistics:")
print("-" * 80)
print(stats.to_string(index=False))


Dataset Statistics:
--------------------------------------------------------------------------------
     dataset  total_samples  toxic_samples  non_toxic_samples  toxic_ratio  non_toxic_ratio
    original         159571          16225             143346     0.101679         0.898321
balanced_1_1          32450          16225              16225     0.500000         0.500000
balanced_2_1          48675          16225              32450     0.333333         0.666667


  toxic = df[toxic_mask].shape[0]
  non_toxic = df[~toxic_mask].shape[0]
  toxic = df[toxic_mask].shape[0]
  non_toxic = df[~toxic_mask].shape[0]


In [11]:
bal11 = pd.read_csv('../data/balanced_1_1_ratio.csv')
bal21 = pd.read_csv('../data/balanced_2_1_ratio.csv')

In [12]:
bal11.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,8ce265d58ecae917,Or what your gonna ban my IP? What a joke... D...,1,0,1,0,1,0
1,222575836beb0c3d,"""\n\nA couple of points: First, relief maps ty...",0,0,0,0,0,0
2,459e25ec70a22c74,"""\nSimply stating it isn't good enough, please...",0,0,0,0,0,0
3,2d5f99355347c0cf,Arrgh \n\nOK I'll go back to wikiHow it's fun ...,1,0,1,0,0,0
4,ce40ff512091244d,"""\nYou obviously do need some kind of lecture,...",0,0,0,0,0,0


In [13]:
bal21.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,6063cac387260c02,"Kimchi,\nthats not advertising. Its the artist...",1,0,1,0,1,0
1,bad17dbfea78f651,I am a little surprised and shocked at the ton...,0,0,0,0,0,0
2,5eae46b6e87b85ac,"Barring a citation of course, no. A religion o...",0,0,0,0,0,0
3,8f079550cbe1980e,If someone does end up putting that sock stuff...,0,0,0,0,0,0
4,0332f7a0090832c1,"Hi, a pleasure, I'm fascinated by coal mines, ...",0,0,0,0,0,0


In [18]:
analyze_toxic_comments(bal11)

1. EXCLUSIVE LABELS ANALYSIS
--------------------------------------------------
Comments with exactly one label:
toxic: 5666
obscene: 317
threat: 22
insult: 301
identity_hate: 54
Total exclusive labels: 6360

2. COMBINATION LABELS ANALYSIS
--------------------------------------------------
Comments with 2 labels: 3480
Comments with 3 labels: 4209
Comments with 4 labels: 1760
Comments with 5 labels: 385
Comments with 6 labels: 31

3. TOTAL COUNT PER LABEL
--------------------------------------------------
toxic: 15294
severe_toxic: 1595
obscene: 8449
threat: 478
insult: 7877
identity_hate: 1405

4. TOTAL COMMENTS
--------------------------------------------------
Total number of comments: 32450


In [19]:
analyze_toxic_comments(bal21)

1. EXCLUSIVE LABELS ANALYSIS
--------------------------------------------------
Comments with exactly one label:
toxic: 5666
obscene: 317
threat: 22
insult: 301
identity_hate: 54
Total exclusive labels: 6360

2. COMBINATION LABELS ANALYSIS
--------------------------------------------------
Comments with 2 labels: 3480
Comments with 3 labels: 4209
Comments with 4 labels: 1760
Comments with 5 labels: 385
Comments with 6 labels: 31

3. TOTAL COUNT PER LABEL
--------------------------------------------------
toxic: 15294
severe_toxic: 1595
obscene: 8449
threat: 478
insult: 7877
identity_hate: 1405

4. TOTAL COMMENTS
--------------------------------------------------
Total number of comments: 48675


In [21]:
import pandas as pd
import numpy as np

def analyze_toxic_categories(df, toxic_cols):
    # Get only toxic comments
    toxic_mask = df[toxic_cols].sum(axis=1) > 0
    toxic_samples = df[toxic_mask]
    
    print("\nDistribution of Specific Toxic Categories:")
    print("-" * 50)
    for col in toxic_cols[1:]:  # Skip 'toxic' as it's the parent category
        count = toxic_samples[col].sum()
        percentage = (count / len(toxic_samples)) * 100
        print(f"{col}: {count} ({percentage:.1f}%)")
    
    print("\nLabel Combination Analysis:")
    print("-" * 50)
    label_counts = toxic_samples[toxic_cols[1:]].sum(axis=1)
    for i in range(1, len(toxic_cols)):
        count = (label_counts == i).sum()
        if count > 0:
            percentage = (count / len(toxic_samples)) * 100
            print(f"Comments with {i} toxic categories: {count} ({percentage:.1f}%)")

# You can run this for both datasets
toxic_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
df_1_1 = pd.read_csv('../data/balanced_1_1_ratio.csv')
df_2_1 = pd.read_csv('../data/balanced_2_1_ratio.csv')

## Rest

In [1]:
data_url = "https://raw.githubusercontent.com/vinaysanga/toxic-comments-classifier/refs/heads/main/data/train.csv"

In [None]:
labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
# Count the number of rows with exclusive labels (only one label)
exclusive_labels_count = train_data[train_data[['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']].sum(axis=1) == 1].shape[0]

# Filter the valid test data for each label
valid_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
label_counts = train_data[valid_labels].apply(lambda x: x.sum(), axis=0)

# Show the label names and their counts
for label, count in label_counts.items():
    print(f"{label}: {count}")
label_combinations_count = {}
for i in range(2, len(labels) + 1):
    count = train_data[train_data[['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']].sum(axis=1) == i].shape[0]
    label_combinations_count[f'{i} labels'] = count

print(f'Exclusive labels count: {exclusive_labels_count}')
print('Combination of labels count:', label_combinations_count)


toxic: 15294
severe_toxic: 1595
obscene: 8449
threat: 478
insult: 7877
identity_hate: 1405
Exclusive labels count: 6360
Combination of labels count: {'2 labels': 3480, '3 labels': 4209, '4 labels': 1760, '5 labels': 385, '6 labels': 31}


In [2]:
from datasets import load_dataset
data = load_dataset("csv", data_files=data_url)
data = data.shuffle(seed=11)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
train_test_dataset = data['train'].train_test_split(test_size=0.2, seed=11)
test_val_dataset = train_test_dataset['test'].train_test_split(test_size=0.5, seed=11)

In [60]:
from datasets import DatasetDict
train_test_val_dataset = DatasetDict({
    'train': train_test_dataset['train'],
    'test': test_val_dataset['test'],
    'validation': test_val_dataset['train']
})

In [None]:
train_test_val_dataset.csv

In [61]:
train_test_val_dataset['train'].column_names

['id',
 'comment_text',
 'toxic',
 'severe_toxic',
 'obscene',
 'threat',
 'insult',
 'identity_hate']

In [62]:
labels = [column for column in train_test_val_dataset['train'].column_names[2:]]

In [63]:
labels

['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

In [64]:
non_labels = [column for column in train_test_val_dataset['train'].column_names if column not in labels]

In [65]:
label2id = {label:id for id, label in enumerate(labels)}

In [66]:
id2label = {id:label for label, id in label2id.items()}

In [69]:
from transformers import AutoTokenizer, DataCollatorWithPadding

checkpoint = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=False)

In [70]:
label2id, id2label

({'toxic': 0,
  'severe_toxic': 1,
  'obscene': 2,
  'threat': 3,
  'insult': 4,
  'identity_hate': 5},
 {0: 'toxic',
  1: 'severe_toxic',
  2: 'obscene',
  3: 'threat',
  4: 'insult',
  5: 'identity_hate'})

In [73]:
def tokenize_function(batch, labels, tokenizer):
    batch_labels = [
        [float(batch[label][i]) for label in labels]
        for i in range(len(batch["comment_text"]))
    ]

    tokenized_output = tokenizer(
        batch["comment_text"],
        truncation=True,
        max_length=512
    )

    tokenized_output["labels"] = batch_labels
    return tokenized_output


In [74]:
tokenized_datasets = train_test_val_dataset.map(
    tokenize_function,
    fn_kwargs={'labels': labels, 'tokenizer': tokenizer},
    remove_columns=non_labels + labels,
    batched=True,
)


Map: 100%|██████████| 127656/127656 [00:39<00:00, 3233.71 examples/s]
Map: 100%|██████████| 15958/15958 [00:05<00:00, 3023.30 examples/s]
Map: 100%|██████████| 15957/15957 [00:05<00:00, 3001.99 examples/s]


In [75]:
tokenized_datasets.save_to_disk('datasets-distilbert')

Saving the dataset (1/1 shards): 100%|██████████| 127656/127656 [00:00<00:00, 745671.70 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 15958/15958 [00:00<00:00, 1027552.32 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 15957/15957 [00:00<00:00, 3284512.39 examples/s]


In [14]:
tokenized_datasets['train'][1]

{'input_ids': [1,
  307,
  93910,
  274,
  270,
  262,
  1321,
  264,
  272,
  1030,
  260,
  367,
  348,
  1938,
  4293,
  361,
  5744,
  274,
  281,
  261,
  401,
  272,
  1030,
  261,
  423,
  288,
  262,
  1547,
  652,
  291,
  294,
  307,
  309,
  31229,
  510,
  760,
  307,
  309,
  8705,
  280,
  297,
  282,
  266,
  40584,
  309,
  309,
  269,
  491,
  265,
  266,
  40584,
  271,
  29664,
  267,
  1161,
  309,
  309,
  7169,
  7169,
  274,
  281,
  266,
  5744,
  40584,
  300,
  1263,
  307,
  2],
 'token_type_ids': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  

In [15]:
example_sentence_decoded = tokenizer.decode(tokenized_datasets['train'][1]['input_ids'], skip_special_tokens=True)
print(example_sentence_decoded)

" Thak you for the link to that article. You just showed everybody how stupid you are, because that article, right at the beginning says this: ""Telling someone ""Don't be a dick"" is something of a dick-move in itself"" ha ha you are a stupid dick! ) "


In [16]:
example_labels = [id2label[idx] for idx, val in enumerate(tokenized_datasets['train'][1]['labels']) if val]
print(example_labels)

['toxic', 'obscene', 'insult']


In [17]:
tokenized_datasets.save_to_disk('tokenized_datasets')

Saving the dataset (1/1 shards): 100%|██████████| 127656/127656 [00:00<00:00, 661216.59 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 15958/15958 [00:00<00:00, 556821.29 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 15957/15957 [00:00<00:00, 1893310.01 examples/s]


In [18]:
import evaluate, torch
import numpy as np

clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Apply sigmoid to logits to get probabilities
    sigmoid = torch.nn.Sigmoid()
    predictions = sigmoid(logits)

    # Convert probabilities to binary predictions using a threshold of 0.5
    predictions = (predictions > 0.5).astype(int).reshape(-1)

    # Ensure labels are integers
    labels = labels.astype(int).reshape(-1)

    # Compute metrics
    return clf_metrics.compute(predictions=predictions, references=labels)


In [19]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained(
   checkpoint, num_labels=len(labels),
           id2label=id2label, label2id=label2id,
                       problem_type = "multi_label_classification"
)


Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-small and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
training_args = TrainingArguments("test-trainer", report_to="none", eval_strategy="epoch")

In [21]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
trainer.train()

### Load dataset and model and train

In [36]:
from datasets import load_from_disk
import evaluate, torch
import numpy as np
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

tokenized_datasets = load_from_disk("tokenized_datasets")   
tokenized_datasets

model = AutoModelForSequenceClassification.from_pretrained(
   checkpoint, num_labels=len(labels),
           id2label=id2label, label2id=label2id,
                       problem_type = "multi_label_classification"
)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


clf_metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Apply sigmoid to logits to get probabilities
    sigmoid = torch.nn.Sigmoid()
    predictions = sigmoid(logits)

    # Convert probabilities to binary predictions using a threshold of 0.5
    predictions = (predictions > 0.5).astype(int).reshape(-1)

    # Ensure labels are integers
    labels = labels.astype(int).reshape(-1)

    # Compute metrics
    return clf_metrics.compute(predictions=predictions, references=labels)

training_args = TrainingArguments("toxic-trainer", report_to="none", eval_strategy="epoch")
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
)
trainer.train()