In [None]:
# Here we use the Donut Model for document classification
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import json
from tqdm import tqdm

# Load the Donut model and processor
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")

In [None]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)


In [17]:
def classify_document(image_path):
    """
    Classify a document image using the Donut model
    Returns: tuple (predicted_class, confidence_score, scores)
    """
    try:
        # Load and preprocess the image
        image = Image.open(image_path).convert("RGB")
        pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
        
        # Generate predictions
        outputs = model.generate(
            pixel_values,
            max_length=64,
            return_dict_in_generate=True,
            output_scores=True
        )
        
        # Decode the prediction
        prediction = processor.batch_decode(outputs.sequences)[0]
        confidence = torch.max(outputs.scores[0]).item()  # Get confidence score
        
        return prediction.strip(), confidence, outputs.scores[0]
        
    except Exception as e:
        print(f"Error classifying document {image_path}: {str(e)}")
        return "unknown", 0.0

In [None]:
# Here we will use the Donut Model for document classification

sample_image_path = "/Volumes/MyDataDrive/thesis/code-2/src/fireworks/image_assets/png_images/ffdd0138/ffdd0138_page4.png"

pred, conf, scores = classify_document(sample_image_path)

print(pred, conf)

print(scores)


In [None]:
print(scores.size())

# From the official implimitation

In [1]:
from donut import DonutModel
import torch
from PIL import Image
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
import re
from transformers import DonutProcessor, VisionEncoderDecoderModel
from datasets import load_dataset
import torch

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")

device = "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)  # doctest: +IGNORE_RESULT


VisionEncoderDecoderModel(
  (encoder): DonutSwinModel(
    (embeddings): DonutSwinEmbeddings(
      (patch_embeddings): DonutSwinPatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DonutSwinEncoder(
      (layers): ModuleList(
        (0): DonutSwinStage(
          (blocks): ModuleList(
            (0): DonutSwinLayer(
              (layernorm_before): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attention): DonutSwinAttention(
                (self): DonutSwinSelfAttention(
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_features=128, bias=True)
                  (value): Linear(in_features=128, out_features=128, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )

In [9]:
# prepare decoder inputs
task_prompt = "<s_rvlcdip>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

sample_image_path = "/Volumes/MyDataDrive/thesis/code-2/src/fireworks/image_assets/png_images/flxp0006/flxp0006_page1.png"

image = Image.open(sample_image_path).convert("RGB")


pixel_values = processor(image, return_tensors="pt").pixel_values

outputs = model.generate(
    pixel_values.to(device),
    decoder_input_ids=decoder_input_ids.to(device),
    max_length=model.decoder.config.max_position_embeddings,
    pad_token_id=processor.tokenizer.pad_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    use_cache=True,
    bad_words_ids=[[processor.tokenizer.unk_token_id]],
    return_dict_in_generate=True,
)

sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
print(processor.token2json(sequence))

{'class': 'email'}


In [3]:
sample_image_path = "/Volumes/MyDataDrive/thesis/code-2/src/fireworks/image_assets/png_images/ffdd0138/ffdd0138_page4.png"

pred = demo_process(sample_image_path)
print(pred)

AttributeError: 'SwinTransformer' object has no attribute 'pos_drop'

In [23]:
# Here check the cloud hosting

import requests

with open(sample_image_path, "rb") as f:
    response = requests.post("http://102.210.171.164:8000/predict", files={"file": f})
    res = response.json()
    print(res['prediction']['class'])


email
