# How to Finetune Florence-2 Model for Object Detection
Florence-2 has a great capability of detecting various objects in a zero-shot setting with the task prompt "<OD>". 
However, if you want to detect specific objects that the base model is not able to in its current form, this notebook shows how you can finetune the model to perform this task with your custom data. 
Here I show how to finetune the model to detect tables in a given image, but a similar process can be applied to detect any objects. 
For finetuning Florence-2 in VQA from documents, please check HuggingFace Florence-2 Finetuning blog post and for inference Florence-2-Large Sample Inference. 
This notebook is heavily inspired by them.

In [1]:
%pip install -q -U git+https://github.com/huggingface/transformers.git accelerate datasets

Note: you may need to restart the kernel to use updated packages.


In [None]:
%pip install flash_attn timm einops

In [4]:
from datasets import load_dataset
dataset = load_dataset("ucsahin/pubtables-detection-1500-samples")

Downloading readme:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/225M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1500 [00:00<?, ? examples/s]

In [5]:
dataset["train"][0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=771x1000>,
 'objects': {'bbox': [[75.53573789752897,
    805.4076686049953,
    580.396252241291,
    882.670855281329]],
  'categories': 'table'}}

In [6]:
import torch

def preprocess_fnc(examples):
    bbox_formatted_list = []
    for objects, image in zip(examples["objects"], examples["image"]):
        width, height = image.size
        bins_w, bins_h = [1000, 1000]  # Quantization bins.
        size_per_bin_w = width / bins_w
        size_per_bin_h = height / bins_h

        bboxes = objects["bbox"]
        bbox_str = ""
        for bbox in bboxes:
            # if you are to detect multiple objects, bbox_str should contain category labels before the bounding boxes
            if len(bbox_str) == 0:
                bbox_str += "table"
            bbox = bbox.copy()

            xmin, ymin, xmax, ymax = torch.tensor(bbox).split(1, dim=-1)
            quantized_xmin = (
                xmin / size_per_bin_w).floor().clamp(0, bins_w - 1)
            quantized_ymin = (
                ymin / size_per_bin_h).floor().clamp(0, bins_h - 1)
            quantized_xmax = (
                xmax / size_per_bin_w).floor().clamp(0, bins_w - 1)
            quantized_ymax = (
                ymax / size_per_bin_h).floor().clamp(0, bins_h - 1)

            quantized_boxes = torch.cat(
                (quantized_xmin, quantized_ymin, quantized_xmax, quantized_ymax), dim=-1
            ).int()

            bbox_str += f"<loc_{quantized_boxes[0]}><loc_{quantized_boxes[1]}><loc_{quantized_boxes[2]}><loc_{quantized_boxes[3]}>"

        bbox_formatted_list.append(bbox_str)

    examples["bbox_str"] = bbox_formatted_list
    return examples

In [7]:
processed_dataset = dataset.map(preprocess_fnc, batched=True)


Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

In [8]:
split_dataset = processed_dataset["train"].train_test_split(test_size=0.1, shuffle=True)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

print("Len train dataset: ", len(train_dataset))
print("Len eval dataset: ", len(eval_dataset))

Len train dataset:  1350
Len eval dataset:  150


In [10]:
from transformers import AutoProcessor, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large-ft", revision="refs/pr/10", trust_remote_code=True) # load the model on GPU
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large-ft", revision="refs/pr/10", trust_remote_code=True)

In [14]:
def run_example(task_prompt, image, max_new_tokens=128):
    prompt = task_prompt
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    generated_ids = model.generate(
      input_ids=inputs["input_ids"],
      pixel_values=inputs["pixel_values"],
      max_new_tokens=max_new_tokens,
      early_stopping=False,
      do_sample=False,
      num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text,
        task=task_prompt,
        image_size=(image.width, image.height)
    )
    return parsed_answer

In [15]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def plot_bbox(image, data):
   # Create a figure and axes
    fig, ax = plt.subplots()

    # Display the image
    ax.imshow(image)

    # Plot each bounding box
    for bbox, label in zip(data['bboxes'], data['labels']):
        # Unpack the bounding box coordinates
        x1, y1, x2, y2 = bbox
        # Create a Rectangle patch
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
        # Add the rectangle to the Axes
        ax.add_patch(rect)
        # Annotate the label
        plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))

    # Remove the axis ticks and labels
    ax.axis('off')

    # Show the plot
    plt.show()

In [None]:
example_id = 250
image = processed_dataset["train"][example_id]["image"]

# notice here that <OD> task prompt is used. This task prompt is already used in training the Florence-2 model checkpoints for object detection.
parsed_answer = run_example(task_prompt="<OD>", image=image)

plot_bbox(image, parsed_answer["<OD>"])

In [17]:
for param in model.vision_tower.parameters():
  param.requires_grad = False

In [18]:
model_total_params = sum(p.numel() for p in model.parameters())
model_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Number of trainable parameters {model_train_params} out of {model_total_params}, rate: {model_train_params/model_total_params:0.3f}")

Number of trainable parameters 462061568 out of 822693888, rate: 0.562


In [19]:
IGNORE_ID = -100 # Pytorch ignore index when computing loss
MAX_LENGTH = 512

def collate_fn(examples):
    task_prompt = "<OD>"

    prompt_texts = [task_prompt for _ in examples]
    label_texts = [example["bbox_str"] for example in examples]
    images = [example["image"] for example in examples]

    inputs = processor(
        images=images,
        text=prompt_texts,
        return_tensors="pt",
        padding="longest",
        max_length=MAX_LENGTH,
    )

    labels = processor.tokenizer(
        label_texts,
        return_tensors="pt",
        padding="longest",
        max_length=MAX_LENGTH,
        return_token_type_ids=False, # no need to set this to True since BART does not use token type ids
    )["input_ids"]

    labels[labels == processor.tokenizer.pad_token_id] = IGNORE_ID # do not learn to predict pad tokens during training

    return_data = {**inputs, "labels": labels}
    return return_data

In [20]:
collated_examples = collate_fn([train_dataset[0], train_dataset[6]])

In [None]:
collated_examples

In [22]:
from transformers import TrainingArguments

args=TrainingArguments(
    output_dir="./ft_models/Florence-2-large-TableDetection",
    num_train_epochs=10,
    learning_rate=1e-6,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_strategy="epoch",
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_total_limit=5,
    load_best_model_at_end=False, # we will manually push model to the hub at the end of training
    label_names=["labels"],
    report_to="tensorboard",
    remove_unused_columns=False,  # needed for data collator
    push_to_hub=True,
    hub_model_id="ucsahin/Florence-2-large-TableDetection",
)

In [23]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    tokenizer=processor,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn, # dont forget to add custom data collator
    args=args
)

RuntimeError: TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX.