# Introduction

Using the Qwen2-VL 7B model to generate ground truth OCR data for SROIEv2 dataset.

In [1]:
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from transformers import BitsAndBytesConfig
from tqdm.auto import tqdm

import torch
import glob
import os

In [2]:
# flash_attention_2 for better acceleration and memory saving. Great for batched inference.
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto"
)

# Load processor
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [3]:
all_images = glob.glob('../input/sroie_v2/SROIE2019/test/img/*.jpg')
# all_images = glob.glob('../input/sroie_v2/SROIE2019/train/img/*.jpg')

In [4]:
out_dir = '../input/qwen2_vl_7b_sroiev2_test_annots'
# out_dir = '../input/qwen2_vl_7b_sroiev2_train_annots'

os.makedirs(out_dir, exist_ok=True)

In [5]:
print(len(all_images))

347


## Batch Inference

Batch processing example
```python
messages1 = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": "file:///path/to/image1.jpg"},
            {"type": "image", "image": "file:///path/to/image2.jpg"},
            {"type": "text", "text": "What are the common elements in these pictures?"},
        ],
    }
]
messages2 = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Who are you?"},
]
# Combine messages for batch processing
messages = [messages1, messages1]
```

In [6]:
def batch_infer(messages):
    # Preparation for inference
    texts = [
            processor.apply_chat_template(
            msg, tokenize=False, add_generation_prompt=True
        )
        for msg in messages
    ]
    
    image_inputs, video_inputs = process_vision_info(messages)
    
    inputs = processor(
        text=texts,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt"
    )
    inputs = inputs.to("cuda")
    
    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens=1024)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    # print(output_text)
    return output_text

In [7]:
batch_size = 8

In [8]:
from torch.utils.data import Dataset, DataLoader

In [9]:
class BatchedDataset(Dataset):
    def __init__(self, all_images):
        self.all_images = all_images

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        return self.all_images[idx]

In [10]:
custom_dataset = BatchedDataset(all_images)

In [11]:
batch_dl = DataLoader(custom_dataset, batch_size=batch_size, shuffle=False)

In [12]:
for i in batch_dl:
    print(i)
    break

['../input/sroie_v2/SROIE2019/test/img/X51005724625.jpg', '../input/sroie_v2/SROIE2019/test/img/X51005447859.jpg', '../input/sroie_v2/SROIE2019/test/img/X51005711446.jpg', '../input/sroie_v2/SROIE2019/test/img/X51006335818.jpg', '../input/sroie_v2/SROIE2019/test/img/X51007231274.jpg', '../input/sroie_v2/SROIE2019/test/img/X51006556610.jpg', '../input/sroie_v2/SROIE2019/test/img/X51006414427.jpg', '../input/sroie_v2/SROIE2019/test/img/X51005677327.jpg']


In [13]:
for batch in tqdm(batch_dl, total=len(batch_dl)):
    messages = []
    
    for image_path in batch:
        message = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image_path,
                        "resized_height": 768,
                        "resized_width": 512,
                    },
                    {"type": "text", "text": "Give the OCR text from this image and nothing else."},
                ],
            }
        ]
        messages.append(message)

    texts = batch_infer(messages)

    for text, image_path in zip(texts, batch):
        # print(text)
        with open(os.path.join(out_dir, image_path.split(os.path.sep)[-1].split('.jpg')[0]+'.txt'), 'w') as f:
            f.write(text)

  0%|          | 0/44 [00:00<?, ?it/s]