### Install Requirements

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
!pip install transformers==4.53.3 \
             trl \
             datasets \
             bitsandbytes \
             peft \
             accelerate \
             pdf2image \
             json2xml \
             num2words

#!pip install -q flash-attn --no-build-isolation

Collecting transformers==4.53.3
  Downloading transformers-4.53.3-py3-none-any.whl.metadata (40 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/40.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.9/40.9 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting trl
  Downloading trl-0.21.0-py3-none-any.whl.metadata (11 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl.metadata (11 kB)
Collecting pdf2image
  Downloading pdf2image-1.17.0-py3-none-any.whl.metadata (6.2 kB)
Collecting json2xml
  Downloading json2xml-5.2.1-py3-none-any.whl.metadata (8.1 kB)
Collecting num2words
  Downloading num2words-0.5.14-py3-none-any.whl.metadata (13 kB)
INFO: pip is looking at multiple versions of trl to determine which version is compatible with other requirements. This could take a while.
Collecting trl
  Downloading trl-0.20.0-py3-none-any.whl.metadata (11 kB)
Colle

### Prepare Training Data

In [None]:
!mkdir data
#!cp -r /content/drive/MyDrive/Dataset/Tirocinio/docile.zip /content/data
!cp -r /content/drive/MyDrive/Dataset/Tirocinio/sroie.zip /content/data
#!cd data ; unzip docile.zip
!cd data ; unzip sroie.zip
#!cp -r /content/drive/MyDrive/Dataset/Tirocinio/kie.zip /content/data

Archive:  sroie.zip
   creating: sroie/
   creating: sroie/train/
   creating: sroie/train/img/
  inflating: sroie/train/img/X51006441474.jpg  
  inflating: sroie/train/img/X51005433533.jpg  
  inflating: sroie/train/img/X51006714065.jpg  
  inflating: sroie/train/img/X51005712017.jpg  
  inflating: sroie/train/img/X51005433494.jpg  
  inflating: sroie/train/img/X51005745183.jpg  
  inflating: sroie/train/img/X51007846307.jpg  
  inflating: sroie/train/img/X51007339156.jpg  
  inflating: sroie/train/img/X51005757286.jpg  
  inflating: sroie/train/img/X51007339135.jpg  
  inflating: sroie/train/img/X51006414703.jpg  
  inflating: sroie/train/img/X51005361946.jpg  
  inflating: sroie/train/img/X51008142032.jpg  
  inflating: sroie/train/img/X51006713996.jpg  
  inflating: sroie/train/img/X51006557178.jpg  
  inflating: sroie/train/img/X51005806702.jpg  
  inflating: sroie/train/img/X51006913055.jpg  
  inflating: sroie/train/img/X51006414679.jpg  
  inflating: sroie/train/img/X5100733909

In [None]:
from enum import Enum
from pydantic import BaseModel
from typing import ClassVar
from pdf2image import convert_from_path
from PIL import Image
import json
import os


class Task(Enum):
    CLS = "cls"
    KIE = "kie"
    OCR = "ocr"
    VQA = "vqa"
    OBJ = "obj"

class BBox(BaseModel):
    x1: int
    y1: int
    x2: int
    y2: int

    def get_coords(self) -> list[int]:
        return [self.x1, self.y1, self.x2, self.y2]

class Field(BaseModel):
    label: str
    value: str
    bbox: BBox | None = None

class VQA(BaseModel):
    question: str
    answer: str

class Classification(BaseModel):
    doc_type: str
    labels: list[str]

class Data(BaseModel):
    image_path: str
    fields: list[Field] | None = None
    entities: list[Field] | None = None
    objects: list[Field] | None = None
    vqa: list[VQA] | None = None
    cls: Classification | None = None

    def to_json(self, task: str):
        json_result = {}
        if task == "kie":
            for entity in self.entities:
                json_result[entity.label] = entity.value

        return json_result

class Dataset(BaseModel):
    tasks: list[Task] = []
    split: str
    data: list[Data] = []

    def read_folder(self, path: str) -> list[str]:
        folder = os.listdir(path)
        folder.sort()
        return folder

    def __iter__(self):
        return self.data.__iter__()

    def _convert_to_format(self, task: Task, item: dict) -> Field | VQA | Classification:
        '''
        This functions converts the `item` into "Field", "VQA" or "Classification" type based on the `task` parameter
        '''
        processed = None

        if task == Task.CLS:
            processed = Classification(
                doc_type = item["doc_type"],
                labels = item["labels"]
            )
        elif task == Task.KIE:
            processed = Field(
                label = item["label"],
                value = item["value"],
                bbox = None
            )
        elif task == Task.OCR:
            x1, y1, x2, y2 = tuple(item["bbox"])
            processed = Field(
                label = "text",
                value = item["text"],
                bbox = BBox(x1=x1, y1=y1, x2=x2, y2=y2)
            )
        elif task == Task.VQA:
            processed = VQA(
                question = item["question"],
                answer = item["answer"]
            )
        elif task == Task.OBJ:
            processed = Field(
                label = "object",
                value = None,
                bbox = BBox()
            )
        else:
            raise Exception(f"Task {task} does not exist")

        return processed

class DocILE(Dataset):
    TASKS: ClassVar[list[Task]] = [Task.OCR, Task.KIE]

    def __init__(
        self,
        tasks: list[Task],
        split: str
    ) -> None:
        super().__init__(tasks=tasks, split=split)
        self._convert_pdf_to_img()
        self._load_data()

        if split == "test":
            self.split = "val"

    def _convert_pdf_to_img(self):
        '''
        Converts pdfs into images
        '''
        for fn in os.listdir(f"./data/docile/pdfs"):
            if fn.endswith(".pdf"):
                images = convert_from_path(
                    pdf_path=f"./data/docile/pdfs/{fn}",
                    dpi=200,
                    fmt="jpg"
                )
                images[0].save(f"./data/docile/pdfs/{fn.strip('.pdf')}.jpg")
                os.remove(f"./data/docile/pdfs/{fn}")

    def _load_data(self) -> None:
        split_file: list[str] = json.load(open(f"./data/docile/{self.split}.json", "r"))
        split_file.sort()

        for img_fn in split_file:
            fields, entities = [], []
            if Task.OCR in self.tasks:
                label: dict = json.load(open(f"./data/docile/ocr/{img_fn}.json", "r"))
                if os.path.exists(f"./data/docile/pdfs/{img_fn}.jpg"):
                    img = Image.open(f"./data/docile/pdfs/{img_fn}.jpg")
                    for block in label["pages"][0]["blocks"]:
                        for line in block["lines"]:
                            for word in line["words"]:
                                x1, y1 = word["geometry"][0]
                                x2, y2 = word["geometry"][1]

                                fields.append(self._convert_to_format(
                                    task = Task.OCR,
                                    item = dict(
                                        bbox = [int(x1 * img.width), int(y1 * img.height), int(x2 * img.width), int(y2 * img.height)],
                                        text = word["value"]
                                    )
                                ))

            if Task.KIE in self.tasks:
                label: dict = json.load(open(f"./data/docile/annotations/{img_fn}.json", "r"))
                if os.path.exists(f"./data/docile/pdfs/{img_fn}.jpg"):
                    for extraction in label["field_extractions"]:
                        entities.append(self._convert_to_format(
                            task = Task.KIE,
                            item = dict(
                                label = extraction["fieldtype"],
                                value = extraction["text"]
                            )
                        ))

            if len(fields) > 0 or len(entities) > 0:
                self.data.append(Data(
                    image_path=f"./data/docile/pdfs/{img_fn}.jpg",
                    fields=fields if fields else None,
                    entities=entities if entities else None
                ))

class SROIE(Dataset):
    TASKS: ClassVar[list[Task]] = [Task.OCR, Task.KIE]

    def __init__(
        self,
        tasks: list[Task],
        split: str
    ) -> None:
        super().__init__(tasks=tasks, split=split)
        self._load_data()

    def __extract_bbox_and_text(self, line: str) -> tuple[tuple[int], str]:
        coords = [int(x) for x in line[:8]]
        text = line[8] if len(line) > 8 else ""

        coords = [coords[0], coords[1], coords[0], coords[1]]
        for i in range(0, len(coords), 2):
            coords[0] = min(coords[0], coords[i])
            coords[1] = min(coords[1], coords[i + 1])
            coords[2] = max(coords[2], coords[i])
            coords[3] = max(coords[3], coords[i + 1])

        return coords, text

    def _load_data(self) -> None:
        images = self.read_folder(f"./data/sroie/{self.split}/img")

        for image in images:
            label = image.replace(".jpg", ".txt")
            fields, entities = [], []

            # For OCR task
            if Task.OCR in self.tasks:
                with open(f"./data/sroie/{self.split}/box/{label}", "r") as f:
                    rows = f.readlines()
                    rows.sort()
                    for row in rows:
                        row = row.strip("\n").split(',', 8)
                        if len(row) >= 8:
                            coords, text = self.__extract_bbox_and_text(row)

                            fields.append(self._convert_to_format(
                                task=Task.OCR,
                                item = dict(
                                    bbox = coords,
                                    text = text
                                )
                            ))

            # For KIE task
            if Task.KIE in self.tasks:
                with open(f"./data/sroie/{self.split}/entities/{label}", "r") as f:
                    json_f: dict = json.load(f)
                    for key, item in json_f.items():
                        entities.append(self._convert_to_format(
                            task=Task.KIE,
                            item = dict(
                                label = key,
                                value = item
                            )
                        ))

            self.data.append(Data(
                image_path=f"./data/sroie/{self.split}/img/{image}",
                fields=fields if fields else None,
                entities=entities if entities else None
            ))

class MultiDataset():
    def __init__(
        self,
        selections: list[tuple[Dataset, list[Task]]],
        split: str
    ) -> None:
        self.data: list[Data] = []
        self._load_data(selections, split)

    def __iter__(self):
        return self.data.__iter__()

    def _load_data(
        self,
        selections: list[tuple[Dataset, list[Task]]],
        split: str
    ) -> None:
        for selection in selections:
            dataset: Dataset = selection[0]
            tasks: list[Task] = selection[1]

            self.data += dataset(tasks,split).data

In [None]:
system_message = """You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data.
Your task is to process and extract meaningful insights from images, leveraging multimodal understanding
to provide accurate and contextually relevant information."""

In [None]:
from json2xml import json2xml
from json2xml.utils import readfromstring
from lxml import etree
import base64
from io import BytesIO

def image_to_base64(pil_image):
    buf = BytesIO()
    pil_image.save(buf, format="JPEG")
    return base64.b64encode(buf.getvalue()).decode("utf-8")

def format_data(sample: Data, train_type: str):
    pil_image = Image.open(sample.image_path)

    field_names = set([entity.label for entity in sample.entities])
    if train_type == "xml":
        xml_fields = "".join([f"<{field}>..</{field}>" for field in field_names])
        output_format = f"<kie>{xml_fields}</kie>"
        prompt = "Extract the following {fields} from the above document. If a field is not present, return ''. Return the output in a valid XML format like {output_format}" \
            .format(
                fields = list(field_names),
                output_format = output_format
            )
    else:
        output_format = {field: ".." for field in field_names}

        prompt = "Extract the following {fields} from the above document. If a field is not present, return ''. Return the output in a valid JSON format like {output_format}" \
            .format(
                fields = list(field_names),
                output_format = output_format
            )

    if train_type == "normal":
        conversation = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}]
            },
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image },
                    { "type": "text", "text": prompt }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": json.dumps(sample.to_json("kie"))
                }]
            }
        ]
    elif train_type == "no-prompt":
        conversation = [
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": json.dumps(sample.to_json("kie"))
                }]
            }
        ]
    elif train_type == "xml":
        label = json2xml.Json2xml(
            data=readfromstring(json.dumps(sample.to_json("kie"))),
            wrapper="kie",
            pretty=False,
            attr_type=False
        ).to_xml()
        label = etree.tostring(
            etree.fromstring(label),
            encoding="unicode",
            pretty_print=False
        )

        conversation = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}]
            },
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image },
                    { "type": "text", "text": prompt }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": label
                }]
            }
        ]
    else:
        raise Exception(f"{train_type} value error")

    return conversation

In [None]:
train_type = "normal"

train_dataset = [format_data(sample, train_type) for sample in SROIE(tasks=["kie"], split="train")]
test_dataset = [format_data(sample, train_type) for sample in SROIE(tasks=["kie"], split="test")]

train_dataset[0]

[{'role': 'system',
  'content': [{'type': 'text',
    'text': 'You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data.\nYour task is to process and extract meaningful insights from images, leveraging multimodal understanding\nto provide accurate and contextually relevant information.'}]},
 {'role': 'user',
  'content': [{'type': 'image',
    'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=463x1013>},
   {'type': 'text',
    'text': "Extract the following ['date', 'company', 'total', 'address'] from the above document. If a field is not present, return ''. Return the output in a valid JSON format like {'date': '..', 'company': '..', 'total': '..', 'address': '..'}"}]},
 {'role': 'assistant',
  'content': [{'type': 'text',
    'text': '{"company": "BOOK TA .K (TAMAN DAYA) SDN BHD", "date": "25/12/2018", "address": "NO.53 55,57 & 59, JALAN SAGU 18, TAMAN DAYA, 81100 JOHOR BAHRU, JOHOR.", "total": "9.00"}'}]}]

### Training Pipeline

In [None]:
import gc
import time


def clear_memory():
    # Delete variables if they exist in the current global scope
    if "inputs" in globals():
        del globals()["inputs"]
    if "model" in globals():
        del globals()["model"]
    if "processor" in globals():
        del globals()["processor"]
    if "trainer" in globals():
        del globals()["trainer"]
    if "peft_model" in globals():
        del globals()["peft_model"]
    if "bnb_config" in globals():
        del globals()["bnb_config"]
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

In [None]:
import torch
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import AutoProcessor, BitsAndBytesConfig, AutoModelForImageTextToText
import os

In [None]:
model_id = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    _attn_implementation="eager",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

processor = AutoProcessor.from_pretrained(model_id)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.03G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/136 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/67.0 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/430 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/599 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/868 [00:00<?, ?B/s]

You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


In [None]:
USE_LORA = False
USE_QLORA = True

lora_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],
    use_dora=False if USE_QLORA else True,
    init_lora_weights="gaussian"
)

In [None]:
if train_type == "xml":
    for x in ['company', 'date', 'address', 'total']:
        processor.tokenizer.add_tokens([f"<doc_{x}>", f"</doc_{x}>"])

    processor.tokenizer.add_tokens(["<kie>", "</kie>"])
    model.resize_token_embeddings(len(processor.tokenizer))

    model.config.pad_token_id = processor.tokenizer.pad_token_id
    model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids("<kie>")
    model.config.eos_token_id = processor.tokenizer.convert_tokens_to_ids("</kie>")

In [None]:
image_token_id = processor.tokenizer.additional_special_tokens_ids[
    processor.tokenizer.additional_special_tokens.index("<image>")
]


def collate_fn(examples):
    texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]

    image_inputs = []
    for example in examples:
        image = example[1]["content"][0]["image"]
        if image.mode != "RGB":
            image = image.convert("RGB")
        image_inputs.append([image])

    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels
    labels[labels == image_token_id] = -100  # Mask image token IDs in labels

    batch["labels"] = labels

    return batch

In [None]:
from trl import SFTConfig


training_args = SFTConfig(
    #num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    warmup_steps=5,
    max_steps=30,
    ###
    learning_rate=1e-4,
    logging_steps=1,
    save_steps=5,
    optim="adamw_torch_fused",
    weight_decay=0.01,
    output_dir=f"training/smolvlm2-{train_type}",
    bf16=False,
    remove_unused_columns=False,
    gradient_checkpointing=True,
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
    report_to="none",
    eval_strategy="steps"
    ##
    save_strategy="best"
)

In [None]:
def compute_metrics(tokenizer):
    def inner_compute_metrics(eval_pred):
        pred_ids, label_ids = eval_pred
        print(type(pred_ids))
        print(type(label_ids))
        pred_text = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        label_text = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
        print(pred_text)
        print(label_text)

        #print(eval_pred)       # <transformers.trainer_utils.EvalPrediction object at 0x7962e3644410>
        #print(type(eval_pred)) # <class 'transformers.trainer_utils.EvalPrediction'>
        return {
            "accuracy": 0,
            "f1": 0
        }
    return inner_compute_metrics

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset[:10],
    data_collator=collate_fn,
    peft_config=lora_config,
    processing_class=processor.tokenizer,
    compute_metrics=compute_metrics(processor)
)
trainer.can_return_loss = True

No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
trainer.train()

Step,Training Loss,Validation Loss


<class 'tuple'>
<class 'tuple'>


In [None]:
trainer.save_model(output_dir=f"training/smolvlm2-{train_type}/final")

In [None]:
clear_memory()

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    _attn_implementation="eager",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

processor = AutoProcessor.from_pretrained(model_id)

In [None]:
model.load_adapter(f"training/smolvlm2-{train_type}/final")

In [None]:
results = {}

test_images = sorted(os.listdir("data/sroie/test/img"))
for data, fn in zip(test_dataset, test_images):
    inputs = processor.apply_chat_template(
        data,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to("cuda", dtype=torch.bfloat16)

    start = time.time()
    generated_ids = model.generate(**inputs, do_sample=False, max_new_tokens=1000)
    end = time.time()

    generated_texts = processor.batch_decode(
        generated_ids,
        skip_special_tokens=True,
    )
    results[fn] = dict(
        response = generated_texts[0],
        inference_time = end - start
    )

    with open(f"result/smolvlm2-{train_type}.json", "w") as f:
        json.dump(results, f, indent=4)