In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

category_columns = [
    "Unlawful detention",
    "Human trafficking",
    "Enslavement",
    "Willful killing of civilians",
    "Mass execution",
    "Kidnapping",
    "Extrajudicial killing",
    "Forced disappearance",
    "Damage or destruction of civilian critical infrastructure",
    "Damage or destruction, looting, or theft of cultural heritage",
    "Military operations (battle, shelling)",
    "Gender-based or other conflict-related sexual violence",
    "Violent crackdowns on protesters/opponents/civil rights abuse",
    "Indiscriminate use of weapons",
    "Torture or indications of torture",
    "Persecution based on political, racial, ethnic, gender, or sexual orientation",
    "Movement of military, paramilitary, or other troops and equipment"
]

# 2) Custom Dataset class for articles
class ArticleDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

    def __len__(self):
        return len(self.labels)


  from .autonotebook import tqdm as notebook_tqdm


In [10]:
train_df = pd.read_csv("train.csv")  
val_df   = pd.read_csv("val.csv")    
test_df  = pd.read_csv("test.csv")  

# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenize the text columns
train_encodings = tokenizer(
    list(train_df["Incident Narrative"].values),
    truncation=True,
    padding=True
)
val_encodings   = tokenizer(
    list(val_df["Incident Narrative"].values),
    truncation=True,
    padding=True
)
test_encodings  = tokenizer(
    list(test_df["Incident Narrative"].values),
    truncation=True,
    padding=True
)

# Extract labels (multi-label targets in your category columns)
train_labels = train_df[category_columns].values
val_labels   = val_df[category_columns].values
test_labels  = test_df[category_columns].values

# Create Dataset objects
train_dataset = ArticleDataset(train_encodings, train_labels)
val_dataset   = ArticleDataset(val_encodings, val_labels)
test_dataset  = ArticleDataset(test_encodings, test_labels)


In [11]:
# Note: num_labels = number of category columns for multi-label classification
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    num_labels=len(category_columns)
)

# Define compute_metrics for multi-label classification
def compute_metrics(p):
    # p.predictions are logits; p.label_ids are the ground truth
    preds = torch.sigmoid(torch.tensor(p.predictions))  # Convert logits to probabilities
    preds = (preds > 0.5).int().cpu().numpy() 
    labels = torch.tensor(p.label_ids).cpu().numpy()
    
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels,
        preds,
        average='weighted'
    )
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',          # Output directory
    eval_strategy="epoch",           # Evaluate at the end of each epoch
    save_strategy="epoch",           # Save model at the end of each epoch
    learning_rate=2e-5,              # Learning rate
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=30,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",      # Use F1 score for best model
    logging_dir='./logs'
)

# Initialize Trainer with training and validation sets
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

# Train the model
trainer.train()

# Save the model
trainer.save_model("./bert-multiclass-model")


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                 
  2%|▏         | 21/1260 [00:14<10:04,  2.05it/s]

{'eval_loss': 0.44863569736480713, 'eval_accuracy': 0.05263157894736842, 'eval_f1': 0.2951841065113377, 'eval_precision': 0.2799389778794813, 'eval_recall': 0.3565217391304348, 'eval_runtime': 0.7557, 'eval_samples_per_second': 75.429, 'eval_steps_per_second': 5.293, 'epoch': 1.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                 
  3%|▎         | 42/1260 [00:31<10:07,  2.00it/s]

{'eval_loss': 0.34153708815574646, 'eval_accuracy': 0.017543859649122806, 'eval_f1': 0.23878215930794291, 'eval_precision': 0.3019367588932806, 'eval_recall': 0.2956521739130435, 'eval_runtime': 0.7158, 'eval_samples_per_second': 79.626, 'eval_steps_per_second': 5.588, 'epoch': 2.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                 
  5%|▌         | 63/1260 [00:49<09:47,  2.04it/s]

{'eval_loss': 0.2994852662086487, 'eval_accuracy': 0.08771929824561403, 'eval_f1': 0.3151808711029101, 'eval_precision': 0.2992236024844721, 'eval_recall': 0.3739130434782609, 'eval_runtime': 0.7258, 'eval_samples_per_second': 78.532, 'eval_steps_per_second': 5.511, 'epoch': 3.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                 
  7%|▋         | 84/1260 [01:07<09:36,  2.04it/s]

{'eval_loss': 0.27817434072494507, 'eval_accuracy': 0.08771929824561403, 'eval_f1': 0.3688308226838961, 'eval_precision': 0.5605072463768116, 'eval_recall': 0.3130434782608696, 'eval_runtime': 0.7124, 'eval_samples_per_second': 80.011, 'eval_steps_per_second': 5.615, 'epoch': 4.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
  8%|▊         | 105/1260 [01:25<09:24,  2.04it/s]

{'eval_loss': 0.26534804701805115, 'eval_accuracy': 0.15789473684210525, 'eval_f1': 0.45729042039057605, 'eval_precision': 0.5534161490683229, 'eval_recall': 0.45217391304347826, 'eval_runtime': 0.7165, 'eval_samples_per_second': 79.553, 'eval_steps_per_second': 5.583, 'epoch': 5.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 10%|█         | 126/1260 [01:43<09:17,  2.04it/s]

{'eval_loss': 0.25428733229637146, 'eval_accuracy': 0.21052631578947367, 'eval_f1': 0.4723389501421539, 'eval_precision': 0.5125465838509317, 'eval_recall': 0.45217391304347826, 'eval_runtime': 0.7407, 'eval_samples_per_second': 76.952, 'eval_steps_per_second': 5.4, 'epoch': 6.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 12%|█▏        | 147/1260 [02:01<09:05,  2.04it/s]

{'eval_loss': 0.24289239943027496, 'eval_accuracy': 0.24561403508771928, 'eval_f1': 0.5120114950443688, 'eval_precision': 0.5228245411877125, 'eval_recall': 0.5130434782608696, 'eval_runtime': 0.7135, 'eval_samples_per_second': 79.887, 'eval_steps_per_second': 5.606, 'epoch': 7.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 13%|█▎        | 168/1260 [02:19<08:53,  2.05it/s]

{'eval_loss': 0.23462191224098206, 'eval_accuracy': 0.3157894736842105, 'eval_f1': 0.5351395140697182, 'eval_precision': 0.5120760672295431, 'eval_recall': 0.5739130434782609, 'eval_runtime': 0.7137, 'eval_samples_per_second': 79.865, 'eval_steps_per_second': 5.605, 'epoch': 8.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 15%|█▌        | 189/1260 [02:37<08:42,  2.05it/s]

{'eval_loss': 0.22411537170410156, 'eval_accuracy': 0.2631578947368421, 'eval_f1': 0.5207155163676903, 'eval_precision': 0.5473291925465837, 'eval_recall': 0.5130434782608696, 'eval_runtime': 0.7149, 'eval_samples_per_second': 79.73, 'eval_steps_per_second': 5.595, 'epoch': 9.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 17%|█▋        | 210/1260 [02:55<08:48,  1.99it/s]

{'eval_loss': 0.21607302129268646, 'eval_accuracy': 0.24561403508771928, 'eval_f1': 0.528358695652174, 'eval_precision': 0.5033501120457643, 'eval_recall': 0.5565217391304348, 'eval_runtime': 0.757, 'eval_samples_per_second': 75.296, 'eval_steps_per_second': 5.284, 'epoch': 10.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 18%|█▊        | 231/1260 [03:14<08:29,  2.02it/s]

{'eval_loss': 0.20978175103664398, 'eval_accuracy': 0.24561403508771928, 'eval_f1': 0.5392258993769679, 'eval_precision': 0.6136521739130434, 'eval_recall': 0.5304347826086957, 'eval_runtime': 0.73, 'eval_samples_per_second': 78.079, 'eval_steps_per_second': 5.479, 'epoch': 11.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 20%|██        | 252/1260 [03:32<08:23,  2.00it/s]

{'eval_loss': 0.2058655321598053, 'eval_accuracy': 0.2631578947368421, 'eval_f1': 0.5725047387863076, 'eval_precision': 0.6790460956812563, 'eval_recall': 0.5478260869565217, 'eval_runtime': 0.7219, 'eval_samples_per_second': 78.954, 'eval_steps_per_second': 5.541, 'epoch': 12.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 22%|██▏       | 273/1260 [03:50<08:10,  2.01it/s]

{'eval_loss': 0.20015521347522736, 'eval_accuracy': 0.24561403508771928, 'eval_f1': 0.5650793650793651, 'eval_precision': 0.6439481605351169, 'eval_recall': 0.5652173913043478, 'eval_runtime': 0.7256, 'eval_samples_per_second': 78.556, 'eval_steps_per_second': 5.513, 'epoch': 13.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 23%|██▎       | 294/1260 [04:09<08:38,  1.86it/s]

{'eval_loss': 0.20107874274253845, 'eval_accuracy': 0.24561403508771928, 'eval_f1': 0.5798498396013925, 'eval_precision': 0.6600853419444124, 'eval_recall': 0.5652173913043478, 'eval_runtime': 0.7826, 'eval_samples_per_second': 72.838, 'eval_steps_per_second': 5.111, 'epoch': 14.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 25%|██▌       | 315/1260 [04:28<08:33,  1.84it/s]

{'eval_loss': 0.19547012448310852, 'eval_accuracy': 0.2631578947368421, 'eval_f1': 0.5934763076436069, 'eval_precision': 0.6746727878296876, 'eval_recall': 0.5826086956521739, 'eval_runtime': 0.7763, 'eval_samples_per_second': 73.427, 'eval_steps_per_second': 5.153, 'epoch': 15.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 27%|██▋       | 336/1260 [04:47<08:42,  1.77it/s]

{'eval_loss': 0.1910184621810913, 'eval_accuracy': 0.2807017543859649, 'eval_f1': 0.6151786019392516, 'eval_precision': 0.6891885458237066, 'eval_recall': 0.6, 'eval_runtime': 0.7819, 'eval_samples_per_second': 72.9, 'eval_steps_per_second': 5.116, 'epoch': 16.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 28%|██▊       | 357/1260 [05:08<08:33,  1.76it/s]

{'eval_loss': 0.18763896822929382, 'eval_accuracy': 0.2631578947368421, 'eval_f1': 0.6224953379439966, 'eval_precision': 0.6541036546943918, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.818, 'eval_samples_per_second': 69.685, 'eval_steps_per_second': 4.89, 'epoch': 17.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 30%|███       | 378/1260 [05:26<07:18,  2.01it/s]

{'eval_loss': 0.1872773915529251, 'eval_accuracy': 0.2631578947368421, 'eval_f1': 0.6255215977037877, 'eval_precision': 0.6945109786244001, 'eval_recall': 0.591304347826087, 'eval_runtime': 0.7245, 'eval_samples_per_second': 78.68, 'eval_steps_per_second': 5.521, 'epoch': 18.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 32%|███▏      | 399/1260 [05:45<08:15,  1.74it/s]

{'eval_loss': 0.18535009026527405, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6278152086184511, 'eval_precision': 0.6999130434782608, 'eval_recall': 0.6, 'eval_runtime': 0.8318, 'eval_samples_per_second': 68.528, 'eval_steps_per_second': 4.809, 'epoch': 19.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 33%|███▎      | 420/1260 [06:05<08:11,  1.71it/s]

{'eval_loss': 0.18538178503513336, 'eval_accuracy': 0.2631578947368421, 'eval_f1': 0.6117276215470948, 'eval_precision': 0.6833684400036006, 'eval_recall': 0.6, 'eval_runtime': 0.779, 'eval_samples_per_second': 73.175, 'eval_steps_per_second': 5.135, 'epoch': 20.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 35%|███▌      | 441/1260 [06:24<07:09,  1.91it/s]

{'eval_loss': 0.18566496670246124, 'eval_accuracy': 0.2807017543859649, 'eval_f1': 0.6088636149542718, 'eval_precision': 0.6960051375278545, 'eval_recall': 0.591304347826087, 'eval_runtime': 0.7739, 'eval_samples_per_second': 73.652, 'eval_steps_per_second': 5.169, 'epoch': 21.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 37%|███▋      | 462/1260 [06:41<06:25,  2.07it/s]

{'eval_loss': 0.17975933849811554, 'eval_accuracy': 0.2807017543859649, 'eval_f1': 0.617155657701427, 'eval_precision': 0.6519420289855072, 'eval_recall': 0.6, 'eval_runtime': 0.7199, 'eval_samples_per_second': 79.175, 'eval_steps_per_second': 5.556, 'epoch': 22.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 38%|███▊      | 483/1260 [07:01<06:54,  1.88it/s]

{'eval_loss': 0.18113550543785095, 'eval_accuracy': 0.2807017543859649, 'eval_f1': 0.6266203131288153, 'eval_precision': 0.6886510590858417, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.773, 'eval_samples_per_second': 73.736, 'eval_steps_per_second': 5.174, 'epoch': 23.0}


 40%|███▉      | 500/1260 [07:14<08:38,  1.47it/s]

{'loss': 0.1993, 'grad_norm': 0.3011831045150757, 'learning_rate': 1.2063492063492064e-05, 'epoch': 23.81}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 40%|████      | 504/1260 [07:17<06:41,  1.88it/s]

{'eval_loss': 0.1799945831298828, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.64028326236655, 'eval_precision': 0.7163456656764785, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.7685, 'eval_samples_per_second': 74.173, 'eval_steps_per_second': 5.205, 'epoch': 24.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 42%|████▏     | 525/1260 [07:51<06:26,  1.90it/s]

{'eval_loss': 0.179915189743042, 'eval_accuracy': 0.2807017543859649, 'eval_f1': 0.6365417657679106, 'eval_precision': 0.7095764349072479, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.7723, 'eval_samples_per_second': 73.807, 'eval_steps_per_second': 5.179, 'epoch': 25.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 43%|████▎     | 546/1260 [08:07<06:13,  1.91it/s]

{'eval_loss': 0.1775854080915451, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6408537218176754, 'eval_precision': 0.7168737060041408, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.7588, 'eval_samples_per_second': 75.114, 'eval_steps_per_second': 5.271, 'epoch': 26.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 45%|████▌     | 567/1260 [08:23<06:03,  1.91it/s]

{'eval_loss': 0.18057061731815338, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6238121250538503, 'eval_precision': 0.6936812994468949, 'eval_recall': 0.6, 'eval_runtime': 0.7528, 'eval_samples_per_second': 75.72, 'eval_steps_per_second': 5.314, 'epoch': 27.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 47%|████▋     | 588/1260 [08:39<05:52,  1.91it/s]

{'eval_loss': 0.1787862926721573, 'eval_accuracy': 0.2807017543859649, 'eval_f1': 0.6002717696411151, 'eval_precision': 0.6370146327802283, 'eval_recall': 0.591304347826087, 'eval_runtime': 0.7555, 'eval_samples_per_second': 75.447, 'eval_steps_per_second': 5.295, 'epoch': 28.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 48%|████▊     | 609/1260 [08:55<05:41,  1.91it/s]

{'eval_loss': 0.17692343890666962, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.64028326236655, 'eval_precision': 0.7163456656764785, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.7815, 'eval_samples_per_second': 72.94, 'eval_steps_per_second': 5.119, 'epoch': 29.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 50%|█████     | 630/1260 [09:29<05:31,  1.90it/s]

{'eval_loss': 0.17550718784332275, 'eval_accuracy': 0.2807017543859649, 'eval_f1': 0.6390574619013678, 'eval_precision': 0.7041159420289854, 'eval_recall': 0.6173913043478261, 'eval_runtime': 0.7511, 'eval_samples_per_second': 75.888, 'eval_steps_per_second': 5.325, 'epoch': 30.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 52%|█████▏    | 651/1260 [09:45<05:19,  1.91it/s]

{'eval_loss': 0.1765577346086502, 'eval_accuracy': 0.3333333333333333, 'eval_f1': 0.6592772762554893, 'eval_precision': 0.7241901561296646, 'eval_recall': 0.6260869565217392, 'eval_runtime': 0.778, 'eval_samples_per_second': 73.264, 'eval_steps_per_second': 5.141, 'epoch': 31.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 53%|█████▎    | 672/1260 [10:01<05:07,  1.91it/s]

{'eval_loss': 0.1772458553314209, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.64028326236655, 'eval_precision': 0.7163456656764785, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.769, 'eval_samples_per_second': 74.123, 'eval_steps_per_second': 5.202, 'epoch': 32.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 55%|█████▌    | 693/1260 [10:23<05:04,  1.86it/s]

{'eval_loss': 0.17465616762638092, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.64028326236655, 'eval_precision': 0.7163456656764785, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.7988, 'eval_samples_per_second': 71.354, 'eval_steps_per_second': 5.007, 'epoch': 33.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 57%|█████▋    | 714/1260 [10:54<05:06,  1.78it/s]

{'eval_loss': 0.17605800926685333, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6344539736210922, 'eval_precision': 0.714206259189246, 'eval_recall': 0.6, 'eval_runtime': 0.8211, 'eval_samples_per_second': 69.416, 'eval_steps_per_second': 4.871, 'epoch': 34.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 58%|█████▊    | 735/1260 [11:22<04:37,  1.89it/s]

{'eval_loss': 0.1750870794057846, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.643261599883149, 'eval_precision': 0.7048788716009888, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.7632, 'eval_samples_per_second': 74.69, 'eval_steps_per_second': 5.241, 'epoch': 35.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 60%|██████    | 756/1260 [11:38<04:25,  1.90it/s]

{'eval_loss': 0.1772109717130661, 'eval_accuracy': 0.3157894736842105, 'eval_f1': 0.6461601506077868, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.7656, 'eval_samples_per_second': 74.454, 'eval_steps_per_second': 5.225, 'epoch': 36.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 62%|██████▏   | 777/1260 [11:54<03:58,  2.03it/s]

{'eval_loss': 0.1762123703956604, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.7095, 'eval_samples_per_second': 80.343, 'eval_steps_per_second': 5.638, 'epoch': 37.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 63%|██████▎   | 798/1260 [12:09<04:01,  1.91it/s]

{'eval_loss': 0.17762839794158936, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.7276, 'eval_samples_per_second': 78.341, 'eval_steps_per_second': 5.498, 'epoch': 38.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 65%|██████▌   | 819/1260 [12:29<03:39,  2.01it/s]

{'eval_loss': 0.1763109713792801, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.717, 'eval_samples_per_second': 79.493, 'eval_steps_per_second': 5.578, 'epoch': 39.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 67%|██████▋   | 840/1260 [13:00<03:35,  1.95it/s]

{'eval_loss': 0.17506076395511627, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6437493259449303, 'eval_precision': 0.7241901561296646, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.8378, 'eval_samples_per_second': 68.038, 'eval_steps_per_second': 4.775, 'epoch': 40.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 68%|██████▊   | 861/1260 [13:18<03:47,  1.75it/s]

{'eval_loss': 0.17628300189971924, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6437493259449303, 'eval_precision': 0.7241901561296646, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.836, 'eval_samples_per_second': 68.181, 'eval_steps_per_second': 4.785, 'epoch': 41.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 70%|███████   | 882/1260 [13:36<03:35,  1.75it/s]

{'eval_loss': 0.17723269760608673, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6437493259449303, 'eval_precision': 0.7241901561296646, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.8548, 'eval_samples_per_second': 66.681, 'eval_steps_per_second': 4.679, 'epoch': 42.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 72%|███████▏  | 903/1260 [14:13<03:16,  1.81it/s]

{'eval_loss': 0.17512091994285583, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6437493259449303, 'eval_precision': 0.7241901561296646, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.7795, 'eval_samples_per_second': 73.124, 'eval_steps_per_second': 5.132, 'epoch': 43.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 73%|███████▎  | 924/1260 [14:28<02:57,  1.90it/s]

{'eval_loss': 0.17502492666244507, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6437493259449303, 'eval_precision': 0.7241901561296646, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.7704, 'eval_samples_per_second': 73.989, 'eval_steps_per_second': 5.192, 'epoch': 44.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 75%|███████▌  | 945/1260 [14:44<02:46,  1.90it/s]

{'eval_loss': 0.17551009356975555, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6437493259449303, 'eval_precision': 0.7241901561296646, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.7639, 'eval_samples_per_second': 74.617, 'eval_steps_per_second': 5.236, 'epoch': 45.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 77%|███████▋  | 966/1260 [14:59<02:34,  1.90it/s]

{'eval_loss': 0.1758652925491333, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.7647, 'eval_samples_per_second': 74.538, 'eval_steps_per_second': 5.231, 'epoch': 46.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                  
 78%|███████▊  | 987/1260 [15:33<02:23,  1.90it/s]

{'eval_loss': 0.1764882504940033, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.7634, 'eval_samples_per_second': 74.667, 'eval_steps_per_second': 5.24, 'epoch': 47.0}


 79%|███████▉  | 1000/1260 [15:43<02:56,  1.47it/s]

{'loss': 0.0745, 'grad_norm': 0.21877256035804749, 'learning_rate': 4.126984126984127e-06, 'epoch': 47.62}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 80%|████████  | 1008/1260 [15:48<02:11,  1.92it/s]

{'eval_loss': 0.17622601985931396, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.7537, 'eval_samples_per_second': 75.626, 'eval_steps_per_second': 5.307, 'epoch': 48.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 82%|████████▏ | 1029/1260 [16:04<02:01,  1.91it/s]

{'eval_loss': 0.17526519298553467, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.7677, 'eval_samples_per_second': 74.245, 'eval_steps_per_second': 5.21, 'epoch': 49.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 83%|████████▎ | 1050/1260 [16:36<01:51,  1.89it/s]

{'eval_loss': 0.17657418549060822, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.7549, 'eval_samples_per_second': 75.502, 'eval_steps_per_second': 5.298, 'epoch': 50.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 85%|████████▌ | 1071/1260 [17:05<01:40,  1.89it/s]

{'eval_loss': 0.17651699483394623, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.7676, 'eval_samples_per_second': 74.26, 'eval_steps_per_second': 5.211, 'epoch': 51.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 87%|████████▋ | 1092/1260 [17:26<01:28,  1.91it/s]

{'eval_loss': 0.17665186524391174, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6266984528852195, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.591304347826087, 'eval_runtime': 0.7522, 'eval_samples_per_second': 75.775, 'eval_steps_per_second': 5.318, 'epoch': 52.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 88%|████████▊ | 1113/1260 [17:41<01:17,  1.90it/s]

{'eval_loss': 0.17677801847457886, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.7582, 'eval_samples_per_second': 75.179, 'eval_steps_per_second': 5.276, 'epoch': 53.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 90%|█████████ | 1134/1260 [17:56<01:01,  2.04it/s]

{'eval_loss': 0.17554762959480286, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.7089, 'eval_samples_per_second': 80.401, 'eval_steps_per_second': 5.642, 'epoch': 54.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 92%|█████████▏| 1155/1260 [18:11<00:56,  1.85it/s]

{'eval_loss': 0.17584124207496643, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.7772, 'eval_samples_per_second': 73.339, 'eval_steps_per_second': 5.147, 'epoch': 55.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 93%|█████████▎| 1176/1260 [18:43<00:47,  1.76it/s]

{'eval_loss': 0.17556345462799072, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6437493259449303, 'eval_precision': 0.7241901561296646, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.8136, 'eval_samples_per_second': 70.06, 'eval_steps_per_second': 4.916, 'epoch': 56.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 95%|█████████▌| 1197/1260 [19:00<00:35,  1.77it/s]

{'eval_loss': 0.17673492431640625, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.8119, 'eval_samples_per_second': 70.208, 'eval_steps_per_second': 4.927, 'epoch': 57.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 97%|█████████▋| 1218/1260 [19:16<00:22,  1.90it/s]

{'eval_loss': 0.1757735162973404, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.6437493259449303, 'eval_precision': 0.7241901561296646, 'eval_recall': 0.6086956521739131, 'eval_runtime': 0.763, 'eval_samples_per_second': 74.703, 'eval_steps_per_second': 5.242, 'epoch': 58.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
 98%|█████████▊| 1239/1260 [19:49<00:11,  1.90it/s]

{'eval_loss': 0.1764277070760727, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.775, 'eval_samples_per_second': 73.545, 'eval_steps_per_second': 5.161, 'epoch': 59.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                   
100%|██████████| 1260/1260 [20:24<00:00,  1.90it/s]

{'eval_loss': 0.17636726796627045, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.637878577108822, 'eval_precision': 0.7222701759488149, 'eval_recall': 0.6, 'eval_runtime': 0.9695, 'eval_samples_per_second': 58.792, 'eval_steps_per_second': 4.126, 'epoch': 60.0}


100%|██████████| 1260/1260 [20:29<00:00,  1.02it/s]


{'train_runtime': 1229.3785, 'train_samples_per_second': 15.764, 'train_steps_per_second': 1.025, 'train_loss': 0.1206161135718936, 'epoch': 60.0}


In [12]:
def evaluate_model(dataset, threshold=0.5):
    loader = DataLoader(dataset, batch_size=16)
    model.eval()  # Set model to eval mode
    
    preds_list = []
    labels_list = []
    
    with torch.no_grad():
        for batch in loader:
            inputs = {
                key: val.to(model.device) for key, val in batch.items() if key != 'labels'
            }
            outputs = model(**inputs)
            logits = outputs.logits
            # Convert logits to probabilities
            probs = torch.sigmoid(logits).cpu().numpy()
            # Apply threshold
            preds = (probs > threshold).astype(int)
            
            preds_list.extend(preds)
            labels_list.extend(batch['labels'].cpu().numpy())
    
    preds_array = np.array(preds_list)
    labels_array = np.array(labels_list)
    
    precision, recall, f1, _ = precision_recall_fscore_support(labels_array, preds_array, average='weighted')
    acc = accuracy_score(labels_array, preds_array)
    
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# Run evaluation on the test dataset
test_results = evaluate_model(test_dataset)
print("Final Test Set Evaluation Results:", test_results)


Final Test Set Evaluation Results: {'accuracy': 0.3023255813953488, 'f1': 0.6284655723680114, 'precision': 0.7117797017797018, 'recall': 0.5777777777777777}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
