In [None]:
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))

from dataset import *
from PIL import Image
import json
import time
import torch

os.environ["UNSLOTH_RETURN_LOGITS"] = "1"
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"

from unsloth import FastVisionModel
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig
from transformers import TrainerCallback
from unsloth import get_chat_template

In [None]:
from json2xml import json2xml
from json2xml.utils import readfromstring
from lxml import etree
import base64

system_message = """You are a highly advanced Vision Language Model (VLM), specialized in extracting visual data.
Your task is to process and extract meaningful insights from images that are asked by the user."""

imgs_fn = []


def format_data(sample, train_type: str):
    pil_image = Image.open(sample.image_path)
    imgs_fn.append(sample.image_path.split("/")[-1])

    field_names = set(sorted([entity.label for entity in sample.entities]))
    if train_type == "xml":
        xml_fields = "".join([f"<{field}>..</{field}>" for field in field_names])
        output_format = f"<kie>{xml_fields}</kie>"
        prompt = "Extract the following {fields} from the above document. If a field is not present, return ''. Return the output in a valid XML format like {output_format}" \
            .format(
                fields = list(field_names),
                output_format = output_format
            )
    else:
        output_format = {field: ".." for field in field_names}

        prompt = "Extract the following {fields} from the above document. If a field is not present, return ''. Return the output in a valid JSON format like ```json\n{output_format}\n```" \
            .format(
                fields = list(field_names),
                output_format = output_format
            )

    if train_type == "normal":
        conversation = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}]
            },
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image },
                    { "type": "text", "text": prompt }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": f"```json\n{json.dumps(sample.to_json('kie'))}\n```"
                }]
            }
        ]
    elif train_type == "no-prompt":
        conversation = [
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": f"```json{json.dumps(sample.to_json('kie'))}```"
                }]
            }
        ]
    elif train_type == "xml":
        label = json2xml.Json2xml(
            data=readfromstring(json.dumps(sample.to_json("kie"))),
            wrapper="kie",
            pretty=False,
            attr_type=False
        ).to_xml()
        label = etree.tostring(
            etree.fromstring(label),
            encoding="unicode",
            pretty_print=False
        )

        conversation = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}]
            },
            {
                "role": "user",
                "content": [
                    { "type": "image", "image": pil_image },
                    { "type": "text", "text": prompt }
                ]
            },
            {
                "role": "assistant",
                "content": [{
                    "type": "text",
                    "text": label
                }]
            }
        ]
    else:
        raise Exception(f"{train_type} value error")

    return { "messages": conversation }

In [None]:
train_type = "normal"
dataset = "docile"

if dataset == "docile":
    train_dataset = [format_data(sample, train_type) for sample in DocILE(tasks=["kie"], split="train")]
    test_dataset = [format_data(sample, train_type) for sample in DocILE(tasks=["kie"], split="val")]
elif dataset == "sroie":
    train_dataset = [format_data(sample, train_type) for sample in SROIE(tasks=["kie"], split="train")]
    test_dataset = [format_data(sample, train_type) for sample in SROIE(tasks=["kie"], split="test")]
elif dataset == "nnts":
    test_dataset = [format_data(sample, train_type) for sample in NNTS_KIE(tasks=["kie"], split="test")]
else:
    raise Exception("Wrong dataset value")

test_dataset[0]

In [None]:
model, processor = FastVisionModel.from_pretrained(
    "unsloth/gemma-3-4b-pt",
    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)

processor = get_chat_template(
    processor,
    "gemma-3"
)

In [None]:
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True, # False if not finetuning vision layers
    finetune_language_layers   = True, # False if not finetuning language layers
    finetune_attention_modules = True, # False if not finetuning attention layers
    finetune_mlp_modules       = True, # False if not finetuning MLP layers

    r = 16,           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    #random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
    # target_modules = "all-linear", # Optional now! Can specify a list if needed
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

In [None]:
from transformers.trainer_utils import EvalPrediction
import editdistance
import xmltodict
import re


def compute_metrics(processor):
    def inner_compute_metrics(eval_pred: EvalPrediction):
        pred_ids, labels = eval_pred.predictions

        pred_ids_list = pred_ids.tolist()
        labels_list = labels.tolist()
        
        decoded_preds = []
        decoded_labels = []
        
        for i in range(len(pred_ids_list)):
            pred_tokens = [token_id for token_id in pred_ids_list[i] if token_id not in [-100, processor.tokenizer.pad_token_id]]
            decoded_pred = processor.tokenizer.decode(pred_tokens, skip_special_tokens=True)
            decoded_preds.append(decoded_pred)

            label_tokens = [token_id for token_id in labels_list[i] if token_id not in [-100, processor.tokenizer.pad_token_id]]
            decoded_label = processor.tokenizer.decode(label_tokens, skip_special_tokens=True)
            decoded_labels.append(decoded_label)
        
        similarities = []
        not_parsable = []
        edit_distance = []
        for i in range(len(decoded_preds)):
            try:
                #pred: dict = json.loads(decoded_preds[i].split('Assistant: ')[1])
                #label: dict = json.loads(decoded_labels[i].split('Assistant: ')[1])
                pred = decoded_preds[i].split('model')[1]
                label = decoded_labels[i].split('model')[1]

                pred = json.loads(re.search(r"```json\s*(.*?)\s*```", pred, re.DOTALL).group(1).strip())
                label = json.loads(re.search(r"```json\s*(.*?)\s*```", label, re.DOTALL).group(1).strip())
                #pred = xmltodict.parse(pred)["kie"]
                #label = xmltodict.parse(label)["kie"]


                field_sims = []
                for k in label.keys():
                    if k in pred:
                        dist = editdistance.eval(str(pred[k]), str(label[k]))
                        max_len = max(len(str(label[k])), 1)
                        #max_len = max(len(str(pred[k])), len(str(label[k])), 1)
                        sim = (1 - dist / max_len)
                        field_sims.append(sim)
                        edit_distance.append(sim)
                    else:
                        field_sims.append(0.0)

                #print("similarity: ", sum(field_sims) / len(field_sims))
                        
                similarities.append(sum(field_sims) / len(field_sims))
            except Exception as e:
                similarities.append(0.0)
                not_parsable.append(1.0)

        #print(similarities)
        #print(len(similarities))
        #print("#" * 100)

        return {
            "Accuracy": sum(similarities) / len(similarities),
            "Not Parsable": int(sum(not_parsable)),
            "Edit Distance": sum(edit_distance) / len(edit_distance) if len(edit_distance) != 0 else 0
        }
    return inner_compute_metrics

In [None]:
def preprocess_logits_for_metrics(logits: tuple, labels: torch.Tensor):
    """
    Original Trainer may have a memory leak.
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels

In [None]:
class JSONLoggerCallback(TrainerCallback):
    def __init__(self, log_path):
        self.log_path = log_path

    def on_train_end(self, args, state, control, **kwargs):
        with open(self.log_path, "w") as f:
            json.dump(state.log_history, f, indent=4)
        print(f"Log saved in {self.log_path}")

In [None]:
FastVisionModel.for_training(model) # Enable for training!

training_args = SFTConfig(
    per_device_train_batch_size = 4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps = 4,
    #warmup_ratio = 0.03,
    max_grad_norm = 0.3,
    warmup_steps=40,
    max_steps = 180,
    #num_train_epochs = 1, # Set this instead of max_steps for full training runs
    learning_rate = 2e-3, # per docile 1e-3 per sroie 1e-4
    logging_steps = 20,
    save_steps=20,
    optim = "adamw_torch_fused",
    #lr_scheduler_type="cosine",
    weight_decay = 0.01,
    #seed = 3407,
    output_dir = f"result-{dataset}/gemma3_{train_type}",
    report_to = "none",     # For Weights and Biases

    # You MUST put the below items for vision finetuning:
    gradient_checkpointing_kwargs = {"use_reentrant": False},
    gradient_checkpointing = True,
    remove_unused_columns = False,
    dataset_text_field = "",
    dataset_kwargs = {"skip_prepare_dataset": True},
    max_length = None,
    ##
    eval_strategy="steps",
    save_strategy="steps",
    #metric_for_best_model="Accuracy",
    #label_names=["labels"],
)

trainer = SFTTrainer(
    model = model,
    tokenizer = processor,
    data_collator = UnslothVisionDataCollator(model, processor), # Must use!
    train_dataset = train_dataset,
    eval_dataset = test_dataset,
    args = training_args,
    callbacks=[
        JSONLoggerCallback(f"log/gemma3-{train_type}-{dataset}.json")
    ],
    #compute_metrics=compute_metrics(processor),
    #preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

In [None]:
trainer_stats = trainer.train()

In [None]:
model.save_pretrained(f"result-{dataset}/gemma3_{train_type}/lora_model")
processor.save_pretrained(f"result-{dataset}/gemma3_{train_type}/lora_model")

In [None]:
FastVisionModel.for_inference(model) # Enable for inference!

In [None]:
model, processor = FastVisionModel.from_pretrained(
    model_name = f"result-{dataset}/gemma3_{train_type}/lora_model",
    load_in_4bit=True
)

FastVisionModel.for_inference(model) # Enable for inference!

processor = get_chat_template(
    processor,
    "gemma-3"
)

In [None]:
imgs_fn = []

if dataset == "docile":
    test_dataset = [format_data(sample, train_type) for sample in DocILE(tasks=["kie"], split="val")]
elif dataset == "sroie":
    test_dataset = [format_data(sample, train_type) for sample in SROIE(tasks=["kie"], split="test")]
elif dataset == "nnts":
    test_dataset = [format_data(sample, train_type) for sample in NNTS_KIE(tasks=["kie"], split="test")]
else:
    raise Exception("Wrong dataset value")

In [None]:
result = {}

os.makedirs(f"result-{dataset}", exist_ok=True)

for data, fn in zip(test_dataset, imgs_fn):
    new_data = data["messages"][0:2]
    img = new_data[1]["content"][0]["image"]

    new_data[1]["content"][0] = {"type": "image"}
    #text = new_data[1]["content"][1]["text"]
    #new_data[1]["content"][1]["text"] = f"<image>\n{text}"

    input_text = processor.apply_chat_template(new_data, add_generation_prompt = True)
    inputs = processor(
        img,
        input_text,
        add_special_tokens = False,
        return_tensors = "pt",
    ).to("cuda")

    start = time.time()
    output_ids = model.generate(
        **inputs,
        max_new_tokens=3000,
        use_cache=True,
        temperature = 0.7, top_p = 0.95, top_k = 64, repetition_penalty=1.3
    )
    end = time.time()

    output_text = processor.decode(output_ids[0], skip_special_tokens=True)

    result[fn] = dict(
        response = output_text,
        t = end - start
    )

    with open(f"result-{dataset}/gemma3n-notrain.json", "w") as f:
        json.dump(result, f, indent = 4)
