# Inference with FT-DINOSAUR

## Preliminaries

In [None]:
# Install dependencies when running this on Google Colab
try:
    import google.colab  # noqa: F401

    !pip install "ftdinosaur-inference[notebook] @ git+https://github.com/ft-dinosaur/ftdinosaur-inference.git"
except ImportError:
    pass

In [None]:
import matplotlib
import numpy as np
import PIL
import torch
from IPython.display import display
from PIL import Image
from torchvision.utils import draw_segmentation_masks

from ftdinosaur_inference import build_dinosaur, utils

In [None]:
def get_cmap(num_classes, cmap="tab10"):
    cmap = matplotlib.colormaps[cmap].resampled(num_classes)(range(num_classes))
    cmap = [tuple((255 * cl[:3]).astype(int)) for cl in cmap]
    return cmap


def overlay_masks_on_image(
    img: PIL.Image, masks: torch.Tensor, num_masks: int, alpha: float = 0.6
) -> PIL.Image:
    img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1)  # C x H x W
    height, width = img_tensor.shape[1:]

    # Need to resize masks to image (1 x K x P -> 1 x K x H x W)
    masks_as_image = utils.resize_patches_to_image(masks, size=(height, width))
    masks_as_image = utils.soft_masks_to_one_hot(masks_as_image).squeeze(0)

    # Overlay masks on image
    masks_on_image = draw_segmentation_masks(
        img_tensor, masks_as_image, alpha=alpha, colors=get_cmap(num_masks)
    )

    # Convert back to PIL
    masks_on_image = masks_on_image.permute(1, 2, 0).numpy()
    return Image.fromarray((masks_on_image * 255).astype(np.uint8))


def load_model():
    model = build_dinosaur.build_dinosaur_small_patch14_224_topk3(pretrained=False)
    checkpoint_path = ""  # TODO
    model.load_state_dict(
        utils.convert_checkpoint_from_oclf(torch.load(checkpoint_path))
    )
    model.eval()
    return model

## Define and run model

Load the model and define preprocessing. Make sure that `input_size` matches the resolution the model was trained on.

In [None]:
model = load_model()
preproc = utils.build_preprocessing(input_size=224)

Load an example image.

In [None]:
image = Image.open("./images/gizmos.jpg").convert("RGB")
display(image)

Run the model. We can flexibly choose the number of slots using the `num_slots` argument.

In [None]:
with torch.no_grad():
    inp = preproc(image).unsqueeze(0)
    outp = model(inp, num_slots=6)

Display the slot masks!

In [None]:
masks_with_image = overlay_masks_on_image(
    image, outp["masks"], num_masks=outp["masks"].shape[1]
)
display(masks_with_image)

Note that we used an image with a non-square aspect ratio here. 
This is achieved by resizing the image to a square in the preprocessing pipeline.
The resulting masks are square as well, but are resized to match the original image in `overlay_masks_on_image`.
As the model was trained with square images, it generally works better with square images.