In [None]:
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import os
from unittest.mock import patch

import requests
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from transformers.dynamic_module_utils import get_imports
import matplotlib.pyplot as plt
import matplotlib.patches as patches

In [None]:
# Mac solution => https://huggingface.co/microsoft/Florence-2-large-ft/discussions/4
def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
    """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
    if not str(filename).endswith("/modeling_florence2.py"):
        return get_imports(filename)
    imports = get_imports(filename)
    imports.remove("flash_attn")
    return imports


with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):

    model = AutoModelForCausalLM.from_pretrained("medieval-data/florence2-medieval-bbox-zone-detection", trust_remote_code=True)
    processor = AutoProcessor.from_pretrained("medieval-data/florence2-medieval-bbox-zone-detection", trust_remote_code=True)


In [None]:
def process_image(url):
    prompt = "<OD>"

    image = Image.open(requests.get(url, stream=True).raw)

    inputs = processor(text=prompt, images=image, return_tensors="pt")

    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        do_sample=False,
        num_beams=3
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

    result = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
    return result, image

In [None]:
image1 = "https://huggingface.co/datasets/CATMuS/medieval-segmentation/resolve/main/data/dev/london-british-library-egerton-821/page-002-of-004.jpg"
image2 = "https://huggingface.co/datasets/CATMuS/medieval-segmentation/resolve/main/data/dev/paris-bnf-lat-12449/page-002-of-003.jpg"
image3 = "https://huggingface.co/datasets/CATMuS/medieval-segmentation/resolve/main/data/dev/paris-bnf-nal-1909/page-009-of-012.jpg"
image4 = "https://huggingface.co/datasets/CATMuS/medieval-segmentation/resolve/main/data/test/paris-bnf-fr-574/page-001-of-003.jpg"
image5 = "https://huggingface.co/datasets/CATMuS/medieval-segmentation/resolve/main/data/train/oxford-bodleian-library-ms-span-d-2-1/page-001-of-001.jpg"
image6 = "https://huggingface.co/datasets/CATMuS/medieval-segmentation/resolve/main/data/train/munich-bayerische-staatsbibliothek-clm-23343/page-001-of-001.jpeg"
image7 = "https://huggingface.co/datasets/CATMuS/medieval-segmentation/resolve/main/data/train/leipzig-universitats-bibliothek-leipzig-ms-758/page-008-of-015.jpg"
image8 = "https://huggingface.co/datasets/CATMuS/medieval-segmentation/resolve/main/data/train/cambridge-corpus-christi-college-ms-111/page-002-of-003.jpg"

image9 = "https://c8.alamy.com/comp/PPJNB8/latin-manuscript-signatura-vitr-14-5-jurisdiction-sheet-2-vo-quadratic-table-medieval-document-exhibition-the-scientific-legacy-of-al-andalus-location-national-library-PPJNB8.jpg"
image10 = "https://preview.redd.it/tengwar-table-in-the-medieval-byzantine-ot-armenian-style-v0-5htv9xg28ypa1.jpg?width=640&crop=smart&auto=webp&s=a26faaa57fed97b9c04157043f628a4214928fa1"

image11 = "https://www.e-codices.unifr.ch/en/download/ubb-A-II-0012_0006v/medium"
image12 = "https://www.e-codices.unifr.ch/en/download/ubb-A-IX-0014_0002r/medium"
image13 = "https://www.e-codices.unifr.ch/en/download/ubb-A-IX-0014_0002v/medium"

result, image = process_image(image13)
fig, ax = plt.subplots(1, figsize=(15, 15))
ax.imshow(image)

# Add bounding boxes and labels to the plot
for bbox, label in zip(result['<OD>']['bboxes'], result['<OD>']['labels']):
    x, y, width, height = bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
    rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor='r', facecolor='none')
    ax.add_patch(rect)
    plt.text(x, y, label, fontsize=12, bbox=dict(facecolor='yellow', alpha=0.5))

# Display the plot
plt.show()