In [1]:
# Install packages
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm==0.7.3
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.7.3

In [2]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

!pip install huggingface_hub
!pip install -U langchain langchain-experimental transformers accelerate
!pip install -U langgraph
!pip install langchain.core
! pip install langchain-huggingface
! pip install torch
! pip install selenium
! pip install firecrawl-py
! pip install langchain_community
! pip install openai
! pip install langchain_openai
! pip install trl


# Gemma3 Model

In [3]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 4096 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gemma-3-1b-it",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.8, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 05-05 20:13:07 __init__.py:207] Automatically detected platform cuda.
==((====))==  Unsloth 2025.4.7: Fast Gemma3 patching. Transformers: 4.51.3. vLLM: 0.7.3.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Unsloth: Making `model.base_model.model.model` require gradients


# Load dataset which was generated by OpenAI model

In [7]:
import json
from datasets import load_dataset
import re
# Load dataset
try:
    dataset = load_dataset("json", data_files="products.jsonl")["train"]
except Exception as e:
    raise ValueError(f"Failed to load dataset: {e}")

# System prompt
system_prompt = """You are an expert data extractor.
Your task is to extract products and their prices from text.

Rules:
- If any keyword like "organic", "bio", or "eco" is found in the product description, set "is_organic" to "True", otherwise set it to "False".
- Extract all product mentions separately, even if multiple products are in one sentence.
- Standardize the "product" name by removing adjectives like "organic", "premium", "fresh", "natural", etc., and retain only the base product (e.g., "organic carrots" → "carrots").
- Extract the "price" exactly as shown in the text, including any currency symbol (e.g., "$2.99").
- Extract "quantity" as written (e.g., "1kg", "6-pack", "2L").
- Strictly avoid unnecessary explanations or text; only return structured data.

Ignore:
- Any promotional text, ads, filters, or non-product information.

Output format:
Return only a valid JSON array, like this:
[
  {
    "product": "product_name",
    "price": "price_value",
    "quantity": "quantity_value",
    "is_organic": "is_organic_value"
  }
]
"""

def clean_input_content(input_content):
  cleaned_text = re.sub(r'https?://\S+|\[.*?\]\((https?://\S+)\)', '', input_content)
  return cleaned_text.strip()

# Format prompt as a list of chat messages
def build_prompt(example):
    return [
        {"role": "system", "content": system_prompt.strip()},
        {"role": "user", "content": clean_input_content(example.get("input", ""))}

    ]

# Format output cleanly as JSON string if possible
def format_output(example):
    output_data = example.get("output")
    if isinstance(output_data, list):
        try:
            return json.dumps(output_data, indent=2)
        except (TypeError, ValueError):
            pass
    return str(output_data)

# Apply transformation
dataset_with_prompts = dataset.map(
    lambda x: {
        "prompt": build_prompt(x),
        "answer": format_output(x),
    },
    remove_columns=dataset.column_names  # This removes all other fields
)



In [8]:
len(dataset_with_prompts)

2818

#Reward function



In [9]:
import json
from collections import Counter

def normalize_ground_truth(answer):
    ground_truth = []

    if isinstance(answer, dict):
        ground_truth = [answer]
    elif isinstance(answer, list):
        for item in answer:
            if isinstance(item, dict):
                ground_truth.append(item)
            elif isinstance(item, str):
                try:
                    parsed = json.loads(item)
                    if isinstance(parsed, list):
                        ground_truth.extend(parsed)
                    elif isinstance(parsed, dict):
                        ground_truth.append(parsed)
                except Exception as e:
                    print(f"Failed to parse item: {item[:100]}... Error: {e}")
    elif isinstance(answer, str):
        try:
            parsed = json.loads(answer)
            if isinstance(parsed, list):
                ground_truth = parsed
            elif isinstance(parsed, dict):
                ground_truth = [parsed]
        except Exception as e:
            print(f"Failed to parse string answer: {e}")

    return ground_truth

def parse_prediction(response):
    if isinstance(response, (list, dict)):
        return response
    if isinstance(response, str):
        try:
            return json.loads(response.strip())
        except json.JSONDecodeError:
            return None
    return None


def item_match_score(gt_item, pred_item):
    score = 0.0
    if pred_item.get("product", "").strip().lower() == gt_item.get("product", "").strip().lower():
        score += 2.0
    if str(pred_item.get("price", "")).strip() == gt_item.get("price", "").strip():
        score += 2.0
    if str(pred_item.get("quantity", "")).strip() == gt_item.get("quantity", "").strip():
        score += 2.0
    if str(pred_item.get("is_organic", "")).lower() == str(gt_item.get("is_organic", "")).lower():
        score += 1.0
    return score

def structured_data_reward(prompts, completions, answer, **kwargs):
    # Handle ground_truth
    ground_truth = normalize_ground_truth(answer)
    if not ground_truth:
        return [0.0 for _ in completions]

    rewards = []
    for completion in completions:
        response = completion[0]["content"] if completion and isinstance(completion, list) and "content" in completion[0] else ""
        parsed = parse_prediction(response)
        if parsed is None:
            rewards.append(0.0)
            continue

        predicted_output = parsed
        if isinstance(predicted_output, dict):
            predicted_output = [predicted_output]
        elif not isinstance(predicted_output, list) or not all(isinstance(item, dict) for item in predicted_output):
            rewards.append(0.0)
            continue

        # Compute all pairwise match scores to ignore order
        total_score = 0.0
        used_pred_indices = set()
        for gt_item in ground_truth:
            best_score = 0.0
            best_pred_index = None
            for idx, pred_item in enumerate(predicted_output):
                if idx in used_pred_indices:
                    continue
                score = item_match_score(gt_item, pred_item)
                if score > best_score:
                    best_score = score
                    best_pred_index = idx
            if best_pred_index is not None:
                used_pred_indices.add(best_pred_index)
                total_score += best_score

        # Calculate reward
        max_score = len(ground_truth) * 4  # 4 fields per item
        length_penalty = min(len(predicted_output), len(ground_truth)) / max(len(predicted_output), len(ground_truth), 1)
        reward = (total_score / max_score) * length_penalty
        rewards.append(max(0.0, min(1.0, reward)))

    return rewards

In [10]:
max_prompt_length = 287 + 1 # + 1 just in case!
max_seq_length = 1024
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_torch_fused",
    logging_steps = 1,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 1000,
    save_steps = 250,
    max_grad_norm = 0.1,
    num_train_epochs=1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)


Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 4 to the `num_generations` of 8


In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [structured_data_reward],
    args = training_args,
    train_dataset = dataset_with_prompts,
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 2,818 | Num Epochs = 2 | Total steps = 1,000
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 4 x 1) = 32
 "-____-"     Trainable parameters = 52,183,040/1,052,068,992 (4.96% trained)
`generation_config` default values have been modified to match model-specific defaults: {'max_length': 32768, 'top_k': 64, 'top_p': 0.95, 'bos_token_id': 2, 'eos_token_id': [1, 106]}. If this is not desired, please set these values explicitly.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / structured_data_reward
1,0.0,0.0,0.0,156.9375,0.0,0.0
2,0.0,0.0,0.0,43.1875,0.0,0.0
3,0.0,7.8e-05,0.000206,60.15625,0.0,7.8e-05
4,0.0004,0.0,0.0,88.9375,0.010842,0.0
5,0.0001,8e-06,2.2e-05,114.15625,0.002155,8e-06
6,0.0001,0.0,0.0,98.15625,0.002282,0.0
7,0.0002,0.0,0.0,92.375,0.003854,0.0
8,0.0005,1e-05,2.8e-05,124.90625,0.012027,1e-05
9,0.0002,0.0,0.0,99.46875,0.005429,0.0
