In [None]:
!pip3 install pytesseract

In [None]:
# # Windows
# import pytesseract
# pytesseract.pytesseract.tesseract_cmd = r"<설치 경로>\tesseract.exe"

In [None]:
# # macOS
# brew install tesseract

In [None]:
# # Linux/Google Colab
# sudo apt install tesseract-ocr

In [None]:
import io
from PIL import Image
from datasets import load_dataset
from transformers import LayoutLMv3FeatureExtractor

def get_ocr_words_and_boxes(sample):
    image_bytes = io.BytesIO(sample["image"])
    image = Image.open(image_bytes)

    encoded_inputs = feature_extractor(image)
    sample["words"] = encoded_inputs.words[0]
    sample["boxes"] = encoded_inputs.boxes[0]
    sample["pixel_values"] = encoded_inputs.pixel_values[0]
    return sample

dataset = load_dataset("s076923/docvqa-train")
model_name = "microsoft/layoutlmv3-base"
feature_extractor = LayoutLMv3FeatureExtractor(model_name)
dataset_with_ocr = dataset["train"].map(get_ocr_words_and_boxes)

print(dataset_with_ocr[1].keys())
print("question :", dataset_with_ocr[1]["question"])
print("answers :", dataset_with_ocr[1]["answers"])
print("words :", dataset_with_ocr[1]["words"])
print("boxes :", dataset_with_ocr[1]["boxes"])

In [None]:
def find_sublist(word_list, target_list):
    word_list = [word.lower() for word in word_list]
    target_list = target_list.lower().split()

    for i in range(len(word_list) - len(target_list) + 1):
        if word_list[i : i + len(target_list)] == target_list:
            return target_list, i, i + len(target_list) - 1
    return None, 0, 0

question = dataset_with_ocr[10]["question"]
words = dataset_with_ocr[10]["words"]
answers = dataset_with_ocr[10]["answers"]
print(question)
print(words)
print(answers)
print()

for answer in answers:
    match, word_idx_start, word_idx_end = find_sublist(words, answer)
    print("Match :", match)
    print("Word idx start :", word_idx_start)
    print("Word idx end :", word_idx_end)
    print()

In [None]:
from transformers import LayoutLMv3TokenizerFast
from datasets import Features, Sequence, Value, Array2D, Array3D

def find_answer_match(words, answers):
    for answer in answers:
        match, word_idx_start, word_idx_end = find_sublist(words, answer)
        if match:
            return match, word_idx_start, word_idx_end

    for answer in answers:
        for i in range(len(answer)):
            answer_modified = answer[:i] + answer[i + 1 :]
            match, word_idx_start, word_idx_end = find_sublist(words, answer_modified)
            if match:
                return match, word_idx_start, word_idx_end

    return False, None, None

def encode_dataset(examples, processor, max_length=512):
    encoding = processor(
        examples["question"],
        examples["words"],
        examples["boxes"],
        max_length=max_length,
        padding="max_length",
        truncation=True
    )

    cls_index = encoding.input_ids.index(processor.cls_token_id)
    start_position = end_position = cls_index

    match, word_idx_start, word_idx_end = find_answer_match(
        examples["words"], examples["answers"]
    )

    if match:
        sequence_ids = encoding.sequence_ids(0)
        token_start_index = next(i for i, seq_id in enumerate(sequence_ids) if seq_id == 1)

        token_end_index = len(encoding.input_ids) - 1 - sequence_ids[::-1].index(1)
        word_ids = encoding.word_ids()[token_start_index : token_end_index + 1]

        start_position = token_start_index + word_ids.index(word_idx_start)
        end_position = token_end_index - word_ids[::-1].index(word_idx_end)

    encoding["image"] = examples["pixel_values"]
    encoding["start_positions"] = start_position
    encoding["end_positions"] = end_position
    return encoding

processor = LayoutLMv3TokenizerFast.from_pretrained(model_name)
encoded_dataset = dataset_with_ocr.map(
    lambda x: encode_dataset(x, processor),
    remove_columns=dataset_with_ocr.column_names,
    features=Features(
        {
            "input_ids": Sequence(feature=Value(dtype="int64")),
            "bbox": Array2D(dtype="int64", shape=(512, 4)),
            "attention_mask": Sequence(Value(dtype="int64")),
            "image": Array3D(dtype="float32", shape=(3, 224, 224)),
            "start_positions": Value(dtype="int64"),
            "end_positions": Value(dtype="int64")
        }
    )
)
print(encoded_dataset)

In [None]:
from transformers import TrainingArguments, Trainer
from transformers import LayoutLMv3ForQuestionAnswering

model = LayoutLMv3ForQuestionAnswering.from_pretrained(model_name)

training_args = TrainingArguments(
    output_dir="DocVQA",
    num_train_epochs=20,
    per_device_train_batch_size=4,
    learning_rate=5e-5,
    warmup_steps=100,
    weight_decay=0.01,
    logging_strategy="steps",
    logging_steps=20,
    seed=42
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset
)

trainer.train()

In [None]:
import torch
from transformers import LayoutLMv3Processor

index = 5
processor = LayoutLMv3Processor.from_pretrained(model_name)

image_bytes = io.BytesIO(dataset_with_ocr[index]["image"])
image = Image.open(image_bytes)

full_text = processor.decode(encoded_dataset["input_ids"][index])
print("Full text:", full_text)

question = dataset_with_ocr[index]["question"]
print("Question:", question)

start_position = encoded_dataset["start_positions"][index]
end_position = encoded_dataset["end_positions"][index]
answer = processor.decode(
    encoded_dataset["input_ids"][index][start_position : end_position + 1]
)
print("Answer:", answer)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoded_inputs = processor(image, question, return_tensors="pt")
encoded_inputs = {k: v.to(device) for k, v in encoded_inputs.items()}
print("Encoded input keys:", encoded_inputs.keys())

In [None]:
model.to(device)
model.eval()

with torch.no_grad():
    outputs = model(**encoded_inputs)

start_logits, end_logits = outputs.start_logits, outputs.end_logits
start_index = start_logits.argmax(-1).item()
end_index = end_logits.argmax(-1).item()
predicted_answer = processor.decode(
    encoded_inputs["input_ids"].squeeze()[start_index : end_index + 1]
)

print("Predicted start_index:", start_index)
print("Predicted end_index:", end_index)
print("predicted_answer:", predicted_answer)