### Import libraries

In [22]:
from PIL import Image
import os
import json

In [24]:
system_message = """You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data. 
Your task is to process and extract meaningful insights from images, 
leveraging multimodal understanding to provide accurate and contextually relevant information."""

### Read and format data

In [14]:
def read_sroie_data(split: str):
    root_path = "../../data/sroie"
    imgs_folder_path = f"{root_path}/{split}/img"
    labels_folder_path = f"../../data/sroie/{split}/entities"

    data = []
    for label_fn in os.listdir(labels_folder_path):
        with open(os.path.join(labels_folder_path, label_fn), "r") as f:
            json_f: dict = json.load(f)

        data.append(dict(
            img_path = os.path.join(imgs_folder_path, label_fn.replace(".txt", ".jpg")),
            label = json_f
        ))

    return data

In [19]:
train_dataset = read_sroie_data("train")
test_dataset = read_sroie_data("test")

train_dataset[0]

{'img_path': '../../data/sroie/train/img/X51005447840.jpg',
 'label': {'company': 'PASARAYA BORONG PINTAR SDN BHD',
  'date': '14/03/2018',
  'address': 'NO 19-G& 19-1& 19-2 JALAN TASIK UTAMA 4, MEDAN NIAGA TASIK DAMAI',
  'total': '8.20'}}

In [25]:
def format_data(sample: dict):
    pil_image = Image.open(sample["img_path"])

    field_names = list(sample["label"].keys())
    output_format = {field: ".." for field in field_names}

    prompt = "Extract the following {fields} from the above document. If a field is not present, return ''. Return the output in a valid JSON format like {output_format}" \
        .format(
            fields = field_names,
            output_format = output_format
        )

    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}]
        },
        {
            "role": "user",
            "content": [
                { "type": "image", "image": pil_image },
                { "type": "text", "text": prompt }
            ]
        },
        {
            "role": "assistant",
            "content": [{ "type": "text", "text": json.dumps(sample["label"])}]
        }
    ]

In [None]:
train_dataset = [format_data(sample) for sample in train_dataset]
test_dataset = [format_data(sample) for sample in test_dataset]

In [27]:
train_dataset[0]

[{'role': 'system',
  'content': [{'type': 'text',
    'text': 'You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data. \nYour task is to process and extract meaningful insights from images, \nleveraging multimodal understanding to provide accurate and contextually relevant information.'}]},
 {'role': 'user',
  'content': [{'type': 'image',
    'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=932x2160>},
   {'type': 'text',
    'text': "Extract the following ['company', 'date', 'address', 'total'] from the above document. If a field is not present, return ''. Return the output in a valid JSON format like {'company': '..', 'date': '..', 'address': '..', 'total': '..'}"}]},
 {'role': 'assistant',
  'content': [{'type': 'text',
    'text': '{"company": "PASARAYA BORONG PINTAR SDN BHD", "date": "14/03/2018", "address": "NO 19-G& 19-1& 19-2 JALAN TASIK UTAMA 4, MEDAN NIAGA TASIK DAMAI", "total": "8.20"}'}]}]

### Training pipeline

In [None]:
# define model

import torch
from transformers import Idefics3ForConditionalGeneration, AutoProcessor
from transformers import BitsAndBytesConfig

model_id = "HuggingFaceTB/SmolVLM-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

model = Idefics3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)

processor = AutoProcessor.from_pretrained(model_id)

In [None]:
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
    use_dora=True,
    init_lora_weights="gaussian",
)

peft_model = get_peft_model(model, peft_config)

peft_model.print_trainable_parameters()

In [None]:
image_token_id = processor.tokenizer.additional_special_tokens_ids[
    processor.tokenizer.additional_special_tokens.index("<image>")
]


def collate_fn(examples):
    texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]

    image_inputs = []
    for example in examples:
        image = example[1]["content"][0]["image"]
        if image.mode != "RGB":
            image = image.convert("RGB")
        image_inputs.append([image])

    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels
    labels[labels == image_token_id] = -100  # Mask image token IDs in labels

    batch["labels"] = labels

    return batch

In [None]:
from trl import SFTConfig

# Configure training arguments using SFTConfig
training_args = SFTConfig(
    output_dir="smolvlm-instruct-trl-sft-sroie",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    save_strategy="steps",
    save_steps=25,
    save_total_limit=1,
    optim="adamw_torch_fused",
    bf16=True,
    remove_unused_columns=False,
    gradient_checkpointing=True,
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
)

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    processing_class=processor.tokenizer,
)

In [None]:
trainer.train()