# Florence-2 for Table Extraction

Fine-tuned Florence-2 model: https://huggingface.co/ucsahin/Florence-2-large-TableDetection

This model is a fine-tuned version of __Florence-2-large-ft__ on __ucsahin/pubtables-detection-1500-samples__ dataset.

## System setup

In [None]:
!pip install accelerate
!pip install flash_attn einops timm

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from transformers import AutoProcessor, AutoModelForCausalLM

## Load fine-tuned Florence-2 model

In [None]:
model_id = "ucsahin/Florence-2-large-TableDetection"

model = AutoModelForCausalLM.from_pretrained(model_id,
                                             trust_remote_code=True,
                                             device_map="cuda") # load the model on GPU

processor = AutoProcessor.from_pretrained(model_id,
                                          trust_remote_code=True)

## Function to detect tables in images

In [None]:
def run_example(task_prompt, image, max_new_tokens=256):
    prompt = task_prompt
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    generated_ids = model.generate(
      input_ids=inputs["input_ids"].cuda(),
      pixel_values=inputs["pixel_values"].cuda(),
      max_new_tokens=max_new_tokens,
      early_stopping=False,
      do_sample=False,
      num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text,
        task=task_prompt,
        image_size=(image.width, image.height)
    )
    return parsed_answer

## Function to plot bounding boxes around tables

In [None]:
def plot_bbox(image, data):
  # Create a figure and axes
  fig, ax = plt.subplots(figsize=(12, 10))
  # Display the image
  ax.imshow(image)
  # Plot each bounding box
  for bbox, label in zip(data['bboxes'], data['labels']):
      # Unpack the bounding box coordinates
      x1, y1, x2, y2 = bbox
      # Create a Rectangle patch
      rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
      # Add the rectangle to the Axes
      ax.add_patch(rect)
      # Annotate the label
      plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
  # Remove the axis ticks and labels
  ax.axis('off')
  # Show the plot
  plt.show()

## Detect tables in images

In [None]:
# load an image
image = Image.open("scanned_doc_1.png")

In [None]:
image

In [None]:
# detect table in the input image
parsed_answer = run_example("<OD>", image=image)

In [None]:
# plot bounding box
plot_bbox(image, parsed_answer["<OD>"])

In [None]:
# load an image
image2 = Image.open("scanned_doc_2.png")

image2

In [None]:
parsed_answer2 = run_example("<OD>", image=image2)
plot_bbox(image2, parsed_answer2["<OD>"])


## Crop detected tables

In [None]:
parsed_answer2["<OD>"]

In [None]:
parsed_answer2["<OD>"]['bboxes'][0]

In [None]:
# Ensure coordinates are integers
left, top, right, bottom = map(int, parsed_answer2["<OD>"]['bboxes'][0])

# Crop the image
cropped_img = image2.crop((left-10, top, right+10, bottom))

In [None]:
cropped_img