In [1]:
import os
from datetime import datetime

# import wandb
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

# Disable parallelism in tokenizers
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Log to Weights & Biases
# wandb.login()

# run name and log dir
run_name = f"grpo_run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
log_dir = os.path.join("../logs", run_name)

# Load dataset
dataset = load_dataset("mlabonne/smoltldr")
print(dataset)
print(dataset["train"][0])

DatasetDict({
    train: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
    test: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
})


In [2]:
# Load model
model_id = "HuggingFaceTB/SmolLM-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
    # attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    lora_alpha=32,
    target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
model.to("mps")
print(model.print_trainable_parameters())


trainable params: 4,884,480 || all params: 139,399,488 || trainable%: 3.5039
None


In [3]:
# Reward function
def reward_len(completions, **kwargs):
    return [-abs(50 - len(completion)) for completion in completions]


In [4]:
# Training arguments
training_args = GRPOConfig(
    output_dir="GRPO",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    max_prompt_length=512,
    max_completion_length=96,
    num_generations=8,
    optim="adamw_torch",
    num_train_epochs=1,
    bf16=False,
    report_to="tensorboard",
    remove_unused_columns=False,
    logging_dir=log_dir,
    logging_steps=1,
    label_names=[],
)

# Trainer
trainer = GRPOTrainer(
    model=model,
    reward_funcs=[reward_len],
    args=training_args,
    train_dataset=dataset["train"],
)


# wandb.init(project="grpo-smoltldr")

# Train model
trainer.train()


KeyboardInterrupt: 

In [None]:
prompt = """
# A long document about the Cat

The cat (Felis catus), also referred to as the domestic cat or house cat, is a small 
domesticated carnivorous mammal. It is the only domesticated species of the family Felidae.
Advances in archaeology and genetics have shown that the domestication of the cat occurred
in the Near East around 7500 BC. It is commonly kept as a pet and farm cat, but also ranges
freely as a feral cat avoiding human contact. It is valued by humans for companionship and
its ability to kill vermin. Its retractable claws are adapted to killing small prey species
such as mice and rats. It has a strong, flexible body, quick reflexes, and sharp teeth,
and its night vision and sense of smell are well developed. It is a social species,
but a solitary hunter and a crepuscular predator. Cat communication includes
vocalizations—including meowing, purring, trilling, hissing, growling, and grunting—as
well as body language. It can hear sounds too faint or too high in frequency for human ears,
such as those made by small mammals. It secretes and perceives pheromones.
"""

messages = [
    {"role": "user", "content": prompt},
]


In [None]:
merged_model = model.merge_and_unload()

In [None]:
# Generate text
from transformers import pipeline

## Or use the model and tokenizer we defined earlier
generator = pipeline("text-generation", model=merged_model, tokenizer=tokenizer)

generate_kwargs = {
    "max_new_tokens": 256,
    "do_sample": True,
    "temperature": 0.5,
    "min_p": 0.1,
}

generated_text = generator(messages, **generate_kwargs)

for m in generated_text[0]["generated_text"]:
    print(m)
