In [None]:
# pip install transformers torch pillow donut-python
# pip install transformers==4.25.1
# pip install pytorch-lightning==1.6.4
# pip install timm==0.5.4
# pip install gradio

In [None]:
from google.colab import files
from PIL import Image

uploaded = files.upload()  # Allows you to upload the image
image_path = list(uploaded.keys())[0]  # Get the uploaded file's name

In [None]:
import torch
from PIL import Image
import numpy as np
from io import BytesIO

from donut import DonutModel

class DonutInference:
    def __init__(self, task_name="cord-v2", pretrained_path="naver-clova-ix/donut-base-finetuned-cord-v2"):
        self.task_name = task_name
        self.task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>" if "docvqa" == task_name else f"<s_{task_name}>"
        self.pretrained_model = self._load_model(pretrained_path)

    def _load_model(self, pretrained_path):
        model = DonutModel.from_pretrained(pretrained_path)
        if torch.cuda.is_available():
            model.half()
            model.to(torch.device("cuda"))
        model.eval()
        return model

    def run_inference(self, image):
        # Ensure image is in PIL.Image.Image format
        if isinstance(image, str):
            image = Image.open(image).convert("RGB")
        elif isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        elif isinstance(image, BytesIO):
            image = Image.open(image).convert("RGB")

        output = self.pretrained_model.inference(image=image, prompt=self.task_prompt)["predictions"][0]
        return output


if __name__ == "__main__":
    # Use the uploaded image path
    image = Image.open(image_path).convert("RGB")
    image.show()
    donut = DonutInference()
    print(donut.run_inference(image))
