# Text Summarization Project

<u>Notes</u>
* Dataset: https://huggingface.co/datasets/har1/MTS_Dialogue-Clinical_Note
* Fine-tuning model: bart-base
* NLI model: https://huggingface.co/FacebookAI/roberta-large-mnli
* Teacher forcing is automatically applied during training
* During inference, you should consider beam search and temperature

<u>Action Items</u>

* <s>Apply a fine-tuned text summarization model to the dataset and see what comes out (done)</s>
* <s>Apply bart-base to the dataset and see what comes out (done)</s>
* You should understand how BART is pretrained. You should understand the "checkpoint" you are loading with bart-base. You should understand BART-base and BART-large.
* Preprocessing - do you want to skip any samples? Also, I don't think they flagged training vs. validation, so you may want to do the splits yourself. Do you want to ignore classes with low representation?
* Is it possible to load multiple examples into a single example to take advantage of the max input length of the model?
* Which PEFT should I use? Understand how to implement this in code. You should understand the PEFT technique you're using.
* Perhaps write your own training loop to experiment with before moving on to Trainer and TrainingArguments. I'm thinking of training and inferring on one sample to begin - basically, can the model learn and apply to the same example? Then gradually increase the number of samples.
    * You'll have to figure out which metrics to use for validation, and whether you should use model.generate(), and if so, whether you should play with beam search and temperature
* push_to_hub saves the final model state, I think (including if you load best model). Check out wandb.
* unsloth - quantization, fine-tuning, pruning

<u>Ayman's Suggestions</u>
* Compare other small models like flan-t5, gpt-2
* Try packing the samples for efficiency
 
<u>Questions</u>

* How useful is ROUGE? Oftentimes people report ROUGE as their validation metric, but it never seems to change much over epochs. And it's usually kind of low.

In [None]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
# !pip install transformers
# !pip install datasets
# !pip install evaluate

In [110]:
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
import evaluate
import torch
from torch.utils.data import DataLoader
import numpy as np

# Get dataset and tokenize

In [12]:
model_name = 'facebook/bart-base'
# model_name = 'facebook/bart-large-cnn'

tokenizer = AutoTokenizer.from_pretrained(model_name)



In [13]:
# All the data is in the 'train' split, but the documentation says 1201 training samples, 100 validation samples

ds = load_dataset("har1/MTS_Dialogue-Clinical_Note", streaming=False)
ds

DatasetDict({
    train: Dataset({
        features: ['ID', 'section_header', 'section_text', 'dialogue'],
        num_rows: 1301
    })
})

In [29]:
# Class to handle tokenization and dataloading

class TextPreprocessor:
    def __init__(self, tokenizer, model, dataset, input_col, output_col=None, padding="longest", truncation=True, batch_size=8):
        """
        Initialize the class with a tokenizer, model, dataset, and various tokenization options.
        - dataset - currently assumes this is a dataset dictionary with keys 'train', 'val', 'test'
        - input_col: The column name for the input sequences.
        - output_col: The column name for the output sequences (if applicable). If None, assume single sequence input.
        - padding: Padding strategy (e.g., "max_length", "longest", or a specific integer length).
        - truncation: Truncation strategy (boolean or "longest_first").
        """
        self.tokenizer = tokenizer
        self.model = model
        self.model_max_length = model.config.max_position_embeddings
        self.dataset = dataset
        self.input_col = input_col
        self.output_col = output_col
        self.padding = padding
        self.truncation = truncation
        self.batch_size = batch_size
        self.tokenized_dataset = {}
        self.truncation_tracker = {}
        self.dataloader = {}

    def tokenize_function(self, examples):
        # Tokenize a pair of sequences (input-output pair, e.g. for summarization)
        if self.output_col:
            tokenized = self.tokenizer(
                examples[self.input_col],
                text_target=examples[self.output_col],   # target sequence
                padding=self.padding,
                truncation=self.truncation,
                max_length=self.model_max_length,
                return_tensors='np',
                return_length=True,
                # return_overflowing_tokens=True
            )

        # Tokenize a single sequence
        else:
            tokenized = self.tokenizer(
                examples[self.input_col],
                padding=self.padding,
                truncation=self.truncation,
                max_length=self.model_max_length,
                return_tensors='np',
                return_length=True,
                # return_overflowing_tokens=True
            )

        # tokenized = dictionary
        # keys = input_ids, attention_mask, labels
        # size = batch size, sequence length (input_ids and attention_mask have the same length)

        # Check for truncation
        if self.truncation not in [False, 'do_not_truncate']:
            # Get non-truncated lengths
            input_lens = self.tokenizer(examples[self.input_col], truncation=False, return_tensors='np', return_length=True)['length']
            # print(type(input_lens))
            # print(input_lens.shape)
            # print(input_lens[:10])

            # Calculate how many tokens were truncated
            trunc_lens = np.clip(input_lens - tokenized['length'], a_min=0, a_max=None)
            trunc_lens_indexing = np.nonzero(trunc_lens > 0)[0]
            # print(type(trunc_lens))
            # print(trunc_lens.shape)
            # print(trunc_lens_indexing[:10])

            # Get IDs for non-zero truncations
            # print(type(examples['ID']))
            # trunc_ids = examples['ID'][trunc_lens_indexing]
            trunc_ids = [examples['ID'][i] for i in trunc_lens_indexing]
            trunc_lens = trunc_lens[trunc_lens > 0].tolist()

            self.truncation_tracker.update(dict(zip(trunc_ids, trunc_lens)))

        return tokenized

    def tokenize_dataset(self):
        """Applies tokenization to the dataset using the tokenize_function."""

        self.tokenized_dataset = self.dataset.map(self.tokenize_function, batched=True)
        self.tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
        
        # for key in self.dataset:
        #     print(f'Tokenizing {key} split...')
        #     self.tokenized_dataset[key] = self.dataset[key].map(self.tokenize_function, batched=True)
        #     self.tokenized_dataset[key].set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

    # def create_dataloader(self, shuffle_train=True):
    #     """Creates a DataLoader from the tokenized dataset."""
    #     if self.tokenized_dataset == {}:
    #         raise ValueError("Tokenized dataset is not available. Run tokenize_dataset() first.")

    #     # Convert to PyTorch DataLoader for batching
    #     for key in self.tokenized_dataset:
    #         shuffle = shuffle_train if key == 'train' else False
    #         print(f'Creating {key} dataloader...')
    #         self.dataloader[key] = DataLoader(self.tokenized_dataset[key], batch_size=self.batch_size, shuffle=shuffle)

    # def get_dataloader(self):
    #     """Returns the DataLoader, creating it if necessary."""
    #     if self.dataloader == {}:
    #         self.create_dataloader()
    #     return self.dataloader

In [30]:
preprocessor = TextPreprocessor(tokenizer, model, dataset, 'dialogue', output_col='section_text', padding='max_length')

preprocessor.tokenize_dataset()

Map:   0%|          | 0/1301 [00:00<?, ? examples/s]

In [31]:
print(type(preprocessor.tokenized_dataset['train']['input_ids']))
print(preprocessor.tokenized_dataset['train']['input_ids'].shape)
print(type(preprocessor.tokenized_dataset['train']['attention_mask']))
print(preprocessor.tokenized_dataset['train']['attention_mask'].shape)
print(type(preprocessor.tokenized_dataset['train']['labels']))
print(preprocessor.tokenized_dataset['train']['labels'].shape)

<class 'torch.Tensor'>
torch.Size([1301, 1024])
<class 'torch.Tensor'>
torch.Size([1301, 1024])
<class 'torch.Tensor'>
torch.Size([1301, 1024])


# Instantiate model and PEFT

In [82]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [None]:
# print(model)
# dir(model)
# model.config
# model.config.max_position_embeddings # maximum input sequence length

In [83]:
peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 442,368 || all params: 139,862,784 || trainable%: 0.3163


# Training

In [117]:
output_dir = './output/'
logging_dir = './logs/'

ds_train = preprocessor.tokenized_dataset['train'].select(range(1))
ds_eval = preprocessor.tokenized_dataset['train'].select(range(1))

train_batch_size = 1
eval_batch_size = 1

training_args = TrainingArguments(
    # disable_tqdm=True,  # disable progress bar
    output_dir=output_dir,
    save_strategy='no',
    # save_strategy='epoch', # save model checkpoints at the end of each epoch
    # save_strategy='steps',
    # save_steps=0.1,
    # save_total_limit=2, # only save the last N checkpoints
    eval_strategy='epoch',  # run validation at the end of each epoch
    # eval_strategy='steps', # experimentation
    # eval_steps=1, # experimentation

    # load_best_model_at_end=True,  # load the best model found during training
    # metric_for_best_model="eval_loss",  # Metric to monitor for the best model
    # metric_for_best_model='accuracy', # this needs to match a key in the returned dictionary from compute_metrics()
    # greater_is_better=True, # higher value for metric_for_best_model is better

    # Logs report training and validation metrics, which may be visualized (including in real time) using TensorBoard.
    logging_dir=logging_dir,
    logging_strategy='no',
    # logging_strategy='epoch',
    # logging_strategy='steps',
    # logging_steps=0.02,
    # logging_steps=0.25, # 4 times per epoch
    # logging_steps=1,

    per_device_train_batch_size=train_batch_size, # batch size per device (CPU core/GPU/etc.)
    per_device_eval_batch_size=eval_batch_size, # batch size per device (CPU core/GPU/etc.)

    # Training hyperparameters
    num_train_epochs=1000, # max epochs
    # learning_rate=5e-8, # default is 5e-5
    # learning_rate=3e-5, # FYI: the default linear scheduler updates learning rate per batch
    learning_rate=5e-5,
    # warmup_ratio=0.2, # 20% of training steps
    # max_grad_norm=1.0,
    # weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    # train_dataset=tokenized_datasets['train'],
    # eval_dataset=tokenized_datasets['test'],  # This is the validation dataset
    train_dataset=ds_train,
    eval_dataset=ds_eval,
    # tokenizer=tokenizer, # only used for generative tasks like translation, summarization, question-answering (to convert decoder predictions/logits back to words) or to work with metrics that require post-processing (BLEU, ROUGE)
    # data_collator=data_collator, # dynamic padding
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],  # Stop after 3 epochs without improvement
    # compute_metrics=compute_metrics,  # custom function to compute validation metrics
)

In [118]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,No log,13.495691
2,No log,12.048134
3,No log,11.691601
4,No log,11.560511
5,No log,11.097082
6,No log,10.572636
7,No log,9.976832
8,No log,9.095772
9,No log,8.076225
10,No log,7.029835


KeyboardInterrupt: 

In [119]:
sample = {}
sample['input_ids'] = ds_train[0]['input_ids'].unsqueeze(dim=0)
sample['attention_mask'] = ds_train[0]['attention_mask'].unsqueeze(dim=0)
sample['labels'] = ds_train[0]['labels'].unsqueeze(dim=0)

In [120]:
model.to('cpu')
# output = model.generate(**sample, num_beams=4, max_length=150, early_stopping=True)
output = model.generate(**sample, max_new_tokens=200)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)

print('Dialogue:\n')
print(ds_train['dialogue'][0] + '\n')
print('Ground truth:\n')
print(ds_train['section_text'][0] + '\n')
print('Prediction:\n')
print(decoded)

Dialogue:

Doctor: What brings you back into the clinic today, miss? 
Patient: I came in for a refill of my blood pressure medicine. 
Doctor: It looks like Doctor Kumar followed up with you last time regarding your hypertension, osteoarthritis, osteoporosis, hypothyroidism, allergic rhinitis and kidney stones.  Have you noticed any changes or do you have any concerns regarding these issues?  
Patient: No. 
Doctor: Have you had any fever or chills, cough, congestion, nausea, vomiting, chest pain, chest pressure?
Patient: No.  
Doctor: Great. Also, for our records, how old are you and what race do you identify yourself as?
Patient: I am seventy six years old and identify as a white female.

Ground truth:

Symptoms: no fever, no chills, no cough, no congestion, no nausea, no vomiting, no chest pain, no chest pressure.
Diagnosis: hypertension, osteoarthritis, osteoporosis, hypothyroidism, allergic rhinitis, kidney stones
History of Patient: 76-year-old white female, presents to the clinic 

In [121]:
# Baseline

model_base = AutoModelForSeq2SeqLM.from_pretrained(model_name)
output = model_base.generate(**sample, num_beams=4, max_length=150, early_stopping=True)
decoded_base = tokenizer.decode(output[0], skip_special_tokens=True)

print('Dialogue:\n')
print(ds_train['dialogue'][0] + '\n')
print('Ground truth:\n')
print(ds_train['section_text'][0] + '\n')
print('Prediction:\n')
print(decoded_base)

Dialogue:

Doctor: What brings you back into the clinic today, miss? 
Patient: I came in for a refill of my blood pressure medicine. 
Doctor: It looks like Doctor Kumar followed up with you last time regarding your hypertension, osteoarthritis, osteoporosis, hypothyroidism, allergic rhinitis and kidney stones.  Have you noticed any changes or do you have any concerns regarding these issues?  
Patient: No. 
Doctor: Have you had any fever or chills, cough, congestion, nausea, vomiting, chest pain, chest pressure?
Patient: No.  
Doctor: Great. Also, for our records, how old are you and what race do you identify yourself as?
Patient: I am seventy six years old and identify as a white female.

Ground truth:

Symptoms: no fever, no chills, no cough, no congestion, no nausea, no vomiting, no chest pain, no chest pressure.
Diagnosis: hypertension, osteoarthritis, osteoporosis, hypothyroidism, allergic rhinitis, kidney stones
History of Patient: 76-year-old white female, presents to the clinic 

# Evaluation

In [111]:
# nli = pipeline("text-classification", model="FacebookAI/roberta-large-mnli")
bertscore = evaluate.load('bertscore')

In [122]:
r = ds_train['section_text'][0]
p = decoded

In [108]:
print(f"{r} </s> {p}")
results = nli(f"{r} </s> {p}", truncation=True, top_k=None)
results

Symptoms: no fever, no chills, no cough, no congestion, no nausea, no vomiting, no chest pain, no chest pressure.
Diagnosis: hypertension, osteoarthritis, osteoporosis, hypothyroidism, allergic rhinitis, kidney stones
History of Patient: 76-year-old white female, presents to the clinic today originally for hypertension and a med check, followed by Dr. Kumar, issues stable
Plan of Action: N/A </s> Doctor: What brings you back into the clinic today, miss?   Patient: I came in for a refill of my blood pressure medicine.  Doctor: It looks like Doctor Kumar followed up with you last time regarding your hypertension, osteoarthritis, osteoporosis, hypothyroidism, allergic rhinitis and kidney stones.  Have you noticed any changes or do you have any concerns regarding these issues? _______________________________________________Doctor: Yes, I have noticed some changes in your blood pressure, but I am not sure if any of these are related to your hypertension. ____________________________________

[{'label': 'NEUTRAL', 'score': 0.7339072823524475},
 {'label': 'ENTAILMENT', 'score': 0.15110178291797638},
 {'label': 'CONTRADICTION', 'score': 0.1149909570813179}]

In [109]:
del nli

In [123]:
results = bertscore.compute(predictions=[p], references=[r], model_type='bert-base-uncased')
results

{'precision': [0.9952623248100281],
 'recall': [0.9955103993415833],
 'f1': [0.9953863620758057],
 'hashcode': 'bert-base-uncased_L9_no-idf_version=0.3.12(hug_trans=4.41.2)'}

In [124]:
p = decoded_base
results = bertscore.compute(predictions=[p], references=[r], model_type='bert-base-uncased')
results

{'precision': [0.39993295073509216],
 'recall': [0.6002448201179504],
 'f1': [0.48003003001213074],
 'hashcode': 'bert-base-uncased_L9_no-idf_version=0.3.12(hug_trans=4.41.2)'}

In [None]:
# del bertscore

# Early Notes
* This is the question you should ask yourself before anything else: what do I want the model to do, and why? Once I figure that out, is there a model out there that already does it? If not, that's when you start thinking about fine-tuning.
* Flexible model: instruction tuning, e.g. "Summarize this into one sentence, summarize this dialogue, etc."
* google/pegasus-xsum and facebook/bart-large-xsum are already pretty good for single-sentence summaries
* facebook/bart-large-cnn and google/pegasus-large - I'm assuming these are pretty good for summaries in general
* You could try to find a dataset that these models don't work well on: https://huggingface.co/datasets/argilla/FinePersonas-Conversations-Email-Summaries?row=18
* Dialog summarization: https://huggingface.co/datasets/knkarthick/dialogsum?row=0
* Question to answer: even the fine-tuned model doesn't summarize the document in one sentence. So what exactly is the objective of my fine-tuning? Remember, I want to change the behavior of my model.
    * **Is it to shorten the summary to one sentence? This might be useful. You could see if you can apply the length objective function.**
    * Is it to summarize in a specific format? Can I find a dataset in that format? Or do I have to spend money to get GPT to do it for me?
    * Is it to get better entailment? For this, I'd need to feed in the original document as the reference, and I don't even know if my entailment model could handle that size.
    * Personally I think bart-large-cnn does a better job of summarization than distilbart-cnn-12-6 (one sample). Is there some way to measure this?
    * I'm not even sure the dataset summary is particularly good (one sample)

# Old Stuff

In [None]:
print(type(preprocessor.dataloader)) # dictionary. Contains 'train', 'val', 'test'.
print(type(preprocessor.dataloader['train'])) # dataloader


first_batch = next(iter(preprocessor.dataloader['train']))
print(type(first_batch)) # dictionary
print(first_batch.keys())
print(type(first_batch['input_ids']))
print(first_batch['input_ids'].shape)
print(type(first_batch['attention_mask']))
print(first_batch['attention_mask'].shape)
print(type(first_batch['labels']))
print(first_batch['labels'].shape)

<class 'dict'>
<class 'torch.utils.data.dataloader.DataLoader'>
<class 'dict'>
dict_keys(['input_ids', 'attention_mask', 'labels'])
<class 'torch.Tensor'>
torch.Size([8, 1024])
<class 'torch.Tensor'>
torch.Size([8, 1024])
<class 'torch.Tensor'>
torch.Size([8, 1024])


In [None]:
# Example code from https://huggingface.co/docs/datasets/en/use_with_pytorch


import numpy as np
from datasets import Dataset
from torch.utils.data import DataLoader

# Data and labels are scalars, so when you batch with size 4, each batch consists of tensors of length 4
# data = np.random.rand(16)
# label = np.random.randint(0, 2, size=16)


# Modify data and label to be vectors of length 7 (16 total samples)
# Now, when I batch with size 4, each batch consists of 4 tensors (samples) of length 7. So it's working correctly.
data = np.random.rand(16,7)
label = np.random.randint(0, 2, size=(16,7))
ds = Dataset.from_dict({"data": data, "label": label}).with_format("torch")
dataloader = DataLoader(ds, batch_size=4)
for batch in dataloader:
    print(batch)

{'data': tensor([[0.1669, 0.0426, 0.8950, 0.5661, 0.5331, 0.0896, 0.0655],
        [0.4167, 0.6302, 0.8909, 0.7454, 0.2768, 0.5057, 0.4715],
        [0.7749, 0.5724, 0.7707, 0.8365, 0.5864, 0.0931, 0.3090],
        [0.9752, 0.2727, 0.1433, 0.6041, 0.2172, 0.5122, 0.3490]]), 'label': tensor([[1, 1, 0, 1, 0, 0, 1],
        [1, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0, 1],
        [1, 1, 0, 0, 0, 1, 0]])}
{'data': tensor([[0.0481, 0.6371, 0.2640, 0.8596, 0.1547, 0.9490, 0.6544],
        [0.2074, 0.4172, 0.1228, 0.5960, 0.4128, 0.8995, 0.5214],
        [0.5998, 0.7333, 0.1849, 0.4986, 0.0315, 0.0149, 0.2415],
        [0.2617, 0.9657, 0.2709, 0.6876, 0.7811, 0.2927, 0.6612]]), 'label': tensor([[0, 0, 1, 1, 1, 0, 0],
        [1, 1, 0, 1, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 1],
        [1, 0, 0, 0, 0, 0, 1]])}
{'data': tensor([[0.3727, 0.3212, 0.9347, 0.9349, 0.3552, 0.0577, 0.5582],
        [0.4870, 0.5731, 0.8875, 0.2646, 0.7427, 0.6187, 0.8907],
        [0.3159, 0.7686, 0.6664, 0.0537

In [None]:
label = np.random.randint(0, 2, size=(16,7))
label = np.ones((16,7))

In [None]:
label.shape

(16, 7)

In [None]:
np.nonzero(label)

(array([ 0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,
         2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,
         4,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  6,  7,  7,
         7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,
         9,  9, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 12,
        12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14,
        14, 14, 14, 15, 15, 15, 15, 15, 15, 15]),
 array([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0,
        1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1,
        2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2,
        3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3,
        4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4,
        5, 6]))

In [None]:
def compute_metrics(eval_pred):
    # Define something here
    return {'my_custom_metric': None}

In [None]:
training_args = TrainingArguments(
    # Save model checkpoints. # For logging and saving model checkpoints, each 'step' refers to one update of the model's weights.
    # By default, this is once per batch of data, but this isn't necessarily true with gradient accumulation.
    # Typically, we only run validation once per epoch, so there's not much benefit to saving checkpoints more than once per epoch.
    # The only time it's beneficial is if your code is prone to crashing - saving more often means you lose less progress when this happens.
    output_dir='./train_results',
    save_strategy='epoch', # save model checkpoints at the end of each epoch
    eval_strategy='epoch',  # run validation at the end of each epoch

    # Loading the best model state based on the desired metric
    load_best_model_at_end=True,  # load the best model found during training
    # metric_for_best_model="eval_loss",  # Metric to monitor for the best model
    metric_for_best_model='my_custom_metric', # this needs to match a key in the returned dictionary from compute_metrics()
    greater_is_better=True, # higher value for 'my_custom_metric' is better

    # Logs report training and validation metrics, which may be visualized (including in real time) using TensorBoard.
    logging_dir='./train_logs',
    logging_strategy='steps',
    logging_steps=100, # frequency of logging - you need to consider your dataset size and your batch size/gradient accumulation

    per_device_train_batch_size=8, # batch size per device (CPU core/GPU/etc.)
    per_device_eval_batch_size=8, # batch size per device (CPU core/GPU/etc.)

    # Training hyperparameters
    num_train_epochs=10, # max epochs
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,  # This is the validation dataset
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],  # Stop after 3 epochs without improvement
    compute_metrics=compute_metrics,  # custom function to compute validation metrics
)

trainer.train()