### Import libraries

In [1]:
import sys
sys.path.append('/home/rainer/Code/VLM-benchmark')

from dataset import *
from PIL import Image
import json

  from .autonotebook import tqdm as notebook_tqdm


### Prepare Training Data

In [2]:
system_message = """You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data. 
Your task is to process and extract meaningful insights from images, 
leveraging multimodal understanding to provide accurate and contextually relevant information."""

In [None]:
from json2xml import json2xml
from json2xml.utils import readfromstring
from lxml import etree
import base64
from io import BytesIO

def image_to_base64(pil_image):
    buf = BytesIO()
    pil_image.save(buf, format="JPEG")
    return base64.b64encode(buf.getvalue()).decode("utf-8")

def format_data(sample, train_type: str):
    pil_image = Image.open(sample.image_path)

    field_names = set([entity.label for entity in sample.entities])
    if train_type == "xml":
        xml_fields = "".join([f"<{field}>..</{field}>" for field in field_names])
        output_format = f"<kie>{xml_fields}</kie>"
        prompt = "Extract the following {fields} from the above document. If a field is not present, return ''. Return the output in a valid XML format like {output_format}" \
            .format(
                fields = list(field_names),
                output_format = output_format
            )
    else:
        output_format = {field: ".." for field in field_names}

        prompt = "Extract the following {fields} from the above document. If a field is not present, return ''. Return the output in a valid JSON format like {output_format}" \
            .format(
                fields = list(field_names),
                output_train_type = "normal"

train_dataset = [format_data(sample, train_type) for sample in SROIE(tasks=["kie"], split="train")]
test_dataset = [format_data(sample, train_type) for sample in SROIE(tasks=["kie"], split="test")]

train_dataset[0]format = output_format
            )

    if train_type == "normal":
        conversation = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}]
            },
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image },
                    { "type": "text", "text": prompt }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": sample.to_json("kie")
                }]
            }
        ]
    elif train_type == "no-prompt":
        conversation = [
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": sample.to_json("kie")
                }]
            }
        ]
    elif train_type == "xml":
        label = json2xml.Json2xml(
            data=readfromstring(json.dumps(sample.to_json("kie"))),
            wrapper="kie",
            pretty=False,
            attr_type=False
        ).to_xml()
        label = etree.tostring(
            etree.fromstring(label),
            encoding="unicode",
            pretty_print=False
        )

        conversation = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}]
            },
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image },
                    { "type": "text", "text": prompt }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": label
                }]
            }
        ]
    else:
        raise Exception(f"{train_type} value error")

    return conversation

In [10]:
train_type = "normal"

train_dataset = [format_data(sample, train_type) for sample in SROIE(tasks=["kie"], split="train")]
test_dataset = [format_data(sample, train_type) for sample in SROIE(tasks=["kie"], split="test")]

train_dataset[0]

[{'role': 'system',
  'content': [{'type': 'text',
    'text': 'You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data. \nYour task is to process and extract meaningful insights from images, \nleveraging multimodal understanding to provide accurate and contextually relevant information.'}]},
 {'role': 'user',
  'content': [{'type': 'image',
    'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=463x1013>},
   {'type': 'text',
    'text': "Extract the following ['date', 'address', 'total', 'company'] from the above document. If a field is not present, return ''. Return the output in a valid XML format like <kie><date>..</date><address>..</address><total>..</total><company>..</company></kie>"}]},
 {'role': 'assistant',
  'content': [{'type': 'text',
    'text': '<kie><company>BOOK TA .K (TAMAN DAYA) SDN BHD</company><date>25/12/2018</date><address>NO.53 55,57 &amp; 59, JALAN SAGU 18, TAMAN DAYA, 81100 JOHOR BAHRU, JOHOR.</address><total>

### Training pipeline

In [None]:
# define model

import torch
from transformers import Idefics3ForConditionalGeneration, AutoProcessor
from transformers import BitsAndBytesConfig

model_id = "HuggingFaceTB/SmolVLM-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

model = Idefics3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

processor = AutoProcessor.from_pretrained(model_id)

In [None]:
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
    use_dora=True,
    init_lora_weights="gaussian",
)

peft_model = get_peft_model(model, peft_config)

peft_model.print_trainable_parameters()

In [None]:
image_token_id = processor.tokenizer.additional_special_tokens_ids[
    processor.tokenizer.additional_special_tokens.index("<image>")
]


def collate_fn(examples):
    texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]

    image_inputs = []
    for example in examples:
        image = example[1]["content"][0]["image"]
        if image.mode != "RGB":
            image = image.convert("RGB")
        image_inputs.append([image])

    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels
    labels[labels == image_token_id] = -100  # Mask image token IDs in labels

    batch["labels"] = labels

    return batch

In [None]:
from trl import SFTConfig

# Configure training arguments using SFTConfig
training_args = SFTConfig(
    output_dir="smolvlm-instruct-trl-sft-sroie",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    save_strategy="steps",
    save_steps=25,
    save_total_limit=1,
    optim="adamw_torch_fused",
    bf16=True,
    remove_unused_columns=False,
    gradient_checkpointing=True,
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
)

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    processing_class=processor.tokenizer,
)

In [None]:
trainer.train()