In [None]:
# Install packages
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm==0.7.3
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.7.3

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

!pip install huggingface_hub
!pip install -U langchain langchain-experimental transformers accelerate
!pip install -U langgraph
!pip install langchain.core
! pip install langchain-huggingface
! pip install torch
! pip install selenium
! pip install firecrawl-py
! pip install langchain_community
! pip install openai
! pip install langchain_openai
! pip install trl
! pip install rapidfuzz
! pip install sentence-transformers


Load Gemma 3 Model

In [None]:
# Load Gemma 3 model
from unsloth import FastLanguageModel
import torch
max_seq_length = 4096
lora_rank = 64

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gemma-3-1b-it",
    max_seq_length = max_seq_length,
    load_in_4bit = False,
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.8, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16,  # Moderate rank for balanced capacity
    target_modules=[
        "q_proj", "v_proj", "up_proj", "down_proj"  # Reduced set for memory efficiency
    ],
    lora_alpha=32,  # 2 * lora_rank for amplified updates
    lora_dropout=0.1,  # Add dropout for regularization
    bias="none",  # Default, no bias adaptation
    use_gradient_checkpointing="unsloth",  # Keep for long contexts
    random_state=3407,  # Keep for reproducibility
)

Load dataset which was generated by OpenAI model

In [None]:
import json
from datasets import load_dataset
import re
# Load dataset
try:
    dataset = load_dataset("json", data_files="products.jsonl")["train"]
except Exception as e:
    raise ValueError(f"Failed to load dataset: {e}")

# System prompt
system_prompt = """You are an expert data extractor.
Your task is to extract products and their prices from text.

Rules:
- If any keyword like "organic", "bio", or "eco" is found in the product description, set "is_organic" to "True", otherwise set it to "False".
- Extract all product mentions separately, even if multiple products are in one sentence.
- Standardize the "product" name by removing only adjectives like "organic", "premium", "fresh", "natural", etc., but retain brand names, specific descriptors (e.g., "Whole"), and the base product.
- Extract the "price" as a monetary value (e.g., "£2.99"), excluding quantities or other text. Use regex like £\d+\.\d{2} if needed.
- Extract "quantity" from the product’s packaging description (e.g., "400g" for sausages sold in 400g packs).
- Strictly avoid unnecessary explanations or text; only return structured data.

Ignore:
- Any promotional text, ads, filters, or non-product information.

Output format:
Return only a valid JSON array, like this:
[
  {
    "product": "product_name",
    "price": "price_value",
    "quantity": "quantity_value",
    "is_organic": "is_organic_value"
  }
]
"""

def clean_input_content(input_content):
  cleaned_text = re.sub(r'https?://\S+|\[.*?\]\((https?://\S+)\)', '', input_content)
  return cleaned_text.strip()

# Format prompt as a list of chat messages
def build_prompt(example):
    return [
        {"role": "system", "content": system_prompt.strip()},
        {"role": "user", "content": clean_input_content(example.get("input", ""))}

    ]

# Format output cleanly as JSON string if possible
def format_output(example):
    output_data = example.get("output")
    if isinstance(output_data, list):
        try:
            return json.dumps(output_data, indent=2)
        except (TypeError, ValueError):
            pass
    return str(output_data)

# Apply transformation
dataset_with_prompts = dataset.map(
    lambda x: {
        "prompt": build_prompt(x),
        "answer": format_output(x),
    },
    remove_columns=dataset.column_names  # This removes all other fields
)



Reward Function

In [None]:
import json
import re
from sentence_transformers import SentenceTransformer, util

similarity_model = SentenceTransformer('all-MiniLM-L6-v2')

UNIT_CONVERSIONS = {
    'g': 1, 'gram': 1, 'grams': 1,
    'kg': 1000, 'kilogram': 1000, 'kilograms': 1000,
    'mg': 0.001, 'milligram': 0.001, 'milligrams': 0.001,
    'l': 1000, 'liter': 1000, 'liters': 1000, 'litre': 1000, 'litres': 1000,
    'ml': 1, 'milliliter': 1, 'milliliters': 1, 'millilitre': 1, 'millilitres': 1,
    'pack': 1, 'packs': 1, 'each': 1, 'unit': 1, 'units': 1
}

def normalize_price(price_str):
    """Remove non-numeric characters (except decimal) from price."""
    if not price_str:
        return ""
    # Remove currency symbols, spaces, and other non-numeric chars
    return re.sub(r'[^\d.]', '', price_str.strip())

def parse_quantity(quantity_str):
    if not quantity_str:
        return None, None
    match = re.match(r'^\s*(\d*\.?\d+)\s*([a-zA-Z\-]+)?\s*$', quantity_str.strip(), re.IGNORECASE)
    if not match:
        return None, quantity_str.strip()
    number = float(match.group(1))
    unit = match.group(2).lower() if match.group(2) else None
    return number, unit

def compare_quantities(pred_quantity, gt_quantity):
    pred_number, pred_unit = parse_quantity(pred_quantity)
    gt_number, gt_unit = parse_quantity(gt_quantity)
    if pred_number is not None and gt_number is not None and pred_unit and gt_unit:
        if pred_unit in UNIT_CONVERSIONS and gt_unit in UNIT_CONVERSIONS:
            pred_value = pred_number * UNIT_CONVERSIONS[pred_unit]
            gt_value = gt_number * UNIT_CONVERSIONS[gt_unit]
            if abs(pred_value - gt_value) / max(gt_value, 1e-6) <= 0.01:
                return 2.0
        return 2.0 * (pred_quantity.strip().lower() == gt_quantity.strip().lower())
    if pred_number is not None and gt_number is not None and not pred_unit and not gt_unit:
        if abs(pred_number - gt_number) / max(gt_number, 1e-6) <= 0.01:
            return 2.0
    if pred_number is None and gt_number is None:
        return 2.0 * (pred_quantity.strip().lower() == gt_quantity.strip().lower())
    return 0.0

def item_match_score(gt_item, pred_item):
    score = 0.0
    pred_product = pred_item.get("product", "").strip().lower()
    gt_product = gt_item.get("product", "").strip().lower()
    if pred_product and gt_product:
        embeddings = similarity_model.encode([pred_product, gt_product], convert_to_tensor=True)
        similarity = util.cos_sim(embeddings[0], embeddings[1]).item()
        score += 2.0 * max(0.0, similarity)
    if normalize_price(str(pred_item.get("price", ""))) == normalize_price(str(gt_item.get("price", ""))):
        score += 2.0
    score += compare_quantities(
        str(pred_item.get("quantity", "")).strip(),
        str(gt_item.get("quantity", "")).strip()
    )
    if str(pred_item.get("is_organic", "")).lower() == str(gt_item.get("is_organic", "")).lower():
        score += 1.0
    return score

def normalize_ground_truth(answer):
    ground_truth = []
    if isinstance(answer, dict):
        ground_truth = [answer]
    elif isinstance(answer, list):
        for item in answer:
            if isinstance(item, dict):
                ground_truth.append(item)
            elif isinstance(item, str):
                try:
                    parsed = json.loads(item)
                    if isinstance(parsed, list):
                        ground_truth.extend(parsed)
                    elif isinstance(parsed, dict):
                        ground_truth.append(parsed)
                except Exception as e:
                    print(f"Failed to parse item: {item[:100]}... Error: {e}")
    elif isinstance(answer, str):
        try:
            parsed = json.loads(answer)
            if isinstance(parsed, list):
                ground_truth = parsed
            elif isinstance(parsed, dict):
                ground_truth = [parsed]
        except Exception as e:
            print(f"Failed to parse string answer: {e}")
    return ground_truth

def parse_prediction(response):
    if json is None or not hasattr(json, 'loads'):
        raise ImportError("json module is not properly imported or has been overwritten")
    if isinstance(response, (list, dict)):
        return response
    if isinstance(response, str):
        try:
            cleaned_response = re.sub(r'(?i)^\s*```?\s*json\s*', '', response)
            cleaned_response = re.sub(r'\s*```?\s*$', '', cleaned_response)
            cleaned_response = re.sub(r'}\s*{', '},{', cleaned_response)
            cleaned_response = re.sub(r'"\s*[^"]*?("[^"]*?)?\s*(\n\s*[\]\}])', r'""\2', cleaned_response, flags=re.DOTALL)
            if cleaned_response.startswith('[') and not cleaned_response.startswith('[{'):
                cleaned_response = re.sub(
                    r'^\[\s*((?:"[^"]*"\s*:\s*"[^"]*"\s*,?\s*)+)',
                    r'[{\1}]',
                    cleaned_response,
                    flags=re.DOTALL
                )
                cleaned_response = re.sub(
                    r'],\s*((?:"[^"]*"\s*:\s*"[^"]*"\s*,?\s*)+)',
                    r'},{\1}]',
                    cleaned_response,
                    flags=re.DOTALL
                )
                cleaned_response = re.sub(r',\s*}', '}', cleaned_response)
            cleaned_response = re.sub(r'\]+$', ']', cleaned_response)
            cleaned_response = re.sub(r'\}+$', '}', cleaned_response)
            cleaned_response = cleaned_response.encode('ascii', errors='ignore').decode('ascii')
            cleaned_response = re.sub(r'}\s*\n\s*\{', '},\n{', cleaned_response)
            cleaned_response = cleaned_response.strip()
            json_match = re.match(r'\s*\[.*?\]\s*$', cleaned_response, re.DOTALL)
            if not json_match:
                print("No valid JSON array found in cleaned_response")
                return None
            if not cleaned_response:
                print("Cleaned response is empty")
                return None
            parsed = json.loads(cleaned_response)
            return parsed
        except json.JSONDecodeError as e:
            print(f"JSONDecodeError: Failed to parse cleaned_response: {repr(cleaned_response)[:200]}... Error: {e}")
            char_pos = e.pos if hasattr(e, 'pos') else None
            if char_pos:
                start = max(0, char_pos - 50)
                end = min(len(cleaned_response), char_pos + 50)
                print(f"Context around error (char {char_pos}): {repr(cleaned_response[start:end])}")
            print(f"Full cleaned_response: {repr(cleaned_response)}")
            try:
                if char_pos:
                    last_valid = cleaned_response[:char_pos].rfind('}')
                    if last_valid != -1:
                        fallback_response = cleaned_response[:last_valid + 1] + ']'
                        print(f"Fallback response: {repr(fallback_response)[:200]}...")
                        parsed = json.loads(fallback_response)
                        return parsed
                    fallback_response = '[]'
                    print(f"Fallback response (empty array): {repr(fallback_response)}")
                    return []
            except json.JSONDecodeError as e2:
                print(f"Fallback JSONDecodeError: {repr(fallback_response)[:200]}... Error: {e2}")
            return None
        except Exception as e:
            print(f"Unexpected error parsing response: {repr(response)[:200]}... Error: {e}")
            return None
    print(f"Invalid input type for parse_prediction: {type(response)}")
    return None

def structured_data_reward(prompts, completions, answer, **kwargs):
    ground_truth = normalize_ground_truth(answer)
    if not ground_truth:
        return [0.0 for _ in completions]

    rewards = []
    for completion in completions:
        response = completion[0]["content"] if completion and isinstance(completion, list) and "content" in completion[0] else ""
        parsed = parse_prediction(response)
        if parsed is None:
            rewards.append(0.0)
            continue

        predicted_output = parsed
        if isinstance(predicted_output, dict):
            predicted_output = [predicted_output]
        elif not isinstance(predicted_output, list) or not all(isinstance(item, dict) for item in predicted_output):
            rewards.append(0.0)
            continue

        total_score = 0.0
        used_pred_indices = set()
        for gt_item in ground_truth:
            best_score = 0.0
            best_pred_index = None
            for idx, pred_item in enumerate(predicted_output):
                if idx in used_pred_indices:
                    continue
                score = item_match_score(gt_item, pred_item)
                if score > best_score:
                    best_score = score
                    best_pred_index = idx
            if best_pred_index is not None:
                used_pred_indices.add(best_pred_index)
                total_score += best_score

        max_score = len(ground_truth) * 7
        length_penalty = min(len(predicted_output), len(ground_truth)) / max(len(predicted_output), len(ground_truth), 1)
        reward = (total_score / max_score) * length_penalty
        rewards.append(max(0.0, min(1.0, reward)))

    if rewards:
        avg_reward = sum(rewards) / len(rewards)
        return [avg_reward] * len(rewards)
    return rewards

GRPO Config

In [None]:
max_prompt_length = 588
max_seq_length=2048
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    weight_decay = 0.01,
    warmup_ratio = 0.2,
    lr_scheduler_type = "cosine",
    optim = "adamw_torch_fused",
    logging_steps = 1,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 1000,
    save_steps = 250,
    max_grad_norm = 1.0,
    num_train_epochs=1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

Train Model

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [structured_data_reward],
    args = training_args,
    train_dataset = dataset_with_prompts,
)
trainer.train()

Save Model

In [None]:
save_directory = "./finetuned_model"

# Save LoRA adapters (recommended for efficiency)
model.save_pretrained(save_directory)  # Saves LoRA adapters
tokenizer.save_pretrained(save_directory)  # Saves tokenizer

# Optional: Save merged model (base model + LoRA adapters)
model.save_pretrained_merged(save_directory, tokenizer, save_method="merged_16bit")  # Merges to 16-bit precision

# Save trainer state (for resuming training)
trainer.save_model(save_directory)
trainer.save_state()