In [None]:
import yaml
from peft import LoraConfig, get_peft_model
from transformers import (
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling)
from trl import SFTTrainer
from components.instructor import Instructor

# Load config file into dictionary
with open('./config/instructor_config.yaml', 'r') as file:
    config = yaml.safe_load(file)
print(config)

instructor = Instructor(config)
instructor.load_model()
instructor.load_data(test_size=0.1)
# instructor.tokenize_data()

In [None]:
# Load config file into dictionary -- ONLY CHANGE TRAINING AND LORA PARAMETERS HERE
with open('./config/instructor_config.yaml', 'r') as file:
    config = yaml.safe_load(file)
print(config)

data_collator = DataCollatorForLanguageModeling(tokenizer=instructor.tokenizer, mlm=False)

# LoRA configuration
peft_config = LoraConfig(
    r=config["r"],
    lora_alpha=config["lora_alpha"],
    lora_dropout=config["lora_dropout"],
    bias=config["bias"],
    task_type=config["task_type"],
    target_modules=config["target_modules"]
)

lora_model = get_peft_model(instructor.model, peft_config)
lora_model.print_trainable_parameters()

# Training arguments
training_args = TrainingArguments(
    output_dir=config["output_dir"],
    per_device_train_batch_size=config["per_device_train_batch_size"],
    gradient_accumulation_steps=config["gradient_accumulation_steps"],
    gradient_checkpointing=config["gradient_checkpointing"],
    learning_rate=config["learning_rate"],
    lr_scheduler_type=config["lr_scheduler_type"],
    max_steps=config["max_steps"],
    logging_steps=config["logging_steps"],
    optim=config["optim"],
    warmup_ratio=config["warmup_ratio"],
    report_to=config["report_to"]
)

# Create trainer
trainer = SFTTrainer(
    model=instructor.model,
    args=training_args,
    peft_config=peft_config,
    data_collator=data_collator,
    train_dataset=instructor.train_dataset,
    eval_dataset=instructor.test_dataset,
    dataset_text_field="text",
    
    tokenizer=instructor.tokenizer
)

trainer.train()