# Introduction

Calculating the CER by taking the results from the fine-tuned SmolVLM model and comparing them with the Qwen2-VL 2B annotations.

In [1]:
import glob
import jiwer
import torch
import os

from transformers import AutoModelForImageTextToText, AutoProcessor
from tqdm.auto import tqdm
from PIL import Image

In [2]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

## Calculate CER for Test Data using Trained Model

In [3]:
# Get the text from the Qwen VL annotations.
vlm_data = []

all_vlm_txt_test_paths = glob.glob('../input/qwen2_5_vl_3b_annots/test_annots/*.txt')
all_vlm_txt_test_paths.sort()

for file_path in all_vlm_txt_test_paths:
    data = open(file_path).read()
    vlm_data.append(data.lower())

In [4]:
model_id = '../notebooks/trained_models/full_ft/smolvlm2_256m_fullft_qwen2_5_vl_3b_gt_20250505/'

In [5]:
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    _attn_implementation='flash_attention_2' # Use `flash_attention_2` on Ampere GPUs and above and `eager` on older GPUs.
    # _attn_implementation='eager', # Use `flash_attention_2` on Ampere GPUs and above and `eager` on older GPUs.
)

processor = AutoProcessor.from_pretrained(model_id)

## Inference Function

In [6]:
def test(model, processor, batch, max_new_tokens=500, device='cuda'):
    messages = []

    for i, data in enumerate(batch):
        message = [
            {
                'role': 'user',
                'content': [
                    {'type': 'image', 'url': data},
                    {'type': 'text', 'text': 'OCR this image accurately'}
                ]
            },
        ]
        messages.append(message)
    
    # Prepare the text input by applying the chat template
    model_inputs = processor.apply_chat_template(
        messages,  # Use the sample without the system message
        add_generation_prompt=True,
        padding=True,
        padding_side='left',
        return_tensors='pt',
        tokenize=True,
        return_dict=True
    ).to(device, dtype=torch.bfloat16)

    # Generate text with the model
    generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)

    # Trim the generated ids to remove the input ids
    trimmed_generated_ids = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    # Decode the output text
    output_text = processor.batch_decode(
        trimmed_generated_ids,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )

    return output_text

## Create Batches

In [7]:
from torch.utils.data import Dataset, DataLoader

In [8]:
# SROIE images from the original dataset.
all_image_paths = glob.glob('../input/sroie_v2/SROIE2019/test/img/*.jpg')
all_image_paths.sort()
print(len(all_image_paths))

347


In [9]:
class CustomData(Dataset):
    def __init__(self, image_paths):
        self.image_paths = image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        return self.image_paths[idx]

In [10]:
batch_size = 16

In [11]:
dataset = CustomData(all_image_paths)
batched_dl = DataLoader(
    dataset=dataset, 
    batch_size=batch_size, 
    shuffle=False, 
    num_workers=4
)

In [12]:
print(len(batched_dl))

22


In [13]:
inference_results = []

for i, batch in tqdm(enumerate(batched_dl), total=len(batched_dl)):
    # if i == 1:
    #     break

    outputs = test(model, processor, batch)
    # print(outputs)

    for output in outputs:
        inference_results.append(output.lower())

  0%|          | 0/22 [00:00<?, ?it/s]

In [14]:
print(len(inference_results))

347


## Function to Calculate CER

In [15]:
def calculate_cer(ground_truth, results):
    """
    :param ground_truth: List containing the ground truth data
        e.g. ['tan woon yann\nbook ta.k', 'are not returnable or']
    :param results: VLM generated results
        e.g. ['tan woon yann\nbook ta.k', 'are not returnable or']
    """

    # Remove elements when ground truth has empty string elements.
    for i, gt in enumerate(ground_truth):
        if len(gt) == 0:
            ground_truth.pop(i)
            results.pop(i)
    
    error = jiwer.cer(ground_truth, results)
    print(f"CER: {error}")

In [16]:
calculate_cer(vlm_data, inference_results)

CER: 0.2399265156154317
