In [1]:
#Using PEFT
from peft import prepare_model_for_kbit_training

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Prepare the model for K-BiT training
model = prepare_model_for_kbit_training(model)

# Function to print trainable parameters in the model
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

# Prepare the model for K-BiT training again (this seems redundant)
model = prepare_model_for_kbit_training(model)

# Define the LoRA configuration
lora_alpha = 16
lora_dropout = 0.1
lora_rank = 64

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_rank,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "query_key_value",
        "dense",
        "dense_h_to_4h",
        "dense_4h_to_h",
    ]
)

# Obtain the PEFT model using the provided configuration
peft_model = get_peft_model(model, peft_config)

from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

# Set tokenizer pad_token to eos_token
tokenizer.pad_token = tokenizer.eos_token

# Define training arguments
training_args = TrainingArguments(
    gradient_accumulation_steps=4,
    per_device_train_batch_size=1,
    learning_rate=2e-4,
    fp16=True,
    save_total_limit=4,
    logging_steps=25,
    output_dir="output_dir",
    save_strategy='epoch',
    optim="paged_adamw_8bit",
    lr_scheduler_type='cosine',
    warmup_ratio=0.05,
)

# Initialize Trainer for training
trainer = Trainer(
    model=peft_model,
    train_dataset=split_dataset["train"],
    args=training_args,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

# Start the training process
trainer.train()
