# Login to HuggingFace (just login once)

In [None]:
from huggingface_hub import interpreter_login
interpreter_login()

# Collect Menu Image Datasets
- Use `metadata.jsonl` to label the images's ground truth. You can visit [here](https://github.com/ryanlinjui/menu-text-detection/tree/main/examples) to see the examples.
- After finishing, push to HuggingFace Datasets.
- For labeling:
    - [Google AI Studio](https://aistudio.google.com) or [OpenAI ChatGPT](https://chatgpt.com).
    - Use function calling by API. Start the gradio app locally or visit [here](https://huggingface.co/spaces/ryanlinjui/menu-text-detection).

In [None]:
from datasets import load_dataset

dataset = load_dataset(path="datasets/menu-zh-TW")
dataset.push_to_hub(repo_id="ryanlinjui/menu-zh-TW")

# Setup for Fine-tuning

In [None]:
from datasets import load_dataset
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
from menu.donut import DonutDatasets

DATASETS_REPO_ID = "ryanlinjui/menu-zh-TW"
PRETRAINED_MODEL_REPO_ID = "naver-clova-ix/donut-base"
TASK_PROMPT_NAME = "<s_menu>"
MAX_LENGTH = 768
IMAGE_SIZE = [1280, 960]

raw_datasets = load_dataset(DATASETS_REPO_ID)

# Config: 預訓練模型載入 Encoder–Decoder 的設定
config = VisionEncoderDecoderConfig.from_pretrained(PRETRAINED_MODEL_REPO_ID)
config.encoder.image_size = IMAGE_SIZE
config.decoder.max_length = MAX_LENGTH

# Processor: 影像前處理與文字後處理
processor = DonutProcessor.from_pretrained(PRETRAINED_MODEL_REPO_ID)
processor.feature_extractor.size = IMAGE_SIZE[::-1]
processor.feature_extractor.do_align_long_axis = False

# Donut Datasets: 
datasets = DonutDatasets(
    datasets=raw_datasets,
    processor=processor,
    image_column="image",
    annotation_column="menu",
    task_start_token=TASK_PROMPT_NAME,
    prompt_end_token=TASK_PROMPT_NAME,
    train_split=0.8,
    validation_split=0.1,
    test_split=0.1
)

# Model: 載入預訓練模型
model = VisionEncoderDecoderModel.from_pretrained(PRETRAINED_MODEL_REPO_ID, config=config)
model.decoder.resize_token_embeddings(len(processor.tokenizer))
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([TASK_PROMPT_NAME])[0]

# Start Fine-tuning

In [None]:
import torch
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

HUGGINGFACE_MODEL_ID = "ryanlinjui/donut-base-finetuned-menu"
EPOCHS = 30
TRAIN_BATCH_SIZE = 1

if torch.cuda.is_available():
    print("Using GPU")
    model.to("cuda")
else:
    print("Using default device")

training_args = Seq2SeqTrainingArguments(
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    learning_rate=3e-5,
    per_device_eval_batch_size=1,
    output_dir="./.checkpoints",
    seed=2022,
    warmup_steps=30,
    max_steps=-1,
    eval_strategy="steps",
    eval_steps=500,
    logging_strategy="steps",
    logging_steps=100,
    save_strategy="steps",
    save_steps=200,
    push_to_hub=True if HUGGINGFACE_MODEL_ID else False,
    hub_model_id=HUGGINGFACE_MODEL_ID,
    hub_strategy="every_save",
    push_to_hub_model_id=HUGGINGFACE_MODEL_ID,
)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=datasets["train"],
    eval_dataset=datasets["test"]
)
trainer.train()

# Plot the results

In [None]:
# test model
import re

from transformers import VisionEncoderDecoderModel
from PIL import Image

image = Image.open("/root/menu-text-detection/examples/menu-hd.jpg").convert("RGB")

model = VisionEncoderDecoderModel.from_pretrained(".checkpoint/checkpoint-2000")
device = "cuda" if torch.cuda.is_available() else "cpu"

model.eval()
model.to(device)

pixel_values = processor(image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)

task_prompt = "<s_menu>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids = decoder_input_ids.to(device)
outputs = model.generate(
    pixel_values,
    decoder_input_ids=decoder_input_ids,
    max_length=model.decoder.config.max_position_embeddings,
    early_stopping=True,
    pad_token_id=processor.tokenizer.pad_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    use_cache=True,
    num_beams=1,
    bad_words_ids=[[processor.tokenizer.unk_token_id]],
    return_dict_in_generate=True,
)

seq = processor.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
seq = processor.token2json(seq)
print(seq)