In [1]:
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
import wandb

In [2]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mvaknin34[0m ([33mvaknin34-hugging-face[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
dataset = load_dataset("mlabonne/smoltldr")
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
    test: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
})


In [4]:
print(dataset["train"][0]["prompt"])

SUBREDDIT: r/tifu

TITLE: TIFU by trying to pet a dog.


TL;DR:


In [5]:
print(dataset["train"][0]["completion"])

 Tried to pet a dog, foot got impaled by a demon stick, never even got to pet the dog.


In [23]:
model_id = "HuggingFaceTB/SmolLM-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')

# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [24]:
from transformers import pipeline
text_generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)

Device set to use mps


In [25]:
# Create the baseline for both metrics we will optimize
from sentence_transformers import SentenceTransformer
import torch.nn.functional as F

# The output model should the same length as the answer
def same_length(prompts, completions, answers, **kwargs):
    # Reward for 2/difference in length
    return [2 / (abs(len(completion) - len(answer)) + 1) for completion, answer in zip(completions, answers)]

# Semantic similarity using embeddings
# Load a pre-trained sentence transformer model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device="mps")
def semantic_similarity_embeddings(prompts, completions, answers, **kwargs):
    # Get embeddings for completions and answers
    completion_embeddings = embedding_model.encode(completions, convert_to_tensor=True)
    answer_embeddings = embedding_model.encode(answers, convert_to_tensor=True)

    # Compute cosine similarity between embeddings
    cosine_scores = F.cosine_similarity(completion_embeddings, answer_embeddings, dim=-1).tolist()
    return cosine_scores

In [26]:
from tqdm import tqdm

In [27]:
prompts = dataset["test"]["prompt"]
answers = dataset["test"]["completion"]

# Generate completions in batches for efficiency
batch_size = 100
completions = []
for i in tqdm(range(0, len(prompts), batch_size), desc="Generating completions"):
    batch_prompts = prompts[i:i+batch_size]
    batch_completions = text_generator(batch_prompts, do_sample=False, batch_size=batch_size)
    completions.extend([output[0]["generated_text"] for output in batch_completions])

Generating completions:   0%|          | 0/2 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Generating completions:  50%|█████     | 1/2 [01:45<01:45, 105.30s/it]The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Generating completions: 100%|██████████| 2/2 [03:29<00:00, 104.81s/it]
Generating completions: 100%|██████████| 2/2 [03:29<00:00, 104.81s/it]


In [31]:
# Compute same length scores
same_length_score = same_length(prompts, completions, answers)
print(f"Same length score: {sum(same_length_score) / len(same_length_score)}")

# Compute cosine similarity scores
cosine_scores = semantic_similarity_embeddings(prompts, completions, answers)
print(f"Cosine similarity score: {sum(cosine_scores) / len(cosine_scores)}")

Same length score: 0.0022290948449059946
Cosine similarity score: 0.577838778346777


In [32]:
# Load LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=32,
    lora_alpha=32,
    target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())

trainable params: 9,768,960 || all params: 144,283,968 || trainable%: 6.7706
None


In [43]:
# Training arguments
training_args = GRPOConfig(
    output_dir="GRPO",
    learning_rate=2e-5,
    per_device_train_batch_size=4,  # Reduced batch size for stability
    gradient_accumulation_steps=4,  # Increased to maintain effective batch size
    max_prompt_length=512,
    max_completion_length=96,
    num_generations=4,  # Reduced for stability
    bf16=False,  # Keep False for MPS
    fp16=False,  # Ensure no mixed precision on MPS
    dataloader_pin_memory=False,  # Disable pin_memory for MPS
    report_to=["wandb"],
    remove_unused_columns=False,
    logging_steps=1,
    max_steps=100,  # Reduced for testing
    label_names=["completion"],
    temperature=0.7,  # Add temperature for more stable generation
    top_p=0.9,  # Add nucleus sampling
)

In [61]:
import math

class NaNGuardCallback:
    def on_step_end(self, args, state, control, **kwargs):
        for k, v in state.log_history[-1].items():
            if isinstance(v, (float, int)):
                if not math.isfinite(v):
                    raise RuntimeError(f"⚠️  {k} became {v} at step {state.global_step}")
    def on_train_begin(self, args, state, control, **kwargs):
        print("NaNGuardCallback initialized. Monitoring for NaN values during training.")
    def on_init_end(self, args, state, control, **kwargs):
        print("NaNGuardCallback initialized. Monitoring for NaN values during training.")
    def on_epoch_begin(self, args, state, control, **kwargs):
        print("Epoch started. Monitoring for NaN values during training.")
    def on_step_begin(self, args, state, control, **kwargs):
        print("Step started. Monitoring for NaN values during training.")

In [62]:
# Trainer
trainer = GRPOTrainer(
    model=model,
    reward_funcs=[same_length, semantic_similarity_embeddings],
    args=training_args,
    train_dataset=dataset["train"],
    processing_class=tokenizer,
    callbacks=[NaNGuardCallback()],
)

NaNGuardCallback initialized. Monitoring for NaN values during training.


In [None]:
# Train model
wandb.init(project="GRPO")
try:
    trainer.train()
except Exception as e:
    print(f"Training failed with err|or: {e}")
    print("Consider reducing batch size or learning rate further.")

NaNGuardCallback initialized. Monitoring for NaN values during training.
Epoch started. Monitoring for NaN values during training.
Step started. Monitoring for NaN values during training.
Training failed with err|or: probability tensor contains either `inf`, `nan` or element < 0
Consider reducing batch size or learning rate further.
