### Import libraries

In [1]:
from PIL import Image
import os
import json
from training.dataset import DocILE

In [2]:
system_message = """You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data. 
Your task is to process and extract meaningful insights from images, leveraging multimodal understanding
to provide accurate and contextually relevant information."""

### Format data

In [3]:
train_dataset = DocILE(tasks=["kie"], split="train")
test_dataset = DocILE(tasks=["kie"], split="val")

In [4]:
def format_data(sample):
    pil_image = Image.open(sample.image_path)

    field_names = set([entity.label for entity in sample.entities])
    output_format = {field: ".." for field in field_names}

    prompt = "Extract the following {fields} from the above document. If a field is not present, return ''. Return the output in a valid JSON format like {output_format}" \
        .format(
            fields = list(field_names),
            output_format = output_format
        )

    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}]
        },
        {
            "role": "user",
            "content": [
                { "type": "image", "image": pil_image },
                { "type": "text", "text": prompt }
            ]
        },
        {
            "role": "assistant",
            "content": [{ "type": "text", "text": json.dumps(sample.to_json("kie"))}]
        }
    ]



In [5]:
train_dataset = [format_data(sample) for sample in train_dataset]
test_dataset = [format_data(sample) for sample in test_dataset]

In [6]:
train_dataset[0]

[{'role': 'system',
  'content': [{'type': 'text',
    'text': 'You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data. \nYour task is to process and extract meaningful insights from images, leveraging multimodal understanding\nto provide accurate and contextually relevant information.'}]},
 {'role': 'user',
  'content': [{'type': 'image',
    'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1692x2245>},
   {'type': 'text',
    'text': "Extract the following ['document_id', 'customer_id', 'vendor_name', 'amount_due', 'payment_reference', 'currency_code_amount_due', 'amount_total_gross', 'vendor_address', 'customer_order_id', 'payment_terms', 'bank_num', 'account_num', 'date_issue', 'customer_billing_address', 'customer_billing_name'] from the above document. If a field is not present, return ''. Return the output in a valid JSON format like {'document_id': '..', 'customer_id': '..', 'vendor_name': '..', 'amount_due': '..', 'payment_r

### Training pipeline

In [7]:
# define model

import torch
from transformers import AutoModelForImageTextToText
from transformers import AutoProcessor
from transformers import BitsAndBytesConfig

model_id = "HuggingFaceTB/SmolVLM-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto",
    quantization_config=bnb_config
)

processor = AutoProcessor.from_pretrained(model_id)

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
    use_dora=True,
    init_lora_weights="gaussian",
)
peft_config.inference_mode = False # new

peft_model = get_peft_model(model, peft_config)

peft_model.print_trainable_parameters()

trainable params: 11,269,248 || all params: 2,257,542,128 || trainable%: 0.4992


In [9]:
image_token_id = processor.tokenizer.additional_special_tokens_ids[
    processor.tokenizer.additional_special_tokens.index("<image>")
]


def collate_fn(examples):
    texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]

    image_inputs = []
    for example in examples:
        image = example[1]["content"][0]["image"]
        if image.mode != "RGB":
            image = image.convert("RGB")
        image_inputs.append([image])

    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Mask padding tokens in labels
    labels[labels == image_token_id] = -100  # Mask image token IDs in labels

    batch["labels"] = labels

    return batch

In [None]:
from trl import SFTConfig

# Configure training arguments using SFTConfig
training_args = SFTConfig(
    output_dir="training/smolvlm2-docile",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=5,
    save_strategy="steps",
    save_steps=20,
    save_total_limit=1,
    optim="adamw_torch_fused",
    bf16=True,
    remove_unused_columns=False,
    gradient_checkpointing=True,
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
)

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collate_fn,
    peft_config=peft_config
    #processing_class=processor.tokenizers
)

No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


: 

In [None]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
5,1.1541
10,1.1149
15,1.0154
20,0.929
25,0.8056
30,0.7051
35,0.5767
40,0.4393
45,0.3313
50,0.2473


In [None]:
trainer.save_model(output_dir="training/smolvlm2-docile/final")