## This adapts the scripts in https://github.com/RizwanMunawar/yolov7-object-tracking for Pixeltable and should be run in the environment created for that repo.

In [None]:
import models, utils
from models.experimental import attempt_load
from utils.torch_utils import select_device
from utils.general import non_max_suppression

In [None]:
weights_file_name = 'yolov7.pt'
device = select_device('')

In [None]:
model = attempt_load(weights_file_name, map_location=device)  # load FP32 model

## detect() takes a PIL.Image and returns a numpy array of detections
Each detection is a numpy array of 6 floats, containing
* the bounding box (as xyxy)
* the confidence
* the class

In [None]:
def detect(img):
    expected_img_size = 640  # required by model
    # resize img to expected size while maintaining aspect ratio
    resize_ratio = min(expected_img_size / img.width, expected_img_size / img.height)
    new_width = int(resize_ratio * img.width)
    new_height = int(resize_ratio * img.height)
    img = img.resize((new_width, new_height))
    
    img_tensor = torchvision.transforms.functional.pil_to_tensor(img) / 255.0
    # turn into batch of 1
    img_tensor = img_tensor.unsqueeze(0)

    model_output = model(img_tensor)
    pred = model_output[0]
    pred = non_max_suppression(pred)
    detections = pred[0].numpy(force=True)
    detections[:, :4] /= resize_ratio  # convert to original image size
    return detections

# Store detect() in Pixeltable as a UDF

In [None]:
import sys
sys.path.append('/home/marcel/pixeltable')

import pixeltable as pt
from pixeltable.function import Function
from pixeltable.type_system import ArrayType, ImageType, ColumnType

cl = pt.Client()
db = cl.get_db('functions')

In [None]:
yolov7_udf = Function(ArrayType((None, 6), dtype=ColumnType.Type.FLOAT), [ImageType()], eval_fn=detect)

db.create_function('yolo.yolov7', yolov7_udf)

### We're loading it back from Pixeltable to show that it worked

We're starting with a fresh client.

In [None]:
cl = pt.Client()
db = cl.get_db('functions')
yolov7_udf = db.load_function('yolo.yolov7')

In [None]:
img_file_path = '/home/marcel/pixeltable/pixeltable/tests/data/imagenette2-160/n03445777_2563.JPEG'
img = PIL.Image.open(img_file)

In [None]:
detections = yolov7_udf.eval_fn(img)
detections