In [2]:
import pathlib
import shutil

import numpy
from datasets import load_from_disk
from evaluate import load
from peft import LoraConfig, PeftType, TaskType
from sklearn.metrics import accuracy_score, f1_score, fbeta_score, precision_score, recall_score
from torch import Tensor, float16
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    EarlyStoppingCallback,
    EvalPrediction,
    SchedulerType,
    TrainingArguments,
)
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer

In [None]:
step_identifier = pathlib.Path("step_2")

input_directory = pathlib.Path(step_identifier, "input_directory")
working_directory = pathlib.Path(step_identifier, "working_directory")
output_directory = pathlib.Path(step_identifier, "output_directory")

In [None]:
hugging_face_dataset_archive = pathlib.Path(input_directory, "hugging_face_dataset_archive.zip")
hugging_face_dataset_path = pathlib.Path(working_directory, "hugging_face_dataset_directory")

In [None]:
shutil.unpack_archive(hugging_face_dataset_archive, extract_dir=working_directory)

In [None]:
hugging_face_dataset = load_from_disk(hugging_face_dataset_path)

train_subset = hugging_face_dataset["train"]
validation_subset = hugging_face_dataset["validation"]

In [None]:
base_model_identifier = "facebook/opt-350m"

quantisation_configuration = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

mask_token_index = -100

In [None]:
tuning_checkpoints_path = pathlib.Path(working_directory, "tuning_checkpoints_directory")
tuning_checkpoints_archive = pathlib.Path(output_directory, "tuning_checkpoints_archive.zip")

tuned_adapter_path = pathlib.Path(working_directory, "tuned_adapter_directory")
tuned_adapter_archive = pathlib.Path(output_directory, "tuned_adapter_archive.zip")

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    base_model_identifier,
    quantization_config=quantisation_configuration,
    device_map="auto",
    low_cpu_mem_usage=True,
)

model.config.use_cache = False
model.config.pretraining_tp = 1

In [None]:
tokeniser = AutoTokenizer.from_pretrained(base_model_identifier)

tokeniser.pad_token = tokeniser.eos_token
tokeniser.padding_side = "right"

In [None]:
peft_configuration = LoraConfig(
    peft_type=PeftType.LORA,
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    use_rslora=True,
)

In [None]:
training_configuration = TrainingArguments(
    output_dir=str(tuning_checkpoints_path),
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    gradient_accumulation_steps=1,
    eval_delay=3,
    learning_rate=1e-4,
    weight_decay=0.001,
    max_grad_norm=0.3,
    num_train_epochs=50,
    lr_scheduler_type=SchedulerType.REDUCE_ON_PLATEAU,
    warmup_ratio=0.03,
    log_level="error",
    logging_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=5,
    save_safetensors=True,
    save_only_model=True,
    use_cpu=False,
    seed=0,
    data_seed=0,
    fp16=True,
    half_precision_backend="auto",
    fp16_full_eval=False,
    load_best_model_at_end=True,
    metric_for_best_model="eval_google_bleu",
    greater_is_better=True,
    optim="paged_adamw_32bit",
    group_by_length=True,
    report_to=["none"],
    skip_memory_metrics=True,
    push_to_hub=False,
    auto_find_batch_size=True,
)

In [None]:
context_template = " ### Context:"
question_template = " ### Question:"
answer_template = " ### Answer:"

In [None]:
def format_inputs(examples: list[dict[str, str]]) -> list[str]:
    return [
        "\n".join(
            [
                f"{context_template} {examples['context'][counter]}",
                f"{question_template} {examples['question'][counter]}",
                f"{answer_template} {examples['answer'][counter]}",
            ]
        )
        for counter in range(len(examples))
    ]

In [None]:
response_template_with_context = f"\n{answer_template}"
response_template_token_indices = tokeniser.encode(
    response_template_with_context, add_special_tokens=False
)[2:]

collator = DataCollatorForCompletionOnlyLM(
    response_template_token_indices, tokenizer=tokeniser, ignore_index=mask_token_index
)

In [None]:
bleu_metric = load("bleu", module_type="metric")
google_bleu_metric = load("google_bleu", module_type="metric")
rouge_metric = load("rouge", module_type="metric")

In [None]:
def get_positions_of_most_likely_token(logits: Tensor, labels: Tensor | None) -> Tensor:
    del labels

    if isinstance(logits, tuple):
        logits = logits[0]

    return logits.argmax(dim=-1)

In [None]:
def calculate_multi_class_classification_metrics(
    y_true: numpy.ndarray, y_pred: numpy.ndarray
) -> dict[str, float]:
    accuracy = accuracy_score(y_true, y_pred, normalize=False)

    precision = precision_score(y_true, y_pred, average="micro", zero_division=1)
    recall = recall_score(y_true, y_pred, average="micro", zero_division=1)

    f1_balanced = f1_score(y_true, y_pred, average="micro", zero_division=1)
    f1_precision = fbeta_score(y_true, y_pred, beta=0.5, average="micro", zero_division=1)
    f1_recall = fbeta_score(y_true, y_pred, beta=2, average="micro", zero_division=1)

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_balanced": f1_balanced,
        "f1_precision": f1_precision,
        "f1_recall": f1_recall,
    }

In [None]:
def track_validation_metrics(validation_outputs: EvalPrediction) -> dict[str, float]:
    predictions = validation_outputs.predictions
    labels = validation_outputs.label_ids

    if isinstance(predictions, tuple):
        predictions = predictions[0]

    predictions = numpy.where(predictions != mask_token_index, predictions, tokeniser.pad_token_id)
    labels = numpy.where(labels != mask_token_index, labels, tokeniser.pad_token_id)

    decoded_predictions = tokeniser.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokeniser.batch_decode(labels, skip_special_tokens=True)

    bleu_score = bleu_metric.compute(predictions=decoded_predictions, references=decoded_labels)
    google_bleu_score = google_bleu_metric.compute(
        predictions=decoded_predictions, references=decoded_labels
    )
    rouge_score = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels)

    classification_scores = calculate_multi_class_classification_metrics(
        labels.flatten(), predictions.flatten()
    )

    return {**bleu_score, **google_bleu_score, **rouge_score, **classification_scores}

In [None]:
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=10, early_stopping_threshold=0.000001
)

In [None]:
supervised_trainer = SFTTrainer(
    model=model,
    args=training_configuration,
    data_collator=collator,
    train_dataset=train_subset,
    eval_dataset=validation_subset,
    tokenizer=tokeniser,
    compute_metrics=track_validation_metrics,
    callbacks=[early_stopping_callback],
    preprocess_logits_for_metrics=get_positions_of_most_likely_token,
    peft_config=peft_configuration,
    formatting_func=format_inputs,
    packing=False,
    max_seq_length=512,
)

In [None]:
supervised_trainer.train()

In [None]:
supervised_trainer.model.save_pretrained(tuned_adapter_path, safe_serialization=True)

In [None]:
_ = shutil.make_archive(
    str(pathlib.Path(tuned_adapter_archive.parent, tuned_adapter_archive.stem)),
    tuned_adapter_archive.suffix[1:],
    root_dir=working_directory,
    base_dir=tuned_adapter_path.stem,
)

_ = shutil.make_archive(
    str(pathlib.Path(tuning_checkpoints_archive.parent, tuning_checkpoints_archive.stem)),
    tuning_checkpoints_archive.suffix[1:],
    root_dir=working_directory,
    base_dir=tuning_checkpoints_path.stem,
)