In [None]:
import pandas as pd
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model
import torch
from trl import SFTTrainer, setup_chat_format

In [None]:
# Importing the dataset
dataset = load_dataset("yunfan-y/trump-qa", split="train")

# Load the tokenizer and model
model_name = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, attn_implementation="eager")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "v_proj"],
)

# Assuming setup_chat_format is a custom function you've defined elsewhere
# If not, you might need to implement or import it
model, tokenizer = setup_chat_format(model, tokenizer)
model = get_peft_model(model, peft_config)

def format_chat_template(row):    
    row_json = [{"role": "system", "content": row["instruction"]},
               {"role": "user", "content": row["input"]},
               {"role": "assistant", "content": row["output"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

# Apply the formatting to the dataset
dataset = dataset.map(
    format_chat_template
)

print(dataset)

In [None]:
# Setting Hyperparamter
training_arguments = TrainingArguments(
    output_dir="./result",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    logging_steps=1,
    warmup_steps=50,
    logging_strategy="steps",
    learning_rate=2e-5,
    group_by_length=True,
    report_to="wandb",
    weight_decay=0.1,
)
# Setting sft parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length= 512,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)

model.config.use_cache = False
trainer.train()

model.push_to_hub("yunfan-y/trump-gemma-qa")
trainer.push_to_hub("yunfan-y/trump-gemma-qa")