In [1]:
import sys
import ast

In [2]:
# !{sys.executable} -m pip install numpy==1.26.4
# !{sys.executable} -m pip install unsloth vllm==0.7.2
# !{sys.executable} -m pip install -U ipywidgets

In [3]:
import torch
import re
from datasets import load_dataset

In [4]:
from unsloth import FastLanguageModel, is_bfloat16_supported

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Failed to patch Gemma3ForConditionalGeneration.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 04-21 21:59:43 __init__.py:190] Automatically detected platform cuda.


In [5]:
from transformers import (
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    TrainingArguments,
    Trainer
)

In [6]:
def get_connections_questions(SYSTEM_PROMPT, split="train"):
    data = load_dataset('csv', data_files='dataset/final_transformed_connections.csv')[split]
    data = data.train_test_split(test_size=0.1, seed=3407)
    train_data = data['train']
    test_data  = data['test']
    # each x has 'questions' (the 16‐word list) and 'answer' (the reference)
    train_data = train_data.map(lambda x: {
        'prompt': [
            {'role': 'assistant', 'content': SYSTEM_PROMPT},
            {'role': 'user',      'content': x['questions']}
        ],
        'answer': x['answer']
    })
    test_data = test_data.map(lambda x: {
        'prompt': [
            {'role': 'assistant', 'content': SYSTEM_PROMPT},
            {'role': 'user',      'content': x['questions']}
        ],
        'answer': x['answer']
    })
    return train_data, test_data


In [7]:
SFT_SYSTEM_PROMPT = '''
You are playing the NY Times Connections game. Your task is to categorize 16 given words into exactly 4 groups of 4 words each, based on shared common themes.

Only output your final solution in the following format:
<answer>
[['WORD1', 'WORD2', 'WORD3', 'WORD4'],
 ['WORD5', 'WORD6', 'WORD7', 'WORD8'],
 ['WORD9', 'WORD10', 'WORD11', 'WORD12'],
 ['WORD13', 'WORD14', 'WORD15', 'WORD16']]
</answer>

Rules:
	• Each word must belong to one group only.
	• Groups must have a clear, shared theme (e.g., weather, NBA teams, keyboard keys, etc.).
	• Do not include any words not present in the input list.

Here is an example:

USER: [BUCKS, HAIL, JAZZ, SHIFT, LEVEL, MOM, SNOW, RACECAR, SLEET, TAB, KAYAK, RETURN, OPTION, NETS, RAIN, HEAT]

SOLUTION:
[['HAIL', 'RAIN', 'SLEET', 'SNOW'],
 ['BUCKS', 'HEAT', 'JAZZ', 'NETS'],
 ['OPTION', 'RETURN', 'SHIFT', 'TAB'],
 ['KAYAK', 'LEVEL', 'MOM', 'RACECAR']]

Explanation:
- WEATHER TERMS: 'HAIL', 'RAIN', 'SLEET', 'SNOW'
- NBA TEAMS: 'BUCKS', 'HEAT', 'JAZZ', 'NETS'
- KEYBOARD KEYS: 'OPTION', 'RETURN', 'SHIFT', 'TAB'
- PALINDROMES: 'KAYAK', 'LEVEL', 'MOM', 'RACECAR'
'''

In [8]:
# train_ds, test_ds = get_connections_questions(SFT_SYSTEM_PROMPT)

In [16]:
GRPO_SYSTEM_PROMPT = '''
You are playing the NY Times Connections game. Your task is to categorize 16 given words into exactly 4 groups of 4 words each, based on shared common themes.

Solve the puzzle using these clear steps:

1. THINK STEP-BY-STEP: Begin by carefully analyzing the words within <think> tags. Identify their meanings, relationships, and possible groupings logically.

2. PROVIDE FINAL ANSWER: After clearly grouping and justifying all four sets, provide ONLY your final solution within <answer> tags. Format your solution exactly as shown below.

Example:
<think>
......
......
</think>
<answer>
[['HAIL', 'RAIN', 'SLEET', 'SNOW'],
 ['BUCKS', 'HEAT', 'JAZZ', 'NETS'],
 ['OPTION', 'RETURN', 'SHIFT', 'TAB'],
 ['KAYAK', 'LEVEL', 'MOM', 'RACECAR']]
</answer>

Important Notes:
- Categories should be specific
- Words cannot appear in more than one group.
- Categories can include compound words, shared prefixes/suffixes, pop culture references, or common phrases.
- DO NOT ADD NEW WORDS THAT ARE NOT MENTIONED IN THE QUESTION. USE ONLY WORDS MENTIONED AND GROUP THEM

Here is an example:

USER: [BUCKS, HAIL, JAZZ, SHIFT, LEVEL, MOM, SNOW, RACECAR, SLEET, TAB, KAYAK, RETURN, OPTION, NETS, RAIN, HEAT]

SOLUTION:
[['HAIL', 'RAIN', 'SLEET', 'SNOW'],
 ['BUCKS', 'HEAT', 'JAZZ', 'NETS'],
 ['OPTION', 'RETURN', 'SHIFT', 'TAB'],
 ['KAYAK', 'LEVEL', 'MOM', 'RACECAR']]

Explanation:
- WEATHER TERMS: 'HAIL', 'RAIN', 'SLEET', 'SNOW'
- NBA TEAMS: 'BUCKS', 'HEAT', 'JAZZ', 'NETS'
- KEYBOARD KEYS: 'OPTION', 'RETURN', 'SHIFT', 'TAB'
- PALINDROMES: 'KAYAK', 'LEVEL', 'MOM', 'RACECAR'
'''

In [17]:
train_ds, test_ds = get_connections_questions(GRPO_SYSTEM_PROMPT)

In [18]:
# ——————————————————————————————
# 2) Initialize model + LoRA adapter
# ——————————————————————————————

# max_seq_length = 1024
# lora_rank       = 8

# # load base model
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name            = "unsloth/Llama-3.2-3B-Instruct",
#     max_seq_length        = max_seq_length,
#     load_in_4bit          = False,    
#     fast_inference        = True,
#     max_lora_rank         = lora_rank
# )

# # wrap with PEFT
# model = FastLanguageModel.get_peft_model(
#     model              = model,
#     r                  = lora_rank,
#     target_modules     = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
#     lora_alpha         = lora_rank,
#     random_state       = 3407,
# )

# # Make sure we also have a HF‐style tokenizer
# hf_tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-3B-Instruct", use_fast=True)
# hf_tokenizer.pad_token_id = hf_tokenizer.eos_token_id

In [19]:
# # ——————————————————————————————
# # 3) Preprocess: pad labels to 1024
# # ——————————————————————————————
# def preprocess_sft(batch):
#     inputs = [
#         tokenizer.apply_chat_template(p, tokenize=False, add_generation_prompt=False)
#         for p in batch["prompt"]
#     ]
#     model_inputs = hf_tokenizer(
#         inputs,
#         max_length = max_seq_length,
#         padding    = "max_length",
#         truncation = True,
#     )
#     labels = hf_tokenizer(
#         batch["answer"],
#         max_length = max_seq_length,    # ← pad answers out to 1024
#         padding    = "max_length",
#         truncation = True,
#     )
#     model_inputs["labels"] = labels["input_ids"]
#     return model_inputs

# train_enc = train_ds.map(
#     preprocess_sft, batched=True, remove_columns=train_ds.column_names
# )
# test_enc  = test_ds.map(
#     preprocess_sft, batched=True, remove_columns=test_ds.column_names
# )

# data_collator = DataCollatorForSeq2Seq(
#     tokenizer           = hf_tokenizer,
#     model               = model.model,   # unwrapped base
#     label_pad_token_id  = -100,          # ignore padded labels
# )

In [20]:
# # ——————————————————————————————
# # 4) Seq2SeqTrainer with padded labels
# # ——————————————————————————————
# sft_args = TrainingArguments(
#     output_dir                  = "sft_outputs",
#     per_device_train_batch_size = 1,
#     gradient_accumulation_steps = 16,
#     num_train_epochs            = 3,
#     learning_rate               = 5e-5,
#     fp16                        = not is_bfloat16_supported(),
#     bf16                        = is_bfloat16_supported(),
#     logging_steps               = 10,
#     save_steps                  = 100,
#     save_total_limit            = 2,
# )

# sft_trainer = Trainer(
#     model         = model,
#     args          = sft_args,
#     train_dataset = train_enc,
#     eval_dataset  = test_enc,
#     tokenizer     = hf_tokenizer,
#     data_collator = DataCollatorForSeq2Seq(
#         tokenizer          = hf_tokenizer,
#         model              = model.model,
#         label_pad_token_id = -100,
#     ),
# )

In [21]:
# run the SFT pass
# sft_trainer.train()
# model.save_lora("sft_saved_lora_3B")

In [25]:
max_seq_length = 1024
lora_rank = 8

In [23]:
# ——————————————————————————————
# 4) GRPO stage, initialized from SFT adapter
# ——————————————————————————————

# re‑load the same base & adapter structure
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name             = "unsloth/Llama-3.2-3B-Instruct",
    max_seq_length         = max_seq_length,
    load_in_4bit           = False,
    fast_inference         = True,
    max_lora_rank          = lora_rank,
    gpu_memory_utilization = 0.9,
)
model = FastLanguageModel.get_peft_model(
    model          = model,
    r              = lora_rank,
    target_modules = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    lora_alpha     = lora_rank,
)
# load your SFT‐trained adapter
model.load_lora("sft_saved_lora_3B")

==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.51.0. vLLM: 0.7.2.
   \\   /|    NVIDIA A40. Num GPUs = 2. Max memory: 44.451 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/Llama-3.2-3B-Instruct with actual GPU utilization = 89.39%
Unsloth: Your GPU has CUDA compute capability 8.6 with VRAM = 44.45 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 1024. Num Sequences = 320.
Unsloth: vLLM's KV Cache can use up to 33.73 GB. Also swap space = 6 GB.
INFO 04-21 22:01:39 config.py:542] This model supports multiple tasks: {'score', 'generate', 'reward', 'classify', 'embed'}. Defaulting to 'generate'.
INFO 04-21 22:01:39 llm_engine.py:234] Initializing a



INFO 04-21 22:01:40 weight_utils.py:252] Using model weights format ['*.safetensors']


Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


INFO 04-21 22:01:41 model_runner.py:1115] Loading model weights took 6.0316 GB
INFO 04-21 22:01:41 punica_selector.py:18] Using PunicaWrapperGPU.
INFO 04-21 22:01:43 worker.py:267] Memory profiling takes 1.13 seconds
INFO 04-21 22:01:43 worker.py:267] the current vLLM instance can use total_gpu_memory (44.45GiB) x gpu_memory_utilization (0.89) = 39.74GiB
INFO 04-21 22:01:43 worker.py:267] model weights take 6.03GiB; non_torch_memory takes 0.06GiB; PyTorch activation peak memory takes 1.48GiB; the rest of the memory reserved for KV Cache is 32.16GiB.
INFO 04-21 22:01:43 executor_base.py:110] # CUDA blocks: 18819, # CPU blocks: 3510
INFO 04-21 22:01:43 executor_base.py:115] Maximum concurrency for 1024 tokens per request: 294.05x
INFO 04-21 22:01:46 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error

Capturing CUDA graph shapes: 100%|██████████| 43/43 [00:26<00:00,  1.63it/s]

INFO 04-21 22:02:12 model_runner.py:1562] Graph capturing finished in 26 secs, took 0.39 GiB
INFO 04-21 22:02:12 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 30.82 seconds



Unsloth 2025.3.19 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


LoRARequest(lora_name='0', lora_int_id=0, lora_path='sft_saved_lora_3B', lora_tensors=None, lora_config=(None,), lora_local_path=None, long_lora_max_len=None, base_model_name=None, lora_embeddings=None)

In [26]:
def thinking_reward_func(completions, **kwargs) -> list[float]:
    """Reward function for including thinking tags"""
    rewards = []
    for completion in completions:
        try:
            reward = 0.0
            # Extract all thinking blocks
            for message in completion:
                if message["role"] == "assistant" and message.get("content"):
                    content = message["content"]

                    # Count opening and closing tags
                    opening_tags = len(re.findall(r"<think>", content))
                    closing_tags = len(re.findall(r"</think>", content))

                    if opening_tags == 0 or closing_tags == 0:
                        continue

                    if opening_tags == closing_tags:
                        reward += 0.5
                    else:
                        reward += 0.1
            reward = min(reward, 1.5)
            rewards.append(reward)
        except Exception as e:
            print(f"{RED}Error in thinking_reward_func: {e}{RESET}")
            rewards.append(0.0)
    assert len(rewards) == len(completions)
    return rewards


def answer_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function for including answer tags"""
    rewards = []
    for completion in completions:
        try:
            reward = 0.0
            # Extract all answer blocks
            for message in completion:
                if message["role"] == "assistant" and message.get("content"):
                    content = message["content"]

                    # Count opening and closing tags
                    opening_tags = len(re.findall(r"<answer>", content))
                    closing_tags = len(re.findall(r"</answer>", content))

                    if opening_tags == 0 or closing_tags == 0:
                        continue

                    if opening_tags == 1 and closing_tags == 1:
                        reward += 0.5
                    else:
                        reward += 0.1
            reward = min(reward, 1.5)
            rewards.append(reward)
        except Exception as e:
            print(f"{RED}Error in thinking_reward_func: {e}{RESET}")
            rewards.append(0.0)
    assert len(rewards) == len(completions)
    return rewards


def correctness_reward_func(completions, answer, **kwargs) -> list[float]:
    """Safe, robust reward for NYT Connections."""
    rewards = []
    for completion, expected_answer in zip(completions, answer):
        # default
        reward = 0.0
        # 1) Extract last <answer>…</answer> block
        predicted_block = None
        for msg in completion:
# Ensure msg is a dictionary and check its role and content
            if isinstance(msg, dict) and msg.get("role") == "assistant" and "content" in msg:
                matches = re.findall(r"<answer>(.*?)</answer>", msg["content"], re.DOTALL)
                if matches:
                    predicted_block = matches[-1].strip()
        
        if not predicted_block:
            rewards.append(0.0)
            continue

        # 2) Safe parse
        try:
            pred_groups = ast.literal_eval(predicted_block)
            exp_groups  = ast.literal_eval(expected_answer)
        except Exception:
            rewards.append(0.0)
            continue

        # 3) Normalize: keep only lists of strings
        def clean(groups):
            cleaned = []
            for g in groups:
                if isinstance(g, list):
                    cleaned.append([w for w in g if isinstance(w, str)])
            return cleaned

        pred = clean(pred_groups)
        exp  = clean(exp_groups)

        # 4) Scoring
        # Perfect match
        if pred == exp:
            rewards.append(6.0)
            continue

        # +1.5 for each fully correct group (set‐equality)
        used = set()
        for pg in pred:
            for i, eg in enumerate(exp):
                if i in used: 
                    continue
                if set(pg) == set(eg):
                    reward += 1.5
                    used.add(i)
                    break

        # +0.75 for any 3‐word overlap in unmatched groups
        for pg in pred:
            for i, eg in enumerate(exp):
                if i in used:
                    continue
                overlap = len([w for w in pg if w in eg])
                if overlap == 3:
                    reward += 0.75
                    used.add(i)
                    break

        # +0.25 if every predicted group has exactly 4 words
        if all(len(pg) == 4 for pg in pred):
            reward += 0.25

        # +0.5 if no word is repeated across all predicted groups
        flat = [w for pg in pred for w in pg]
        if len(flat) == len(set(flat)):
            reward += 0.5

        rewards.append(reward)

    assert len(rewards) == len(completions)
    return rewards

In [27]:
from trl import GRPOConfig, GRPOTrainer
from rich import print
rl_args = GRPOConfig(
    use_vllm                     = True,
    learning_rate                = 5e-6,
    adam_beta1                   = 0.9,
    adam_beta2                   = 0.99,
    weight_decay                 = 0.1,
    warmup_ratio                 = 0.1,
    lr_scheduler_type            = "cosine",
    optim                        = "adamw_8bit",
    logging_steps                = 1,
    bf16                         = is_bfloat16_supported(),
    fp16                         = not is_bfloat16_supported(),
    per_device_train_batch_size  = 1,
    gradient_accumulation_steps  = 16,
    num_generations              = 4,
    max_prompt_length            = 700,
    max_completion_length        = 700,
    # num_train_epochs             = 1,
    max_steps                    = 200,
    save_steps                   = 50,
    max_grad_norm                = 0.1,
    report_to                    = "none",
    output_dir                   = "outputs_grpo",
    log_completions              = 5, 
)

trainer = GRPOTrainer(
    model            = model,
    processing_class = tokenizer,
    reward_funcs     = [
        thinking_reward_func,
        correctness_reward_func,
        answer_format_reward_func,
    ],
    args             = rl_args,
    train_dataset    = train_ds,  # or a mix of train & test
)

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 4


In [28]:
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 580 | Num Epochs = 12 | Total steps = 200
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 16
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 16 x 1) = 128
 "-____-"     Trainable parameters = 12,156,928/3,224,906,752 (0.38% trained)
Unsloth: Input IDs of length 1029 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Unsloth: Input IDs of length 1027 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.


Unsloth: Will smartly offload gradients to save VRAM!


Unsloth: Input IDs of length 1030 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Unsloth: Input IDs of length 1025 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Unsloth: Input IDs of length 1036 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Unsloth: Input IDs of length 1026 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Unsloth: Input IDs of length 1028 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / thinking_reward_func,rewards / correctness_reward_func,rewards / answer_format_reward_func
1,-0.0,1.521094,1.157509,303.054688,0.0,0.340625,0.851562,0.328906
2,0.0,1.539844,1.156473,272.039062,0.0,0.344531,0.863281,0.332031
3,0.0001,1.269531,1.039544,280.367188,0.0014,0.335156,0.640625,0.29375
4,0.0001,1.521094,1.087884,273.335938,0.001703,0.372656,0.8125,0.335938
5,0.0001,1.500781,1.228315,275.289062,0.001701,0.332031,0.835938,0.332812
6,0.0,1.322656,1.135398,290.351562,0.001233,0.328906,0.699219,0.294531


Unsloth: Input IDs of length 1032 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Unsloth: Input IDs of length 1069 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Unsloth: Input IDs of length 1031 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Unsloth: Input IDs of length 1034 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Unsloth: Input IDs of length 1050 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Unsloth: Input IDs of length 1035 > the model's max sequence length of 1024.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Unsloth: Input IDs of length 1033 > the model's max 

KeyboardInterrupt: 

In [29]:
model.save_lora("grpo_saved_lora_3B_comp1024")

In [48]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : GRPO_SYSTEM_PROMPT},
    {"role" : "user", "content" : "[FLY, SINK, SHOWER, SALSA, TAP, RACE, DIP, MODERN, FALL, SWING, CARROT, TEAR, DROP, TALK, BLAZE, BOOM]"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 2048,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    # lora_request = model.load_lora("grpo_saved_lora_3B"),
)[0].outputs[0].text

Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.82s/it, est. speed input: 81.56 toks/s, output: 68.79 toks/s]


In [50]:
output

"<think>\n- FLY, SHOWER, SHOOTER, DIP all relate to SHOOT, implying they are either actions of shooting or types of shooting, or associated with shooting like a sport or device.\n- SINK, RACE, SWING, DROP all can imply a loss or failure in these words; however, SINK and SWING also relate to furniture, so considering different aspects of these words may also be valid. Also, the words TALK, TEAR, and BLAZE also imply loss or failure in the context of public speaking or burning down something. MODERN and CARROT do not fit into these categories. \n- FLY, SHOOT, RACE, and DIP might be associated with performance and speed in sports, while SINK, SWING, and DROP relate to falls, SINK to furniture, and TEAR to tears and could also be associated with loss of control in sports, TALK and BLAZE to loss of control or mess in public, and MODERN and CARROT do not fit into these categories.\n- FLY, SHOWER, SINK, TALK all have multiple related meanings, TALK also implies loss of control in public as in

In [None]:
[['DIP', 'DROP', 'FALL', 'SINK'], ['BLAZE', 'FLY', 'RACE', 'TEAR'], ['MODERN', 'SALSA', 'SWING', 'TAP'], ['BOOM', 'CARROT', 'SHOWER', 'TALK']]