### Import Libraries

In [None]:
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))

from dataset import *
from PIL import Image
import json
import time

In [None]:
!pip install unsloth

In [None]:
from unsloth import FastVisionModel
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig

In [None]:
system_message = """You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data.
Your task is to process and extract meaningful insights from images that are asked in the prompt."""

In [None]:
from json2xml import json2xml
from json2xml.utils import readfromstring
from lxml import etree
import base64


def format_data(sample, train_type: str):
    pil_image = Image.open(sample.image_path)

    field_names = set([entity.label for entity in sample.entities])
    if train_type == "xml":
        xml_fields = "".join([f"<{field}>..</{field}>" for field in field_names])
        output_format = f"<kie>{xml_fields}</kie>"
        prompt = "Extract the following {fields} from the above document. If a field is not present, return ''. Return the output in a valid XML format like {output_format}" \
            .format(
                fields = list(field_names),
                output_format = output_format
            )
    else:
        output_format = {field: ".." for field in field_names}

        prompt = "Extract the following {fields} from the above document. If a field is not present, return ''. Return the output in a valid JSON format like {output_format}" \
            .format(
                fields = list(field_names),
                output_format = output_format
            )

    if train_type == "normal":
        conversation = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}]
            },
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image },
                    { "type": "text", "text": prompt }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": json.dumps(sample.to_json("kie"))
                }]
            }
        ]
    elif train_type == "no-prompt":
        conversation = [
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": json.dumps(sample.to_json("kie"))
                }]
            }
        ]
    elif train_type == "xml":
        label = json2xml.Json2xml(
            data=readfromstring(json.dumps(sample.to_json("kie"))),
            wrapper="kie",
            pretty=False,
            attr_type=False
        ).to_xml()
        label = etree.tostring(
            etree.fromstring(label),
            encoding="unicode",
            pretty_print=False
        )

        conversation = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}]
            },
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image },
                    { "type": "text", "text": prompt }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": label
                }]
            }
        ]
    else:
        raise Exception(f"{train_type} value error")

    return conversation

In [None]:
train_type = "normal"

train_dataset = [format_data(sample, train_type) for sample in SROIE(tasks=["kie"], split="train")]
test_dataset = [format_data(sample, train_type) for sample in SROIE(tasks=["kie"], split="test")]

train_dataset[0]

### Training Pipeline

In [None]:
model, tokenizer = FastVisionModel.from_pretrained(
    "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit",
    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)

model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True, # False if not finetuning vision layers
    finetune_language_layers   = True, # False if not finetuning language layers
    finetune_attention_modules = True, # False if not finetuning attention layers
    finetune_mlp_modules       = True, # False if not finetuning MLP layers

    r = 16,           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
    # target_modules = "all-linear", # Optional now! Can specify a list if needed
)

In [None]:
if train_type == "xml":
    for x in ['company', 'date', 'address', 'total']:
        tokenizer.tokenizer.add_tokens([f"<doc_{x}>", f"</doc_{x}>"])

    tokenizer.tokenizer.add_tokens(["<kie>", "</kie>"])
    model.resize_token_embeddings(len(tokenizer.tokenizer))

    model.config.pad_token_id = tokenizer.tokenizer.pad_token_id
    model.config.decoder_start_token_id = tokenizer.tokenizer.convert_tokens_to_ids("<kie>")
    model.config.eos_token_id = tokenizer.tokenizer.convert_tokens_to_ids("</kie>")