In [1]:
import json
import numpy as np
import evaluate

from datasets import load_dataset
from scipy.special import softmax
from scipy.special import expit

from sklearn.metrics import multilabel_confusion_matrix, classification_report
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, precision_score, recall_score

from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback

from focal_trainer import FocalLossTrainer

In [2]:
# helper functions

def tokenizerFunction(example):
    
    title_mod = [f"{t}[SEP]{s}" for t, s in zip(example['title'], example['summary'])]
    
    return tokenizer(title_mod, example['genres'], padding = 'max_length', truncation = True)

In [3]:
# load datasets using hugging face

data_files = {
        'train' : '../data/training',
        'val' : '../data/validation',
        'test' : '../data/test',
        } 

training = load_dataset('json', data_files = data_files, split = 'train')
validation = load_dataset('json', data_files = data_files, split = 'val')
test = load_dataset('json', data_files = data_files, split = 'test')

In [4]:
# train using PyTorch Trainer API

# tokenize datasets using DeBERTaV3 tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    'microsoft/deberta-v3-base',
    padding = True,
    truncation = True,
    max_length = 512,
    model_max_length = 512,
)

tokenized_training = training.map(tokenizerFunction, batched = True).remove_columns(["genres", "title", "summary",])
tokenized_validation = validation.map(tokenizerFunction, batched = True).remove_columns(["genres", "title", "summary",])
tokenized_test = test.map(tokenizerFunction, batched = True).remove_columns(["genres", "title", "summary",])



In [5]:
# define the model

model = AutoModelForSequenceClassification.from_pretrained('microsoft/deberta-v3-base', num_labels = 3)

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# define a TrainingArguments class that will contain all the hyperparameters the Trainer will use for training and
# evaluation
# the model will be saved in the input directory

training_args = TrainingArguments(
    per_device_train_batch_size = 24,
    gradient_accumulation_steps = 2,
    per_device_eval_batch_size = 128,
    num_train_epochs = 15,
    eval_strategy = "epoch",
    save_strategy = "epoch",
    output_dir = '../models/fine_tuned_DeBERTaV3',
    log_level = "info",
    fp16 = True,
    optim = "adamw_torch",
    learning_rate = 5.0E-6,
    load_best_model_at_end = True,
    metric_for_best_model = "eval_balanced_accuracy",
    greater_is_better = True,
)

In [7]:
# compute_metrics() function to calculate a metric when evaluating the model during training 
# (otherwise the evaluation would just print the loss, which is not a very intuitive number).


def computeMetrics(eval_pred):
#     convert the logits to predictions before passing the predictions to compute
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis = -1)
        
    accuracy = accuracy_score(y_true = labels, y_pred = predictions)
    balanced_accuracy = balanced_accuracy_score(y_true = labels, y_pred = predictions)
    precision = precision_score(y_true = labels, y_pred = predictions, average = 'macro')
    recall = recall_score(y_true = labels, y_pred = predictions, average = 'macro')
    f1 = f1_score(y_true = labels, y_pred = predictions, average = 'macro')
    cls_report = classification_report(labels, predictions, 
                                                  output_dict = True, labels = [0, 1, 2], 
                                                  target_names = ['bad_rating', 'average_rating', 'good_rating'])

    res = {
        'accuracy': accuracy, 
        'balanced_accuracy': balanced_accuracy, 
        'precision': precision, 
        'recall': recall, 
        'f1': f1, 
        'classification_report': cls_report,
        }
    
    return res

In [8]:
# TRAINER
# define the trainer object

#trainer = Trainer(
trainer = FocalLossTrainer(
    model = model,
    args = training_args,
    train_dataset = tokenized_training,
    eval_dataset = tokenized_validation,
    compute_metrics = computeMetrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)

Using auto half precision backend


In [9]:
# fine-tune the model

trainer.train()

***** Running training *****
  Num examples = 370,940
  Num Epochs = 15
  Instantaneous batch size per device = 24
  Total train batch size (w. parallel, distributed & accumulation) = 48
  Gradient Accumulation steps = 2
  Total optimization steps = 115,920
  Number of trainable parameters = 184,424,451


Epoch,Training Loss,Validation Loss,Accuracy,Balanced Accuracy,Precision,Recall,F1,Classification Report
1,0.0993,0.099282,0.726757,0.395734,0.591128,0.395734,0.397759,"{'bad_rating': {'precision': 0.4973544973544973, 'recall': 0.06483503851017358, 'f1-score': 0.11471575307637547, 'support': 8699.0}, 'average_rating': {'precision': 0.7393844620487, 'recall': 0.9675076452599388, 'f1-score': 0.8382019188623192, 'support': 44472.0}, 'good_rating': {'precision': 0.5366439727673208, 'recall': 0.15485958627065757, 'f1-score': 0.24035874439461882, 'support': 8653.0}, 'accuracy': 0.726756599378882, 'macro avg': {'precision': 0.5911276440568393, 'recall': 0.39573409001359, 'f1-score': 0.39775880544443787, 'support': 61824.0}, 'weighted avg': {'precision': 0.6769534948413591, 'recall': 0.726756599378882, 'f1-score': 0.6527279420112428, 'support': 61824.0}}"
2,0.0917,0.100606,0.728503,0.400215,0.601941,0.400215,0.407127,"{'bad_rating': {'precision': 0.48288075560802834, 'recall': 0.09403379698815956, 'f1-score': 0.15741364379871067, 'support': 8699.0}, 'average_rating': {'precision': 0.7408970340004822, 'recall': 0.9672378125562151, 'f1-score': 0.8390714912708476, 'support': 44472.0}, 'good_rating': {'precision': 0.582046332046332, 'recall': 0.13937362764359182, 'f1-score': 0.2248951048951049, 'support': 8653.0}, 'accuracy': 0.7285034937888198, 'macro avg': {'precision': 0.6019413738849476, 'recall': 0.40021507906265547, 'f1-score': 0.40712674665488774, 'support': 61824.0}, 'weighted avg': {'precision': 0.6823595933666633, 'recall': 0.7285034937888198, 'f1-score': 0.657196978355646, 'support': 61824.0}}"
3,0.0864,0.103557,0.724136,0.420492,0.567201,0.420492,0.435467,"{'bad_rating': {'precision': 0.45172878311629994, 'recall': 0.11564547649155076, 'f1-score': 0.18414790408200624, 'support': 8699.0}, 'average_rating': {'precision': 0.7487216946676406, 'recall': 0.9449766144990106, 'f1-score': 0.8354787725768134, 'support': 44472.0}, 'good_rating': {'precision': 0.5011534025374856, 'recall': 0.2008551947301514, 'f1-score': 0.28677501856282483, 'support': 8653.0}, 'accuracy': 0.7241362577639752, 'macro avg': {'precision': 0.5672012934404753, 'recall': 0.4204924285735709, 'f1-score': 0.43546723174054813, 'support': 61824.0}, 'weighted avg': {'precision': 0.6722829367518257, 'recall': 0.7241362577639752, 'f1-score': 0.6670351130187231, 'support': 61824.0}}"
4,0.0788,0.116994,0.715984,0.42999,0.539995,0.42999,0.446655,"{'bad_rating': {'precision': 0.4074074074074074, 'recall': 0.1353029083802736, 'f1-score': 0.2031411805315844, 'support': 8699.0}, 'average_rating': {'precision': 0.7526696583936258, 'recall': 0.9239971217844937, 'f1-score': 0.8295799812248277, 'support': 44472.0}, 'good_rating': {'precision': 0.4599078341013825, 'recall': 0.23067144343002427, 'f1-score': 0.3072423612714539, 'support': 8653.0}, 'accuracy': 0.7159840838509317, 'macro avg': {'precision': 0.5399949666341385, 'recall': 0.4299904911982639, 'f1-score': 0.4466545076759553, 'support': 61824.0}, 'weighted avg': {'precision': 0.6631137515139367, 'recall': 0.7159840838509317, 'f1-score': 0.6683290341381452, 'support': 61824.0}}"
5,0.0698,0.125013,0.685883,0.459687,0.502182,0.459687,0.474179,"{'bad_rating': {'precision': 0.3218321530812114, 'recall': 0.24554546499597654, 'f1-score': 0.27856025039123633, 'support': 8699.0}, 'average_rating': {'precision': 0.7666950492640798, 'recall': 0.8503777657852132, 'f1-score': 0.8063711379773556, 'support': 44472.0}, 'good_rating': {'precision': 0.4180174031735199, 'recall': 0.2831387957933665, 'f1-score': 0.33760507096596387, 'support': 8653.0}, 'accuracy': 0.6858825051759835, 'macro avg': {'precision': 0.502181535172937, 'recall': 0.4596873421915187, 'f1-score': 0.4741788197781853, 'support': 61824.0}, 'weighted avg': {'precision': 0.6552986658932791, 'recall': 0.6858825051759835, 'f1-score': 0.6664956901098411, 'support': 61824.0}}"
6,0.0624,0.149842,0.700763,0.432632,0.507462,0.432632,0.448325,"{'bad_rating': {'precision': 0.3568716780561883, 'recall': 0.16208759627543395, 'f1-score': 0.22292490118577074, 'support': 8699.0}, 'average_rating': {'precision': 0.7541458427987279, 'recall': 0.8957771181867242, 'f1-score': 0.8188825851011347, 'support': 44472.0}, 'good_rating': {'precision': 0.41136858783917607, 'recall': 0.24003235871951925, 'f1-score': 0.3031674208144796, 'support': 8653.0}, 'accuracy': 0.7007634575569358, 'macro avg': {'precision': 0.5074620362313641, 'recall': 0.4326323577272258, 'f1-score': 0.448324969033795, 'support': 61824.0}, 'weighted avg': {'precision': 0.650271303036494, 'recall': 0.7007634575569358, 'f1-score': 0.6628474012089217, 'support': 61824.0}}"
7,0.0546,0.168041,0.679267,0.457234,0.489766,0.457234,0.468459,"{'bad_rating': {'precision': 0.3165735567970205, 'recall': 0.2149672376135188, 'f1-score': 0.2560591537724223, 'support': 8699.0}, 'average_rating': {'precision': 0.7652524508299052, 'recall': 0.8407762187443785, 'f1-score': 0.8012385758520566, 'support': 44472.0}, 'good_rating': {'precision': 0.3874716553287982, 'recall': 0.3159597827343118, 'f1-score': 0.34808071805971097, 'support': 8653.0}, 'accuracy': 0.6792669513457557, 'macro avg': {'precision': 0.489765887651908, 'recall': 0.4572344130307364, 'f1-score': 0.4684594825613966, 'support': 61824.0}, 'weighted avg': {'precision': 0.6492458041770982, 'recall': 0.6792669513457557, 'f1-score': 0.6611038266260618, 'support': 61824.0}}"
8,0.0491,0.176479,0.674544,0.452214,0.481574,0.452214,0.46169,"{'bad_rating': {'precision': 0.31450026819238336, 'recall': 0.20220715024715485, 'f1-score': 0.2461516932549678, 'support': 8699.0}, 'average_rating': {'precision': 0.7632054176072235, 'recall': 0.8362790070156503, 'f1-score': 0.7980730027252634, 'support': 44472.0}, 'good_rating': {'precision': 0.3670177309692041, 'recall': 0.3181555529874032, 'f1-score': 0.34084437291073416, 'support': 8653.0}, 'accuracy': 0.6745438664596274, 'macro avg': {'precision': 0.48157447225627026, 'recall': 0.4522139034167361, 'f1-score': 0.4616896896303218, 'support': 61824.0}, 'weighted avg': {'precision': 0.6446188145527709, 'recall': 0.6745438664596274, 'f1-score': 0.6564198776465363, 'support': 61824.0}}"



***** Running Evaluation *****
  Num examples = 61824
  Batch size = 128
Saving model checkpoint to ../models/fine_tuned_DeBERTaV3/checkpoint-7728
Configuration saved in ../models/fine_tuned_DeBERTaV3/checkpoint-7728/config.json
Model weights saved in ../models/fine_tuned_DeBERTaV3/checkpoint-7728/model.safetensors

***** Running Evaluation *****
  Num examples = 61824
  Batch size = 128
Saving model checkpoint to ../models/fine_tuned_DeBERTaV3/checkpoint-15456
Configuration saved in ../models/fine_tuned_DeBERTaV3/checkpoint-15456/config.json
Model weights saved in ../models/fine_tuned_DeBERTaV3/checkpoint-15456/model.safetensors

***** Running Evaluation *****
  Num examples = 61824
  Batch size = 128
Saving model checkpoint to ../models/fine_tuned_DeBERTaV3/checkpoint-23184
Configuration saved in ../models/fine_tuned_DeBERTaV3/checkpoint-23184/config.json
Model weights saved in ../models/fine_tuned_DeBERTaV3/checkpoint-23184/model.safetensors

***** Running Evaluation *****
  Num ex

TrainOutput(global_step=61824, training_loss=0.07559675356184227, metrics={'train_runtime': 44136.8539, 'train_samples_per_second': 126.065, 'train_steps_per_second': 2.626, 'total_flos': 7.808083318923264e+17, 'train_loss': 0.07559675356184227, 'epoch': 8.0})

In [10]:
trainer.evaluate(tokenized_validation, metric_key_prefix = "val")


***** Running Evaluation *****
  Num examples = 61824
  Batch size = 128


early stopping required metric_for_best_model, but did not find eval_balanced_accuracy so early stopping is disabled


{'val_loss': 0.12501265108585358,
 'val_accuracy': 0.6858825051759835,
 'val_balanced_accuracy': 0.4596873421915187,
 'val_precision': 0.502181535172937,
 'val_recall': 0.4596873421915187,
 'val_f1': 0.4741788197781853,
 'val_classification_report': {'bad_rating': {'precision': 0.3218321530812114,
   'recall': 0.24554546499597654,
   'f1-score': 0.27856025039123633,
   'support': 8699.0},
  'average_rating': {'precision': 0.7666950492640798,
   'recall': 0.8503777657852132,
   'f1-score': 0.8063711379773556,
   'support': 44472.0},
  'good_rating': {'precision': 0.4180174031735199,
   'recall': 0.2831387957933665,
   'f1-score': 0.33760507096596387,
   'support': 8653.0},
  'accuracy': 0.6858825051759835,
  'macro avg': {'precision': 0.502181535172937,
   'recall': 0.4596873421915187,
   'f1-score': 0.4741788197781853,
   'support': 61824.0},
  'weighted avg': {'precision': 0.6552986658932791,
   'recall': 0.6858825051759835,
   'f1-score': 0.6664956901098411,
   'support': 61824.0}},


In [11]:
trainer.evaluate(tokenized_test, metric_key_prefix = "test")


***** Running Evaluation *****
  Num examples = 61824
  Batch size = 128
early stopping required metric_for_best_model, but did not find eval_balanced_accuracy so early stopping is disabled


{'test_loss': 0.1266673505306244,
 'test_accuracy': 0.6852840320910973,
 'test_balanced_accuracy': 0.4550989336998814,
 'test_precision': 0.5002329412741152,
 'test_recall': 0.4550989336998814,
 'test_f1': 0.46988696029032245,
 'test_classification_report': {'bad_rating': {'precision': 0.3149879372738239,
   'recall': 0.23874285714285715,
   'f1-score': 0.27161617474970745,
   'support': 8750.0},
  'average_rating': {'precision': 0.7646394579770931,
   'recall': 0.8530357906103075,
   'f1-score': 0.8064224573342549,
   'support': 44453.0},
  'good_rating': {'precision': 0.4210714285714286,
   'recall': 0.27351815334647955,
   'f1-score': 0.33162224878700514,
   'support': 8621.0},
  'accuracy': 0.6852840320910973,
  'macro avg': {'precision': 0.5002329412741152,
   'recall': 0.4550989336998814,
   'f1-score': 0.46988696029032245,
   'support': 61824.0},
  'weighted avg': {'precision': 0.6530913409406697,
   'recall': 0.6852840320910973,
   'f1-score': 0.664522748960474,
   'support': 6

In [12]:
trainer.save_model(output_dir = '../models/fine_tuned_DeBERTaV3')
tokenizer.save_pretrained('../models/fine_tuned_DeBERTaV3')

Saving model checkpoint to ../models/fine_tuned_DeBERTaV3
Configuration saved in ../models/fine_tuned_DeBERTaV3/config.json
Model weights saved in ../models/fine_tuned_DeBERTaV3/model.safetensors
tokenizer config file saved in ../models/fine_tuned_DeBERTaV3/tokenizer_config.json
Special tokens file saved in ../models/fine_tuned_DeBERTaV3/special_tokens_map.json


('../models/fine_tuned_DeBERTaV3/tokenizer_config.json',
 '../models/fine_tuned_DeBERTaV3/special_tokens_map.json',
 '../models/fine_tuned_DeBERTaV3/spm.model',
 '../models/fine_tuned_DeBERTaV3/added_tokens.json',
 '../models/fine_tuned_DeBERTaV3/tokenizer.json')

In [1]:
total = 8750.0 + 8621.0 + 44453.0
good = 8621.0
average = 44453.0
bad = 8750.0

f"{good/total:1.3F}  {average/total:1.3F}   {bad/total:1.3F}"

'0.139  0.719   0.142'