# Install the required packages

In [None]:
%%capture
!pip install transformers sentencepiece

# Load the Donut model

In [None]:
import torch

device = "cpu"
# Use mps if available
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"

In [None]:
from transformers import VisionEncoderDecoderModel, DonutProcessor

checkpoint_name = "naver-clova-ix/donut-base-finetuned-docvqa"

model = VisionEncoderDecoderModel.from_pretrained(checkpoint_name).to(device)
processor = DonutProcessor.from_pretrained(checkpoint_name)

# Load document

In [None]:
from PIL import Image

path = "../assets/donut_receipt_example.jpg"
image = Image.open(path).convert("RGB")

display(image)

# Run inference

In [None]:
import re

questions = [
    "How much is the Chocolate Soft Cookie?",
    "What is the TAX ID?",
    "What is the total quantity of Cold Brew Latte?",
]

for question in questions:
    task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
    prompt = task_prompt.format(user_input=question)
    # Tokenize the prompt into input_ids
    decoder_input_ids = processor.tokenizer(
        prompt, add_special_tokens=False, return_tensors="pt"
    ).input_ids
    # Convert the image into pixel_values
    pixel_values = processor(image, return_tensors="pt").pixel_values

    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
        processor.tokenizer.pad_token, ""
    )
    sequence = re.sub(
        r"<.*?>", "", sequence, count=1
    ).strip()  # remove first task start token
    print(processor.token2json(sequence))

## Quiz

- ลอง prompt ด้วยคำสั่งด้านบน แต่ใช้ภาพเล่มทะเบียนแทน
- ตรวจสอบผลว่าแม่นยำขนาดไหน (hint: Donut ยังทำงานได้ไม่ค่อยดีในภาพภาษาไทย)