# Install the required packages

In [None]:
%%capture
!pip install transformers sentencepiece

# Load the LoolooOCR model

In [None]:
import torch

device = "cpu"
# Use mps if available
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"

In [None]:
from transformers import VisionEncoderDecoderModel, TrOCRProcessor

checkpoint_name = "our_closed_source_model"

model = VisionEncoderDecoderModel.from_pretrained(checkpoint_name).to(device)
processor = TrOCRProcessor.from_pretrained(checkpoint_name)

# Load the text detector model

In [None]:
from easyocr import Reader

detector = Reader([])

# Load document

In [None]:
from PIL import Image

image_path = "../datasets/Srisawad Deep Learning/mc_รายการจดทะเบียน_1.jpg"
image = Image.open(image_path).convert("RGB")

display(image)

# Run the text detector
We need to get all the bounding boxes of text before passing it into the recognizer.

In [None]:
from PIL import ImageDraw

drawn_image = image.copy()
draw = ImageDraw.Draw(drawn_image)

batch_regions, _ = detector.detect(image_path)

# Becase we only pass in one image, get the first one.
regions = batch_regions[0]

textboxes = []

# Crop the textbox from the image
for region in regions:
    # Rearrange the region to match the PIL.Image.crop format
    region = [region[0], region[2], region[1], region[3]]

    # Draw the region on the image
    draw.rectangle(region, outline="blue", width=2)

    textbox = image.crop(region)
    textboxes.append(textbox)

display(drawn_image)

Let's display some of the example 

In [None]:
for idx, textbox in enumerate(textboxes):
    if idx >= 3:
        break
    display(textbox)

# Run inference

In [None]:
from tqdm import tqdm

batch_size = 4

# Get divmode
quotient, remainder = divmod(len(textboxes), batch_size)
total_batches = quotient + (1 if remainder > 0 else 0)

predictions = []

for idx in tqdm(range(total_batches), desc="Batch Inferencing"):
    image_batch = textboxes[idx * batch_size : (idx + 1) * batch_size]

    pixel_values = processor(image_batch, return_tensors="pt").pixel_values

    outputs = model.generate(pixel_values.to(device))

    # Decode the prediction
    batch_predictions = processor.batch_decode(outputs, skip_special_tokens=True)
    predictions.extend(batch_predictions)

In [None]:
from PIL import ImageDraw, ImageFont

FONT = ImageFont.truetype("../assets/THSarabun.ttf", 20)
label_offset = 20

drawn_image = image.copy()
draw = ImageDraw.Draw(drawn_image)

for region, prediction in zip(regions, predictions):
    region = [region[0], region[2], region[1], region[3]]
    draw.rectangle(region, outline="blue", width=2)
    draw.text((region[0], region[1] - label_offset), prediction, fill="red", font=FONT)

display(drawn_image)