In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

########################
# 1) Device Setup      #
########################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

########################
# 2) Data + Categories #
########################

category_columns = [
    "Unlawful detention",
    "Human trafficking",
    "Enslavement",
    "Willful killing of civilians",
    "Mass execution",
    "Kidnapping",
    "Extrajudicial killing",
    "Forced disappearance",
    "Damage or destruction of civilian critical infrastructure",
    "Damage or destruction, looting, or theft of cultural heritage",
    "Military operations (battle, shelling)",
    "Gender-based or other conflict-related sexual violence",
    "Violent crackdowns on protesters/opponents/civil rights abuse",
    "Indiscriminate use of weapons",
    "Torture or indications of torture",
    "Persecution based on political, racial, ethnic, gender, or sexual orientation",
    "Movement of military, paramilitary, or other troops and equipment"
]

# Load CSVs (adjust to your actual file paths)
train_df = pd.read_csv("train.csv")
val_df   = pd.read_csv("val.csv")
test_df  = pd.read_csv("test.csv")

##################################
# 3) Custom Dataset for Articles #
##################################

class ArticleDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

    def __len__(self):
        return len(self.labels)

###############################
# 4) Tokenization + Datasets #
###############################

# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenize (Incident Narrative column)
train_encodings = tokenizer(
    list(train_df["Incident Narrative"].values),
    truncation=True,
    padding=True
)
val_encodings   = tokenizer(
    list(val_df["Incident Narrative"].values),
    truncation=True,
    padding=True
)
test_encodings  = tokenizer(
    list(test_df["Incident Narrative"].values),
    truncation=True,
    padding=True
)

# Extract labels (multi-hot vectors for each category)
train_labels = train_df[category_columns].values
val_labels   = val_df[category_columns].values
test_labels  = test_df[category_columns].values

# Create Dataset objects
train_dataset = ArticleDataset(train_encodings, train_labels)
val_dataset   = ArticleDataset(val_encodings, val_labels)
test_dataset  = ArticleDataset(test_encodings, test_labels)

########################
# 5) Model + Training  #
########################

# num_labels = number of category columns
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=len(category_columns)
)

# Move the model to the correct device
model.to(device)

# Define compute_metrics for multi-label classification
def compute_metrics(p):
    # p.predictions are logits; p.label_ids are ground truth
    preds = torch.sigmoid(torch.tensor(p.predictions))  # Convert logits to probabilities
    preds = (preds > 0.5).int().cpu().numpy()
    labels = torch.tensor(p.label_ids).cpu().numpy()

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average='weighted'
    )
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# Training arguments
training_args = TrainingArguments(
    output_dir='./results',         # Output directory
    eval_strategy="epoch",          # Evaluate at the end of each epoch
    save_strategy="epoch",          # Save model at the end of each epoch
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=30,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="f1",     # Use F1 score for best model
    logging_dir='./logs'
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

# Train the model
trainer.train()

# Save the model
trainer.save_model("./bert-multiclass-model")

###########################################
# 6) Single-Category Inference on Test Set
###########################################

# Create a DataLoader for the test dataset
test_loader = DataLoader(
    test_dataset,
    batch_size=16,  # Adjust as appropriate
    shuffle=False
)

model.eval()  # evaluation mode

y_true = []
y_pred = []

with torch.no_grad():
    for batch in test_loader:
        # Move batch inputs/labels to the same device as the model
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)

        outputs = model(**inputs)   # forward pass
        logits = outputs.logits     # shape: (batch_size, 17)

        # Convert to numpy
        probs = logits.detach().cpu().numpy()

        y_true.extend(labels.detach().cpu().numpy())
        y_pred.extend(probs)

# Convert to numpy arrays
y_true = np.array(y_true)  # shape: (num_samples, 17)
y_pred = np.array(y_pred)  # shape: (num_samples, 17)

# 1) Single-label by argmax
y_pred_single = np.argmax(y_pred, axis=1)  # shape: (num_samples,)

# 2) Exact-match accuracy if predicted category is among the "1"s in ground truth
accuracy = np.mean([
    1 if y_true[i, y_pred_single[i]] == 1 else 0
    for i in range(len(y_true))
])

print("Exact-match single-category accuracy:", accuracy)


Using device: cuda


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

  3%|▎         | 21/630 [00:13<05:03,  2.01it/s]

{'eval_loss': 0.46326664090156555, 'eval_accuracy': 0.0, 'eval_f1': 0.0287856071964018, 'eval_precision': 0.08347826086956522, 'eval_recall': 0.017391304347826087, 'eval_runtime': 0.7218, 'eval_samples_per_second': 78.967, 'eval_steps_per_second': 5.542, 'epoch': 1.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

  7%|▋         | 42/630 [00:28<04:46,  2.05it/s]

{'eval_loss': 0.35947632789611816, 'eval_accuracy': 0.08771929824561403, 'eval_f1': 0.3071107720993304, 'eval_precision': 0.31105304125430894, 'eval_recall': 0.3217391304347826, 'eval_runtime': 0.7053, 'eval_samples_per_second': 80.815, 'eval_steps_per_second': 5.671, 'epoch': 2.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 10%|█         | 63/630 [00:43<04:36,  2.05it/s]

{'eval_loss': 0.3125326633453369, 'eval_accuracy': 0.017543859649122806, 'eval_f1': 0.19375, 'eval_precision': 0.1878787878787879, 'eval_recall': 0.2, 'eval_runtime': 0.7102, 'eval_samples_per_second': 80.258, 'eval_steps_per_second': 5.632, 'epoch': 3.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 13%|█▎        | 84/630 [00:57<04:27,  2.04it/s]

{'eval_loss': 0.2882603704929352, 'eval_accuracy': 0.03508771929824561, 'eval_f1': 0.179131702286404, 'eval_precision': 0.5659420289855074, 'eval_recall': 0.11304347826086956, 'eval_runtime': 0.7119, 'eval_samples_per_second': 80.062, 'eval_steps_per_second': 5.618, 'epoch': 4.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 17%|█▋        | 105/630 [01:32<04:20,  2.02it/s]

{'eval_loss': 0.2715749740600586, 'eval_accuracy': 0.14035087719298245, 'eval_f1': 0.4156473960821787, 'eval_precision': 0.5922430830039526, 'eval_recall': 0.3739130434782609, 'eval_runtime': 0.7096, 'eval_samples_per_second': 80.329, 'eval_steps_per_second': 5.637, 'epoch': 5.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 20%|██        | 126/630 [01:46<04:07,  2.04it/s]

{'eval_loss': 0.2606455385684967, 'eval_accuracy': 0.19298245614035087, 'eval_f1': 0.4705199148677409, 'eval_precision': 0.5024357569114603, 'eval_recall': 0.45217391304347826, 'eval_runtime': 0.7135, 'eval_samples_per_second': 79.884, 'eval_steps_per_second': 5.606, 'epoch': 6.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 23%|██▎       | 147/630 [02:01<03:57,  2.04it/s]

{'eval_loss': 0.2492636740207672, 'eval_accuracy': 0.17543859649122806, 'eval_f1': 0.49885187276491627, 'eval_precision': 0.5132312252964426, 'eval_recall': 0.48695652173913045, 'eval_runtime': 0.7096, 'eval_samples_per_second': 80.331, 'eval_steps_per_second': 5.637, 'epoch': 7.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 27%|██▋       | 168/630 [02:15<03:46,  2.04it/s]

{'eval_loss': 0.24140220880508423, 'eval_accuracy': 0.2807017543859649, 'eval_f1': 0.5183062285095911, 'eval_precision': 0.4954654293406657, 'eval_recall': 0.5478260869565217, 'eval_runtime': 0.7103, 'eval_samples_per_second': 80.246, 'eval_steps_per_second': 5.631, 'epoch': 8.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 30%|███       | 189/630 [02:33<03:36,  2.04it/s]

{'eval_loss': 0.23583588004112244, 'eval_accuracy': 0.19298245614035087, 'eval_f1': 0.4706521739130435, 'eval_precision': 0.5158260869565218, 'eval_recall': 0.45217391304347826, 'eval_runtime': 0.7083, 'eval_samples_per_second': 80.477, 'eval_steps_per_second': 5.648, 'epoch': 9.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 33%|███▎      | 210/630 [03:04<03:28,  2.01it/s]

{'eval_loss': 0.2287553995847702, 'eval_accuracy': 0.2807017543859649, 'eval_f1': 0.5256468584350176, 'eval_precision': 0.5129250278706801, 'eval_recall': 0.5391304347826087, 'eval_runtime': 0.7075, 'eval_samples_per_second': 80.563, 'eval_steps_per_second': 5.654, 'epoch': 10.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 37%|███▋      | 231/630 [03:19<03:16,  2.03it/s]

{'eval_loss': 0.22480817139148712, 'eval_accuracy': 0.24561403508771928, 'eval_f1': 0.5170301003344481, 'eval_precision': 0.48979750949060413, 'eval_recall': 0.5478260869565217, 'eval_runtime': 0.7297, 'eval_samples_per_second': 78.115, 'eval_steps_per_second': 5.482, 'epoch': 11.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 40%|████      | 252/630 [03:33<03:05,  2.04it/s]

{'eval_loss': 0.2195204645395279, 'eval_accuracy': 0.22807017543859648, 'eval_f1': 0.5375494071146245, 'eval_precision': 0.6086956521739131, 'eval_recall': 0.5304347826086957, 'eval_runtime': 0.7282, 'eval_samples_per_second': 78.272, 'eval_steps_per_second': 5.493, 'epoch': 12.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 43%|████▎     | 273/630 [04:03<02:56,  2.02it/s]

{'eval_loss': 0.21884743869304657, 'eval_accuracy': 0.2807017543859649, 'eval_f1': 0.5559175488625858, 'eval_precision': 0.5905615942028984, 'eval_recall': 0.5739130434782609, 'eval_runtime': 0.7081, 'eval_samples_per_second': 80.493, 'eval_steps_per_second': 5.649, 'epoch': 13.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 47%|████▋     | 294/630 [04:25<02:45,  2.03it/s]

{'eval_loss': 0.21434271335601807, 'eval_accuracy': 0.24561403508771928, 'eval_f1': 0.5258912362000818, 'eval_precision': 0.6050291093769354, 'eval_recall': 0.5130434782608696, 'eval_runtime': 0.7161, 'eval_samples_per_second': 79.598, 'eval_steps_per_second': 5.586, 'epoch': 14.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 50%|█████     | 315/630 [04:39<02:34,  2.04it/s]

{'eval_loss': 0.21053160727024078, 'eval_accuracy': 0.24561403508771928, 'eval_f1': 0.5532091097308489, 'eval_precision': 0.597752508361204, 'eval_recall': 0.5565217391304348, 'eval_runtime': 0.7255, 'eval_samples_per_second': 78.571, 'eval_steps_per_second': 5.514, 'epoch': 15.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 53%|█████▎    | 336/630 [04:54<02:25,  2.02it/s]

{'eval_loss': 0.20665861666202545, 'eval_accuracy': 0.2807017543859649, 'eval_f1': 0.5656462585034013, 'eval_precision': 0.6132173913043478, 'eval_recall': 0.5652173913043478, 'eval_runtime': 0.7246, 'eval_samples_per_second': 78.667, 'eval_steps_per_second': 5.521, 'epoch': 16.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 57%|█████▋    | 357/630 [05:24<02:15,  2.01it/s]

{'eval_loss': 0.20395605266094208, 'eval_accuracy': 0.2631578947368421, 'eval_f1': 0.5705533789881615, 'eval_precision': 0.6011881270903009, 'eval_recall': 0.5739130434782609, 'eval_runtime': 0.71, 'eval_samples_per_second': 80.277, 'eval_steps_per_second': 5.633, 'epoch': 17.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 60%|██████    | 378/630 [05:38<02:03,  2.04it/s]

{'eval_loss': 0.20423869788646698, 'eval_accuracy': 0.24561403508771928, 'eval_f1': 0.5639016818636002, 'eval_precision': 0.5978615136876005, 'eval_recall': 0.5652173913043478, 'eval_runtime': 0.7148, 'eval_samples_per_second': 79.745, 'eval_steps_per_second': 5.596, 'epoch': 18.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 63%|██████▎   | 399/630 [05:53<01:53,  2.04it/s]

{'eval_loss': 0.20222172141075134, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.5832836438923394, 'eval_precision': 0.607438127090301, 'eval_recall': 0.5826086956521739, 'eval_runtime': 0.7269, 'eval_samples_per_second': 78.416, 'eval_steps_per_second': 5.503, 'epoch': 19.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 67%|██████▋   | 420/630 [06:08<01:43,  2.02it/s]

{'eval_loss': 0.20126166939735413, 'eval_accuracy': 0.3157894736842105, 'eval_f1': 0.5834090157802252, 'eval_precision': 0.601503105590062, 'eval_recall': 0.591304347826087, 'eval_runtime': 0.7241, 'eval_samples_per_second': 78.717, 'eval_steps_per_second': 5.524, 'epoch': 20.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 70%|███████   | 441/630 [06:22<01:32,  2.03it/s]

{'eval_loss': 0.20006106793880463, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.60125702454895, 'eval_precision': 0.665391304347826, 'eval_recall': 0.591304347826087, 'eval_runtime': 0.7328, 'eval_samples_per_second': 77.789, 'eval_steps_per_second': 5.459, 'epoch': 21.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 73%|███████▎  | 462/630 [06:43<01:22,  2.03it/s]

{'eval_loss': 0.1982196718454361, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.5885214955547438, 'eval_precision': 0.609816425120773, 'eval_recall': 0.591304347826087, 'eval_runtime': 0.7079, 'eval_samples_per_second': 80.522, 'eval_steps_per_second': 5.651, 'epoch': 22.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 77%|███████▋  | 483/630 [07:13<01:12,  2.02it/s]

{'eval_loss': 0.19652602076530457, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.5902509014706716, 'eval_precision': 0.6205507246376811, 'eval_recall': 0.5826086956521739, 'eval_runtime': 0.7173, 'eval_samples_per_second': 79.46, 'eval_steps_per_second': 5.576, 'epoch': 23.0}


 79%|███████▉  | 500/630 [07:25<01:22,  1.57it/s]

{'loss': 0.2166, 'grad_norm': 0.43137043714523315, 'learning_rate': 4.126984126984127e-06, 'epoch': 23.81}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 80%|████████  | 504/630 [07:27<01:01,  2.04it/s]

{'eval_loss': 0.1961987018585205, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.5902509014706716, 'eval_precision': 0.6205507246376811, 'eval_recall': 0.5826086956521739, 'eval_runtime': 0.7065, 'eval_samples_per_second': 80.684, 'eval_steps_per_second': 5.662, 'epoch': 24.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 83%|████████▎ | 525/630 [07:42<00:51,  2.04it/s]

{'eval_loss': 0.19729112088680267, 'eval_accuracy': 0.3157894736842105, 'eval_f1': 0.5907719609582965, 'eval_precision': 0.6141404682274247, 'eval_recall': 0.591304347826087, 'eval_runtime': 0.7087, 'eval_samples_per_second': 80.426, 'eval_steps_per_second': 5.644, 'epoch': 25.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 87%|████████▋ | 546/630 [08:16<00:41,  2.03it/s]

{'eval_loss': 0.1960185021162033, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.5902509014706716, 'eval_precision': 0.6205507246376811, 'eval_recall': 0.5826086956521739, 'eval_runtime': 0.721, 'eval_samples_per_second': 79.055, 'eval_steps_per_second': 5.548, 'epoch': 26.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 90%|█████████ | 567/630 [08:31<00:30,  2.04it/s]

{'eval_loss': 0.19597844779491425, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.5902509014706716, 'eval_precision': 0.6205507246376811, 'eval_recall': 0.5826086956521739, 'eval_runtime': 0.7114, 'eval_samples_per_second': 80.127, 'eval_steps_per_second': 5.623, 'epoch': 27.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 93%|█████████▎| 588/630 [08:45<00:20,  2.03it/s]

{'eval_loss': 0.19515424966812134, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.5902509014706716, 'eval_precision': 0.6205507246376811, 'eval_recall': 0.5826086956521739, 'eval_runtime': 0.7143, 'eval_samples_per_second': 79.804, 'eval_steps_per_second': 5.6, 'epoch': 28.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

 97%|█████████▋| 609/630 [09:15<00:10,  2.02it/s]

{'eval_loss': 0.19533133506774902, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.5902509014706716, 'eval_precision': 0.6205507246376811, 'eval_recall': 0.5826086956521739, 'eval_runtime': 0.7183, 'eval_samples_per_second': 79.358, 'eval_steps_per_second': 5.569, 'epoch': 29.0}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

100%|██████████| 630/630 [09:38<00:00,  2.03it/s]

{'eval_loss': 0.1951710730791092, 'eval_accuracy': 0.2982456140350877, 'eval_f1': 0.5902509014706716, 'eval_precision': 0.6205507246376811, 'eval_recall': 0.5826086956521739, 'eval_runtime': 0.6712, 'eval_samples_per_second': 84.924, 'eval_steps_per_second': 5.96, 'epoch': 30.0}


100%|██████████| 630/630 [09:39<00:00,  1.09it/s]


{'train_runtime': 579.9261, 'train_samples_per_second': 16.709, 'train_steps_per_second': 1.086, 'train_loss': 0.1972462472461519, 'epoch': 30.0}
Exact-match single-category accuracy: 0.7674418604651163
