## 1. Import necessary packages

In [None]:
import cv2
import torch

from matplotlib import pyplot as plt
from util.utils import load_state_dict
from util.visualize import plot_bounding_boxes_on_image_cv2

## 2. Load a model and class information

We encode class information into the checkpoint, use our defined `load_state_dict` to handle it.

In [None]:
from configs.canet.canet_resnet50_800_1333 import model
weight = torch.load("checkpoint.pth")
load_state_dict(model, weight)
model = model.eval()

In [None]:
image = cv2.imread("data/coco/val2017/000000000724.jpg")
torch_image = torch.tensor(image.transpose(2, 0, 1))

## 3. Inference on the image

In [None]:
predictions = model([torch_image])[0]
# The following code also works:
# predictions = model(torch_image.unsqueeze(0))[0]

## 4. Visualize results

In [None]:
image_for_show = plot_bounding_boxes_on_image_cv2(
    image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
    boxes=predictions["boxes"],
    labels=predictions["labels"],
    scores=predictions["scores"],
    classes=model.CLASSES,  # class information
    show_conf=0.5,
    font_scale=0.5,
    box_thick=2,
    text_alpha=0.75,
)

In [None]:
plt.imshow(image_for_show)