### Import Libraries

In [None]:
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))

from dataset import *
from PIL import Image
import json
import time

In [None]:
from vlmgrpo import VLMGRPOTrainer # YOU MUST IMPORT vlmgrpo before unsloth
from trl import GRPOConfig
from unsloth import FastVisionModel
from unsloth import is_bf16_supported

In [None]:
system_message = """You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data.
Your task is to process and extract meaningful insights from images that are asked in the prompt."""

In [None]:
model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit",
    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)

model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True, # False if not finetuning vision layers
    finetune_language_layers   = True, # False if not finetuning language layers
    finetune_attention_modules = True, # False if not finetuning attention layers
    finetune_mlp_modules       = True, # False if not finetuning MLP layers

    r = 16,           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
    # target_modules = "all-linear", # Optional now! Can specify a list if needed
)

### Training Pipeline

In [None]:
model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit",
    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)

model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True, # False if not finetuning vision layers
    finetune_language_layers   = True, # False if not finetuning language layers
    finetune_attention_modules = True, # False if not finetuning attention layers
    finetune_mlp_modules       = True, # False if not finetuning MLP layers

    r = 16,           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
    # target_modules = "all-linear", # Optional now! Can specify a list if needed
)

In [None]:
def format_data(sample):
    pil_image = Image.open(sample.image_path)

    field_names = set([entity.label for entity in sample.entities])

    prompt = "Extract the following {fields} from the image" \
        .format(fields = list(field_names))

    conversation = [
        {"role": "system", "content": system_message},
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": prompt},
            ],
        },
    ]
    
    return {
        "prompt": conversation,
        "image": [pil_image],
        "answer": json.dumps(sample.to_json("kie"))
    }


In [None]:
train_dataset = [format_data(sample) for sample in SROIE(tasks=["kie"], split="train")]
test_dataset = [format_data(sample) for sample in SROIE(tasks=["kie"], split="test")]

train_dataset[0]

In [None]:
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    bf16 = is_bf16_supported(),
    fp16 = not is_bf16_supported(),
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 2, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    # num_train_epochs = 1, # Set to 1 for a full training run
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

In [None]:
def json_parsable(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    results = []
    for response in responses:
        try:
            json.loads(response)
            results.append(2.0)
        except json.JSONDecodeError:
            results.append(0.0)
    return results

def all_fields_present(completions, **kwargs) -> list[float]:
    labels = {"address", "date", "company", "total"}
    responses = [completion[0]['content'] for completion in completions]
    results = []
    for response in responses:
        try:
            response_dict = json.loads(response)
            results.append(2.0 if labels.issubset(response_dict.keys()) else 0.0)
        except json.JSONDecodeError:
            results.append(0.0)
    return results

def correct_labels(completions, answer, **kwargs) -> list[float]:
    labels = {"address", "date", "company", "total"}
    results = []
    for completion, gt in zip(completions, answer):
        score = 0.0
        try:
            completion_dict = json.loads(completion[0]['content'])
            gt_dict = json.loads(gt)
            for label in labels:
                pred = completion_dict.get(label, None)
                gt_value = gt_dict.get(label, None)

                if pred == gt_value and pred is not None:
                    score += 1.0
                else:
                    score -= 2.0
            results.append(score)
        except json.JSONDecodeError:
            results.append(0.0)
    return results

# TODO: aggiungere funzione per calcolare editdistance sui campi

In [None]:
trainer = VLMGRPOTrainer(
    model=model,
    reward_funcs = [
        json_parsable,
        all_fields_present,
        correct_labels
    ],
    args=training_args,
    train_dataset=train_dataset,
    processing_class=tokenizer, # MUST put unsloth processor here !
    reward_processing_classes = tokenizer, #Here also
    grad_verbose = True #Enable to monitor loss and grad during training 
)

In [None]:

trainer.train()