### Import Libraries

In [20]:
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))

from dataset import *
from PIL import Image
import json
import time

import torch
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import AutoProcessor, BitsAndBytesConfig, AutoModelForImageTextToText
from datasets import Dataset

In [None]:
import gc
import time


def clear_memory():
    # Delete variables if they exist in the current global scope
    if "inputs" in globals():
        del globals()["inputs"]
    if "model" in globals():
        del globals()["model"]
    if "processor" in globals():
        del globals()["processor"]
    if "trainer" in globals():
        del globals()["trainer"]
    if "peft_model" in globals():
        del globals()["peft_model"]
    if "bnb_config" in globals():
        del globals()["bnb_config"]
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

### Prepare Training Data

In [3]:
model_id = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"

In [4]:
system_message = """You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data.
Your task is to extract from the image the requested fields and return the result only in a valid JSON format."""

In [26]:
processor = AutoProcessor.from_pretrained(model_id, use_fast=True, padding_side="left")

In [31]:
def format_data(sample):
    pil_image = Image.open(sample.image_path)

    field_names = set([entity.label for entity in sample.entities])

    prompt = "Extract the following {fields} from the image" \
        .format(fields = list(field_names))

    conversation = [
        {"role": "system", "content": system_message},
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": prompt},
            ],
        },
    ]
    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
    return {
        "prompt": prompt,
        "image": pil_image,
        "answer": json.dumps(sample.to_json("kie"))
    }


In [32]:
train_dataset = Dataset.from_list([format_data(sample) for sample in SROIE(tasks=["kie"], split="train")])
test_dataset = Dataset.from_list([format_data(sample) for sample in SROIE(tasks=["kie"], split="test")])

print(train_dataset)

Dataset({
    features: ['prompt', 'image', 'answer'],
    num_rows: 626
})


In [33]:
print(train_dataset[0]["prompt"])
print(train_dataset[0]["image"])
print(train_dataset[0]["answer"])

<|im_start|>System: <end_of_utterance>
User:<image>Extract the following ['company', 'address', 'total', 'date'] from the image<end_of_utterance>
Assistant:
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=463x1013 at 0x7C4B68B265D0>
{"company": "BOOK TA .K (TAMAN DAYA) SDN BHD", "date": "25/12/2018", "address": "NO.53 55,57 & 59, JALAN SAGU 18, TAMAN DAYA, 81100 JOHOR BAHRU, JOHOR.", "total": "9.00"}


### Training Pipeline

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    _attn_implementation="eager",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

In [None]:
USE_LORA = False
USE_QLORA = True

lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=['q_proj', 'v_proj'],
    #init_lora_weights="gaussian"
)

In [None]:
#model = prepare_model_for_kbit_training(model)
#model = get_peft_model(model, lora_config)

In [None]:
from trl import GRPOConfig, GRPOTrainer

training_args = GRPOConfig(
    output_dir="result/smolvlm2-grpo",
    learning_rate=1e-5,
    remove_unused_columns=False,  # to access the solution column in accuracy_reward
    num_train_epochs=1,
    bf16=True,
    # Parameters that control the data preprocessing
    per_device_train_batch_size=2,
    max_completion_length=1024,  # default: 256
    num_generations=2,  # default: 8
    max_prompt_length=2048,
    # Parameters related to reporting and saving
    report_to="none",
    logging_steps=10,
    push_to_hub=False,
    save_strategy="steps",
    save_steps=10,
    use_liger_loss=True
)

In [None]:
def json_parsable(completions, **kwargs) -> list[float]:
    responses = [completion[0] for completion in completions]
    results = []
    for response in responses:
        try:
            json.loads(response)
            results.append(2.0)
        except json.JSONDecodeError:
            results.append(0.0)
    return results

def all_fields_present(completions, **kwargs) -> list[float]:
    labels = {"address", "date", "company", "total"}
    responses = [completion[0] for completion in completions]
    results = []
    for response in responses:
        try:
            response_dict = json.loads(response)
            results.append(2.0 if labels.issubset(response_dict.keys()) else 0.0)
        except json.JSONDecodeError:
            results.append(0.0)
    return results

def correct_labels(completions, answer, **kwargs) -> list[float]:
    labels = {"address", "date", "company", "total"}
    results = []
    for completion, gt in zip(completions, answer):
        score = 0.0
        try:
            completion_dict = json.loads(completion[0])
            gt_dict = json.loads(gt)
            for label in labels:
                pred = completion_dict.get(label, None)
                gt_value = gt_dict.get(label, None)

                if pred == gt_value and pred is not None:
                    score += 1.0
                else:
                    score -= 2.0
            results.append(score)
        except json.JSONDecodeError:
            results.append(0.0)
    return results

# TODO: aggiungere funzione per calcolare editdistance sui campi

In [None]:
from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=model,
    processing_class=processor,
    reward_funcs=[
        json_parsable,
        all_fields_present,
        correct_labels
    ],
    args=training_args,
    train_dataset=train_dataset,
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(training_args.output_dir)