In [1]:
from datasets import load_dataset
from datetime import date
from loguru import logger

from setfit import TrainingArguments, Trainer, SetFitModel

In [2]:
# Set main variables
TRAINING_MODEL = "sentence-transformers/paraphrase-mpnet-base-v2"
TEST_SIZE = 0.15
Training_Arguments = TrainingArguments(batch_size=36, num_epochs=1, seed=123)

In [None]:
def data_load():
    """
    Loads the dataset from Datasets Library
    """
    dataset = load_dataset("SetFit/SentEval-CR")
    # Select N examples per class (8 in this case)
    train_ds = dataset["train"].shuffle(seed=42).select(range(8 * 2))
    test_ds = dataset["test"]
    return train_ds, test_ds

In [None]:
def model_finetuning(MODEL, TrainingArguments, Train_Data, Val_Data) :
    """
    Fine-tunes the specified model on the provided training data and evaluates it on the validation data.
    """
    model = SetFitModel.from_pretrained(MODEL)
    trainer = Trainer(
    model=model,
    args=TrainingArguments,
    train_dataset=Train_Data,
    eval_dataset= Val_Data
    )
    logger.info('fine-tuning the Setfit model on dataset')
    trainer.train()
    logger.info('saving the fine-tuned model')
    model_directory_timestamp = f'{date.today().strftime("%Y%m%d")}-reviews-text-classification'
    trainer.model.save_pretrained(model_directory_timestamp)
    metrics = trainer.evaluate()
    logger.info(f"'Performance of fine-tuned model: , {metrics}")
    return trainer, metrics

In [None]:
def main():
    train_data, val_data = data_load()
    trainer, metrics = model_finetuning(
        TRAINING_MODEL, Training_Arguments, train_data, val_data)
    return trainer, metrics

In [None]:
if __name__== "__main__":
    trainer, metrics = main()