## Setup and Imports

In [1]:
experiment = 'ISHate-lora-back-translation'

In [2]:
import os

COLAB = False
if 'google.colab' in str(get_ipython()):
    COLAB = True

if COLAB:
    from google.colab import drive, userdata
    drive.mount('/content/drive')
    repo_path = '/content/drive/Othercomputers/My Mac/266-implicit-hate-speech-detection'

    hf_token = userdata.get('hf_token')

else:
    repo_path = '..'

!python -m pip install transformers accelerate datasets evaluate peft bitsandbytes tqdm

data_path = os.path.join(repo_path, 'data/processed')
aug_path = os.path.join(repo_path, 'data/back_translation')

Mounted at /content/drive
Collecting accelerate
  Downloading accelerate-0.29.2-py3-none-any.whl (297 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.4/297.4 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting peft
  Downloading peft-0.10.0-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.1/199.1 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bitsandbytes
  Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl (119.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m 

In [1]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset

from transformers import (
    BertForSequenceClassification,
    BertConfig,
    BertTokenizer,
    EvalPrediction,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    BitsAndBytesConfig
)

from peft import (
    PeftModel,
    PeftConfig,
    PeftType,
    prepare_model_for_kbit_training,
    LoraConfig,
    get_peft_model
)

import accelerate

import evaluate
from datasets import load_dataset
from datetime import datetime
from sklearn.metrics import classification_report
import time
import math

import bitsandbytes as bnb

ModuleNotFoundError: No module named 'peft'

In [4]:
# Path Definitions
exp_dir = os.path.join(repo_path, 'experiments', experiment)

model_dir = os.path.join(repo_path, f'models/hateBERT-{experiment}')
model_target = 'GroNLP/hateBERT'

train_file = os.path.join(aug_path, 'ishate/ishate_train.csv')
val_file = os.path.join(data_path, 'ishate/ishate_val.csv')
test_file = os.path.join(data_path, 'ishate/ishate_test.csv')

results_file = os.path.join(exp_dir, 'results.csv')
metrics_file = os.path.join(exp_dir, 'metrics.csv')

## Load Data/Model/Tokenizer

In [5]:
data = load_dataset(
    "csv",
    data_files = {
        "train": train_file,
    }
)

val = load_dataset(
    'csv',
    data_files = {
        "val": val_file,
    }
)

Generating train split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizer.from_pretrained(model_target, token=hf_token, max_length=512)

# set padding_side and truncation side to 'left', following hateBERT procedure
tokenizer.padding_side = 'left'
tokenizer.truncation_side = 'left'

data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer,
    padding = 'max_length',
    max_length = 512,
)

tokenizer_config.json:   0%|          | 0.00/151 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.24k [00:00<?, ?B/s]

## Preprocess Data

In [7]:
def preprocess(example):
    encoded = tokenizer(
        example['cleaned_text'],
        add_special_tokens=True,
        padding='max_length'
    )

    return encoded

In [8]:
processed = data.map(preprocess)
processed.set_format("torch")

processed_val = val.map(preprocess)
processed_val.set_format("torch")

Map:   0%|          | 0/21021 [00:00<?, ? examples/s]

Map:   0%|          | 0/4367 [00:00<?, ? examples/s]

In [9]:
processed

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'index', 'id', 'text', 'cleaned_text', 'label_name', 'label', 'orig_id', 'orig_cleaned_text', 'aug_method', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 21021
    })
})

## Define model

In [None]:
peft_config = LoraConfig(
    task_type="SEQ_CLS",
    inference_mode=False,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
)

In [None]:
model = BertForSequenceClassification.from_pretrained(
    model_target,
    num_labels=3,
    output_attentions=False,
    output_hidden_states=False,
    token=hf_token,
#    quantization_config=bnb_config
)

model.to(device)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at GroNLP/hateBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [None]:
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 297,219 || all params: 109,781,766 || trainable%: 0.27073621679578375



## Train setup

In [None]:
batch_size = 20
metric_name = "f1"

args = TrainingArguments(
    model_dir,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    push_to_hub=False,
)

In [None]:
def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions,
            tuple) else p.predictions

    y_pred = np.argmax(preds, axis=1).flatten()
    y_true = p.label_ids

    result = classification_report(y_pred, y_true, output_dict=True)
    result['f1'] = result['weighted avg']['f1-score']
    return result

## Train

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=processed['train'],
    eval_dataset=processed_val['val'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


## Run Fine-tuning

In [None]:
start = time.time()
trainer.train()
end = time.time()

print(f"Total training time: ~{(end - start) // 60} minutes")

Epoch,Training Loss,Validation Loss,0,1,2,Accuracy,Macro avg,Weighted avg,F1
1,0.631,0.811225,"{'precision': 0.4962686567164179, 'recall': 0.835427135678392, 'f1-score': 0.6226591760299626, 'support': 1592}","{'precision': 0.8507661558960693, 'recall': 0.4645325572935613, 'f1-score': 0.6009411764705883, 'support': 2749}","{'precision': 0.010752688172043012, 'recall': 0.07692307692307693, 'f1-score': 0.018867924528301886, 'support': 26}",0.597435,"{'precision': 0.4525958335948434, 'recall': 0.45896092329834337, 'f1-score': 0.41415609234295087, 'support': 4367}","{'precision': 0.7165320434951465, 'recall': 0.5974353102816579, 'f1-score': 0.605393008563106, 'support': 4367}",0.605393
2,0.5092,0.63033,"{'precision': 0.7261194029850746, 'recall': 0.8625886524822695, 'f1-score': 0.7884927066450568, 'support': 2256}","{'precision': 0.8194536975349767, 'recall': 0.5936293436293436, 'f1-score': 0.688497061293031, 'support': 2072}","{'precision': 0.026881720430107527, 'recall': 0.1282051282051282, 'f1-score': 0.04444444444444444, 'support': 39}",0.728418,"{'precision': 0.5241516069833863, 'recall': 0.5281410414389137, 'f1-score': 0.5071447374608441, 'support': 4367}","{'precision': 0.7641588782971317, 'recall': 0.7284176780398443, 'f1-score': 0.7344032036921779, 'support': 4367}",0.734403
3,0.4353,0.581591,"{'precision': 0.7667910447761194, 'recall': 0.8685545224006762, 'f1-score': 0.8145065398335314, 'support': 2366}","{'precision': 0.8261159227181879, 'recall': 0.6365503080082136, 'f1-score': 0.7190489997100609, 'support': 1948}","{'precision': 0.04838709677419355, 'recall': 0.16981132075471697, 'f1-score': 0.07531380753138076, 'support': 53}",0.756583,"{'precision': 0.5470980214228336, 'recall': 0.5583053837212023, 'f1-score': 0.5362897823583244, 'support': 4367}","{'precision': 0.7845353665043189, 'recall': 0.7565834669109228, 'f1-score': 0.7629543293978698, 'support': 4367}",0.762954
4,0.4234,0.556322,"{'precision': 0.8, 'recall': 0.867664912990692, 'f1-score': 0.8324597165598914, 'support': 2471}","{'precision': 0.8181212524983345, 'recall': 0.6655826558265583, 'f1-score': 0.7340107591153615, 'support': 1845}","{'precision': 0.06451612903225806, 'recall': 0.23529411764705882, 'f1-score': 0.10126582278481013, 'support': 51}",0.774903,"{'precision': 0.5608791271768642, 'recall': 0.5895138954881031, 'f1-score': 0.5559120994866876, 'support': 4367}","{'precision': 0.7990666437920936, 'recall': 0.774902679184795, 'f1-score': 0.7823270820126766, 'support': 4367}",0.782327
5,0.4136,0.555491,"{'precision': 0.7880597014925373, 'recall': 0.8734491315136477, 'f1-score': 0.8285602196939977, 'support': 2418}","{'precision': 0.8307794803464357, 'recall': 0.6573537163943068, 'f1-score': 0.7339611536197764, 'support': 1897}","{'precision': 0.07526881720430108, 'recall': 0.2692307692307692, 'f1-score': 0.11764705882352942, 'support': 52}",0.772384,"{'precision': 0.564702666347758, 'recall': 0.6000112057129079, 'f1-score': 0.5600561440457679, 'support': 4367}","{'precision': 0.7981293819374324, 'recall': 0.7723837874971377, 'f1-score': 0.7790017326987922, 'support': 4367}",0.779002
6,0.4,0.547858,"{'precision': 0.7884328358208955, 'recall': 0.8767634854771784, 'f1-score': 0.8302554027504913, 'support': 2410}","{'precision': 0.8347768154563624, 'recall': 0.6657810839532412, 'f1-score': 0.7407626367129767, 'support': 1882}","{'precision': 0.10752688172043011, 'recall': 0.26666666666666666, 'f1-score': 0.1532567049808429, 'support': 75}",0.775361,"{'precision': 0.576912177665896, 'recall': 0.6030704120323621, 'f1-score': 0.5747582481481036, 'support': 4367}","{'precision': 0.7967111557467975, 'recall': 0.7753606594916419, 'f1-score': 0.7800606951674077, 'support': 4367}",0.780061
7,0.3862,0.54903,"{'precision': 0.7835820895522388, 'recall': 0.8808724832214765, 'f1-score': 0.8293838862559243, 'support': 2384}","{'precision': 0.8481012658227848, 'recall': 0.6578811369509044, 'f1-score': 0.7409778812572759, 'support': 1935}","{'precision': 0.08064516129032258, 'recall': 0.3125, 'f1-score': 0.1282051282051282, 'support': 48}",0.775819,"{'precision': 0.570776172221782, 'recall': 0.6170845400574603, 'f1-score': 0.5661889652394428, 'support': 4367}","{'precision': 0.804443924571001, 'recall': 0.7758186397984886, 'f1-score': 0.7825045182552779, 'support': 4367}",0.782505


Trainer is attempting to log a value of "{'precision': 0.4962686567164179, 'recall': 0.835427135678392, 'f1-score': 0.6226591760299626, 'support': 1592}" of type <class 'dict'> for key "eval/0" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.8507661558960693, 'recall': 0.4645325572935613, 'f1-score': 0.6009411764705883, 'support': 2749}" of type <class 'dict'> for key "eval/1" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.010752688172043012, 'recall': 0.07692307692307693, 'f1-score': 0.018867924528301886, 'support': 26}" of type <class 'dict'> for key "eval/2" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.4525958335948434, 'recall': 0

Epoch,Training Loss,Validation Loss,0,1,2,Accuracy,Macro avg,Weighted avg,F1
1,0.631,0.811225,"{'precision': 0.4962686567164179, 'recall': 0.835427135678392, 'f1-score': 0.6226591760299626, 'support': 1592}","{'precision': 0.8507661558960693, 'recall': 0.4645325572935613, 'f1-score': 0.6009411764705883, 'support': 2749}","{'precision': 0.010752688172043012, 'recall': 0.07692307692307693, 'f1-score': 0.018867924528301886, 'support': 26}",0.597435,"{'precision': 0.4525958335948434, 'recall': 0.45896092329834337, 'f1-score': 0.41415609234295087, 'support': 4367}","{'precision': 0.7165320434951465, 'recall': 0.5974353102816579, 'f1-score': 0.605393008563106, 'support': 4367}",0.605393
2,0.5092,0.63033,"{'precision': 0.7261194029850746, 'recall': 0.8625886524822695, 'f1-score': 0.7884927066450568, 'support': 2256}","{'precision': 0.8194536975349767, 'recall': 0.5936293436293436, 'f1-score': 0.688497061293031, 'support': 2072}","{'precision': 0.026881720430107527, 'recall': 0.1282051282051282, 'f1-score': 0.04444444444444444, 'support': 39}",0.728418,"{'precision': 0.5241516069833863, 'recall': 0.5281410414389137, 'f1-score': 0.5071447374608441, 'support': 4367}","{'precision': 0.7641588782971317, 'recall': 0.7284176780398443, 'f1-score': 0.7344032036921779, 'support': 4367}",0.734403
3,0.4353,0.581591,"{'precision': 0.7667910447761194, 'recall': 0.8685545224006762, 'f1-score': 0.8145065398335314, 'support': 2366}","{'precision': 0.8261159227181879, 'recall': 0.6365503080082136, 'f1-score': 0.7190489997100609, 'support': 1948}","{'precision': 0.04838709677419355, 'recall': 0.16981132075471697, 'f1-score': 0.07531380753138076, 'support': 53}",0.756583,"{'precision': 0.5470980214228336, 'recall': 0.5583053837212023, 'f1-score': 0.5362897823583244, 'support': 4367}","{'precision': 0.7845353665043189, 'recall': 0.7565834669109228, 'f1-score': 0.7629543293978698, 'support': 4367}",0.762954
4,0.4234,0.556322,"{'precision': 0.8, 'recall': 0.867664912990692, 'f1-score': 0.8324597165598914, 'support': 2471}","{'precision': 0.8181212524983345, 'recall': 0.6655826558265583, 'f1-score': 0.7340107591153615, 'support': 1845}","{'precision': 0.06451612903225806, 'recall': 0.23529411764705882, 'f1-score': 0.10126582278481013, 'support': 51}",0.774903,"{'precision': 0.5608791271768642, 'recall': 0.5895138954881031, 'f1-score': 0.5559120994866876, 'support': 4367}","{'precision': 0.7990666437920936, 'recall': 0.774902679184795, 'f1-score': 0.7823270820126766, 'support': 4367}",0.782327
5,0.4136,0.555491,"{'precision': 0.7880597014925373, 'recall': 0.8734491315136477, 'f1-score': 0.8285602196939977, 'support': 2418}","{'precision': 0.8307794803464357, 'recall': 0.6573537163943068, 'f1-score': 0.7339611536197764, 'support': 1897}","{'precision': 0.07526881720430108, 'recall': 0.2692307692307692, 'f1-score': 0.11764705882352942, 'support': 52}",0.772384,"{'precision': 0.564702666347758, 'recall': 0.6000112057129079, 'f1-score': 0.5600561440457679, 'support': 4367}","{'precision': 0.7981293819374324, 'recall': 0.7723837874971377, 'f1-score': 0.7790017326987922, 'support': 4367}",0.779002
6,0.4,0.547858,"{'precision': 0.7884328358208955, 'recall': 0.8767634854771784, 'f1-score': 0.8302554027504913, 'support': 2410}","{'precision': 0.8347768154563624, 'recall': 0.6657810839532412, 'f1-score': 0.7407626367129767, 'support': 1882}","{'precision': 0.10752688172043011, 'recall': 0.26666666666666666, 'f1-score': 0.1532567049808429, 'support': 75}",0.775361,"{'precision': 0.576912177665896, 'recall': 0.6030704120323621, 'f1-score': 0.5747582481481036, 'support': 4367}","{'precision': 0.7967111557467975, 'recall': 0.7753606594916419, 'f1-score': 0.7800606951674077, 'support': 4367}",0.780061
7,0.3862,0.54903,"{'precision': 0.7835820895522388, 'recall': 0.8808724832214765, 'f1-score': 0.8293838862559243, 'support': 2384}","{'precision': 0.8481012658227848, 'recall': 0.6578811369509044, 'f1-score': 0.7409778812572759, 'support': 1935}","{'precision': 0.08064516129032258, 'recall': 0.3125, 'f1-score': 0.1282051282051282, 'support': 48}",0.775819,"{'precision': 0.570776172221782, 'recall': 0.6170845400574603, 'f1-score': 0.5661889652394428, 'support': 4367}","{'precision': 0.804443924571001, 'recall': 0.7758186397984886, 'f1-score': 0.7825045182552779, 'support': 4367}",0.782505
8,0.3934,0.529031,"{'precision': 0.8097014925373134, 'recall': 0.8735909822866345, 'f1-score': 0.8404337722695584, 'support': 2484}","{'precision': 0.8294470353097935, 'recall': 0.6781045751633987, 'f1-score': 0.7461792028768355, 'support': 1836}","{'precision': 0.08602150537634409, 'recall': 0.3404255319148936, 'f1-score': 0.13733905579399144, 'support': 47}",0.785665,"{'precision': 0.5750566777411503, 'recall': 0.630707029788309, 'f1-score': 0.5746506769801285, 'support': 4367}","{'precision': 0.8102143977660077, 'recall': 0.785665216395695, 'f1-score': 0.7932396250107099, 'support': 4367}",0.79324
9,0.3923,0.543591,"{'precision': 0.7873134328358209, 'recall': 0.8813700918964077, 'f1-score': 0.8316909735908554, 'support': 2394}","{'precision': 0.8481012658227848, 'recall': 0.6626756897449245, 'f1-score': 0.7440093512565751, 'support': 1921}","{'precision': 0.08064516129032258, 'recall': 0.28846153846153844, 'f1-score': 0.12605042016806722, 'support': 52}",0.778109,"{'precision': 0.5720199533163094, 'recall': 0.6108357733676235, 'f1-score': 0.5672502483384992, 'support': 4367}","{'precision': 0.8056387538909139, 'recall': 0.7781085413327227, 'f1-score': 0.7847182909066014, 'support': 4367}",0.784718
10,0.3896,0.537454,"{'precision': 0.7940298507462686, 'recall': 0.8778877887788779, 'f1-score': 0.8338557993730408, 'support': 2424}","{'precision': 0.8407728181212525, 'recall': 0.6677248677248677, 'f1-score': 0.7443232084930699, 'support': 1890}","{'precision': 0.08602150537634409, 'recall': 0.3018867924528302, 'f1-score': 0.13389121338912136, 'support': 53}",0.77994,"{'precision': 0.5736080580812883, 'recall': 0.6158331496521919, 'f1-score': 0.5706900737517441, 'support': 4367}","{'precision': 0.8056670767673617, 'recall': 0.77994046256011, 'f1-score': 0.7866117600278856, 'support': 4367}",0.786612


Trainer is attempting to log a value of "{'precision': 0.8097014925373134, 'recall': 0.8735909822866345, 'f1-score': 0.8404337722695584, 'support': 2484}" of type <class 'dict'> for key "eval/0" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.8294470353097935, 'recall': 0.6781045751633987, 'f1-score': 0.7461792028768355, 'support': 1836}" of type <class 'dict'> for key "eval/1" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.08602150537634409, 'recall': 0.3404255319148936, 'f1-score': 0.13733905579399144, 'support': 47}" of type <class 'dict'> for key "eval/2" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.5750566777411503, 'recall': 0.6

Total training time: ~81.0 minutes


In [None]:
trainer.evaluate()

Trainer is attempting to log a value of "{'precision': 0.8097014925373134, 'recall': 0.8735909822866345, 'f1-score': 0.8404337722695584, 'support': 2484}" of type <class 'dict'> for key "eval/0" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.8294470353097935, 'recall': 0.6781045751633987, 'f1-score': 0.7461792028768355, 'support': 1836}" of type <class 'dict'> for key "eval/1" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.08602150537634409, 'recall': 0.3404255319148936, 'f1-score': 0.13733905579399144, 'support': 47}" of type <class 'dict'> for key "eval/2" as a scalar. This invocation of Tensorboard's writer.add_scalar() is incorrect so we dropped this attribute.
Trainer is attempting to log a value of "{'precision': 0.5750566777411503, 'recall': 0.6

{'eval_loss': 0.5290306806564331,
 'eval_0': {'precision': 0.8097014925373134,
  'recall': 0.8735909822866345,
  'f1-score': 0.8404337722695584,
  'support': 2484},
 'eval_1': {'precision': 0.8294470353097935,
  'recall': 0.6781045751633987,
  'f1-score': 0.7461792028768355,
  'support': 1836},
 'eval_2': {'precision': 0.08602150537634409,
  'recall': 0.3404255319148936,
  'f1-score': 0.13733905579399144,
  'support': 47},
 'eval_accuracy': 0.785665216395695,
 'eval_macro avg': {'precision': 0.5750566777411503,
  'recall': 0.630707029788309,
  'f1-score': 0.5746506769801285,
  'support': 4367},
 'eval_weighted avg': {'precision': 0.8102143977660077,
  'recall': 0.785665216395695,
  'f1-score': 0.7932396250107099,
  'support': 4367},
 'eval_f1': 0.7932396250107099,
 'eval_runtime': 41.6494,
 'eval_samples_per_second': 104.851,
 'eval_steps_per_second': 5.258,
 'epoch': 10.0}

## Save best model checkpoint

In [None]:
trainer.save_model(os.path.join(model_dir, 'final_model'))