In [None]:
import torch
import matplotlib.pyplot as plt
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    TrainerCallback
)
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import os

In [None]:
# config dataset path.
model_path = r""
data_path = r""
output_path = r""

# force to use GPU.
assert torch.cuda.is_available(), "Use GPU!"
device = torch.device("cuda")

In [None]:
# Implement a custom callback to record the change of loss value
# in real time during model training.
class LossCallback(TrainerCallback):
    def __init__(self):
        self.losses = []

    def on_log(self, args, state, control, logs=None, **kwargs):
        if "loss" in logs:
            self.losses.append(logs["loss"])

In [None]:
def process_data(tokenizer):
    dataset = load_dataset("json", data_files=data_path, split="train[:1500]")

    def format_example(example):
        instruction = f"Question: {example['Question']}\nAnalysis: {example['Complex_CoT']}"
        inputs = tokenizer(
            f"{instruction}\n### Answer: \n{example['Response']}<|endoftext|>",
            padding="max_length",
            truncation=True,
            max_length=512,
            return_tensors="pt"
        )
        return {"input_ids": inputs["input_ids"].squeeze(0), "attention_mask": inputs["attention_mask"].squeeze(0)}

    return dataset.map(format_example, remove_columns=dataset.column_names)

In [None]:
# LoRA configuration.
# r means the rank of the low-rank decomposition.
# lora_alpha is the scaling factor.
# target_modules are the modules to be decomposed.
# lora_dropout is the dropout rate.
# task_type is the task type -- Causal Language Model.
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

In [None]:
# Training arguments.
training_args = TrainingArguments(
    output_dir=output_path,
    per_device_train_batch_size=2,  # storage limited.
    gradient_accumulation_steps=4,  # accumulate gradient, batch_size=8
    num_train_epochs=3,
    learning_rate=3e-4,
    fp16=True,  # open fp16, accelerate training.
    logging_steps=20,
    save_strategy="no",
    report_to="none",
    optim="adamw_torch",
    no_cuda=False,
    dataloader_pin_memory=False,  # use pinned memory to accelerate training.
    remove_unused_columns=False  # prevent error.
)

In [None]:
def main():
    # create output path.
    os.makedirs(output_path, exist_ok=True)

    # load tokenizer.
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token

    # load model.
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16,
        device_map={"": device}
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    # preprocess data.
    dataset = process_data(tokenizer)

    # loss callback.
    loss_callback = LossCallback()

    # data collator.
    def data_collator(data):
        batch = {
            "input_ids": torch.stack([torch.tensor(d["input_ids"]) for d in data]).to(device),
            "attention_mask": torch.stack([torch.tensor(d["attention_mask"]) for d in data]).to(device),
            # use input_ids as labels.
            "labels": torch.stack([torch.tensor(d["input_ids"]) for d in data]).to(device)
        }
        return batch

    # create trainer.
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=data_collator,
        callbacks=[loss_callback]
    )

    # start training.
    print("Start training...")
    trainer.train()

    # save model.
    trainer.model.save_pretrained(output_path)
    print(f"Model saved to {output_path}")

    # plot loss curve.
    plt.figure(figsize=(10, 6))
    plt.plot(loss_callback.losses)
    plt.title("Training Loss Curve")
    plt.xlabel("Steps")
    plt.ylabel("Loss")
    plt.savefig(os.path.join(output_path, "loss_curve.png"))
    print("Loss curve saved to loss_curve.png")

if __name__ == "__main__":
    main()