# Creating a UDF for the YoloV7 Model

This adapts the object detection script in https://github.com/RizwanMunawar/yolov7-object-tracking for Pixeltable.

This script needs to be run in the environment created for that repo. The UDF is stored in a database `functions` and can subsequently be used in queries and computed columns without having access to the modules used in the script.

In [None]:
import numpy as np
import torch
import torchvision

import models, utils, PIL, thop
from models.experimental import attempt_load
from utils.torch_utils import select_device
from utils.general import non_max_suppression, scale_coords
from utils.datasets import letterbox
%load_ext autoreload
%autoreload 2

In [None]:
weights_file_name = 'yolov7.pt'
device = select_device('')

model = attempt_load(weights_file_name, map_location=device)  # load FP32 model
stride = int(model.stride.max())

**`detect()` takes a PIL.Image and returns a numpy array of detections**

Each detection is a numpy array of 6 floats, containing
* the bounding box (as xyxy)
* the confidence
* the class

In [None]:
def detect(img):
    expected_img_size = 640
    img_array = np.array(img)
    orig_shape = img_array.shape
    img_array = letterbox(img_array, expected_img_size, stride=stride)[0]
    img_array = img_array[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, channels go first
    img_array = np.ascontiguousarray(img_array)
    img_tensor = torch.from_numpy(img_array).to(device).float()
    img_tensor /= 255.0
    img_tensor = img_tensor.unsqueeze(0)

    model_output = model(img_tensor)
    pred = model_output[0]
    pred = non_max_suppression(pred)
    detections = pred[0]
    detections[:, :4] = scale_coords(img_tensor.shape[2:], detections[:, :4], orig_shape).round()
    return detections.numpy(force=True)

Sanity check

In [None]:
img_file = '/home/marcel/.pixeltable/images/frame_1_0_0_00001.jpg'
#img_file = '/home/marcel/pixeltable/pixeltable/tests/data/imagenette2-160/n03888257_50622.JPEG'

img = PIL.Image.open(img_file)

print(img.size)
display(img)

In [None]:
detect(img)

We create database `functions` (or get handle to it)

In [None]:
import sys
sys.path.append('/home/marcel/pixeltable')

import pixeltable as pt
from pixeltable.function import Function, FunctionRegistry
from pixeltable.type_system import ArrayType, ImageType, ColumnType, JsonType

cl = pt.Client()
try:
    db = cl.get_db('functions')
except pt.UnknownEntityError:
    db = cl.create_db('functions')

We then store `detect()` as a named UDF 'yolov7' in the db

In [None]:
FunctionRegistry.register_pickled_module(models)
FunctionRegistry.register_pickled_module(utils)
FunctionRegistry.register_pickled_module(thop)

In [None]:
yolov7_udf = Function(ArrayType((None, 6), dtype=ColumnType.Type.FLOAT), [ImageType()], eval_fn=detect)
db.drop_function('yolov7')
db.create_function('yolov7', yolov7_udf)

Let's verify that it worked.

We're starting with a fresh client to make sure we're not simply referencing cached data.

In [None]:
cl = pt.Client()
db = cl.get_db('functions')
yolov7 = db.load_function('yolov7')

In [None]:
yolov7.eval_fn(img)