# Inference with FT-DINOSAUR

## Preliminaries

In [None]:
# Install dependencies when running this on Google Colab
try:
    import google.colab  # noqa: F401

    !pip install "ftdinosaur_inference[notebook] @ git+https://github.com/rw-ocrl/ftdinosaur-inference.git"
except ImportError:
    pass

In [None]:
import matplotlib
import numpy as np
import PIL
import torch
from IPython.display import display
from PIL import Image
from torchvision.utils import draw_segmentation_masks

from ftdinosaur_inference import build_dinosaur, utils

In [None]:
def get_cmap(num_classes, cmap="tab10"):
    cmap = matplotlib.colormaps[cmap].resampled(num_classes)(range(num_classes))
    cmap = [tuple((255 * cl[:3]).astype(int)) for cl in cmap]
    return cmap


def overlay_masks_on_image(
    img: PIL.Image, masks: torch.Tensor, num_masks: int, alpha: float = 0.6
) -> PIL.Image:
    img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1)  # C x H x W
    height, width = img_tensor.shape[1:]

    # Need to resize masks to image (1 x K x P -> 1 x K x H x W)
    masks_as_image = utils.resize_patches_to_image(masks, size=(height, width))
    masks_as_image = utils.soft_masks_to_one_hot(masks_as_image).squeeze(0)

    # Overlay masks on image
    masks_on_image = draw_segmentation_masks(
        img_tensor, masks_as_image, alpha=alpha, colors=get_cmap(num_masks)
    )

    # Convert back to PIL
    masks_on_image = masks_on_image.permute(1, 2, 0).numpy()
    return Image.fromarray(masks_on_image.astype(np.uint8))


def load_model(model_name):
    model = build_dinosaur.build(model_name, pretrained=True)
    model.eval()
    return model

## Define and run model

List available models.

In [None]:
build_dinosaur.list_checkpoints()

Load the model and create preprocessing pipeline.

In [None]:
model_name = "dinosaur_base_patch14_518_topk3.coco_dv2_ft_s7_300k+10k"
model = load_model(model_name)
preproc = build_dinosaur.build_preprocessing(model_name)

Load an example image.

In [None]:
image = Image.open("./images/example.jpg").convert("RGB")
display(image)

Run the model. We can flexibly choose the number of slots using the `num_slots` argument. Note that the model was trained with 7 slots; thus, picking a number of slots close to 7 works best.

In [None]:
with torch.no_grad():
    inp = preproc(image).unsqueeze(0)
    outp = model(inp, num_slots=5)

Display the slot masks!

In [None]:
masks_with_image = overlay_masks_on_image(
    image, outp["masks"], num_masks=outp["masks"].shape[1]
)
display(masks_with_image)

We used an image with a square aspect ratio here, which is what the model was trained with, and what works best in general.
Note that the code also supports non-square aspect ratios by resizing the image to a square in the preprocessing pipeline.
The resulting masks are square as well, but are resized to match the original image in `overlay_masks_on_image`.

## Upload custom image in Colab

In [None]:
try:
    from google.colab import files
except ImportError as e:
    raise ImportError("Need to run the following in Google Colab") from e

In [None]:
# @title Upload image
uploaded = files.upload()
file_path = list(uploaded.keys())[0]
image = Image.open(file_path).convert("RGB")
display(image)

In [None]:
# @title Select number of slots
num_slots = 7  # @param {type:"slider", min:1, max:24, step:1}

In [None]:
# @title Run model and display results
with torch.no_grad():
    inp = preproc(image).unsqueeze(0)
    outp = model(inp, num_slots=num_slots)

masks_with_image = overlay_masks_on_image(
    image, outp["masks"], num_masks=outp["masks"].shape[1]
)
display(masks_with_image)