# Creating a UDF for the YoloV7 Model

This adapts the object detection script in https://github.com/RizwanMunawar/yolov7-object-tracking for Pixeltable.

This script needs to be run in the environment created for that repo. The UDF is stored in a database `functions` and can subsequently be used in queries and computed columns without having access to the modules used in the script.

In [None]:
import numpy as np
import torch
import torchvision
import cv2
from collections import defaultdict
import os

import models, utils, PIL, thop, filterpy
from models.experimental import attempt_load
from utils.torch_utils import select_device, TracedModel
from utils.general import non_max_suppression, scale_coords
from utils.datasets import letterbox
from utils.download_weights import download
import sort_track as sort
%load_ext autoreload
%autoreload 2

In [None]:
weights_file_name = 'yolov7.pt'
if not os.path.exists(weights_file_name):
    download('./')

device = select_device('')

model = attempt_load(weights_file_name, map_location=device)  # load FP32 model
stride = int(model.stride.max())

**`detect()` takes a PIL.Image and returns a numpy array of detections**

Each detection is a numpy array of 6 floats, containing
* the bounding box (as xyxy)
* the confidence
* the class

In [None]:
def detect(img):
    expected_img_size = 640
    img_array = np.array(img)
    img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
    orig_shape = img_array.shape
    img_array = letterbox(img_array, expected_img_size, stride=stride)[0]
    img_array = img_array[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, channels go first
    img_array = np.ascontiguousarray(img_array)
    img_tensor = torch.from_numpy(img_array).to(device).float()
    img_tensor /= 255.0
    img_tensor = img_tensor.unsqueeze(0)

    with torch.no_grad():
        model_output = model(img_tensor)
        pred = model_output[0]
        pred = non_max_suppression(pred)
        detections = pred[0]
        detections[:, :4] = scale_coords(img_tensor.shape[2:], detections[:, :4], orig_shape).round()
        return detections.numpy(force=True)

Sanity check

In [None]:
#img_file = '/home/marcel/.pixeltable/images/frame_1_0_0_00001.jpg'
img_file = '/home/marcel/pixeltable/pixeltable/tests/data/imagenette2-160/n03888257_50622.JPEG'

img = PIL.Image.open(img_file)

print(img.size)
display(img)

In [None]:
detect(img)

We create database `functions` (or get handle to it)

In [None]:
import sys
sys.path.append('/home/marcel/pixeltable')

import pixeltable as pt
from pixeltable.function import Function, FunctionRegistry
from pixeltable.type_system import ArrayType, ImageType, ColumnType, JsonType

cl = pt.Client()
try:
    db = cl.get_db('functions')
except pt.UnknownEntityError:
    db = cl.create_db('functions')

We then store `detect()` as a named UDF 'yolov7' in the db

In [None]:
from cloudpickle import register_pickle_by_value
register_pickle_by_value(models)
register_pickle_by_value(utils)
register_pickle_by_value(thop)
register_pickle_by_value(sort)
register_pickle_by_value(filterpy)

In [None]:
yolov7_udf = Function(
    ArrayType((None, 6), dtype=ColumnType.Type.FLOAT), [ImageType()], eval_fn=detect)

In [None]:
try:
    db.create_function('yolov7', yolov7_udf)
except:
    db.update_function('yolov7', yolov7_udf)

# Tracking

Object tracking involves assigning the detections in a frame object IDs based on what was detected in previous frames, and can be easily expressed as a windowed aggregate function.

In [None]:
class SortTracker:
    def __init__(self):
        self.current_id = 0
        self.tracker = sort.Sort(self.next_id, max_age=5, min_hits=2, iou_threshold=0.2)
    @classmethod
    def make_instance(cls):
        return cls()
    def next_id(self):
        self.current_id += 1
        return self.current_id
    def update(self, detections):
        self.tracker.update(detections)
    def value(self):
        # the most recent value is the last recorded bbox
        #return [np.hstack((track.bbox_history[-1][:-1], [track.id])) for track in self.tracker.getTrackers()]
        return [
            {
                'id': track.id,
                'bbox': track.bbox_history[-1][:4].astype(int).tolist(),
                'conf': track.bbox_history[-1][4].item(),  # make sure to return Python types
                'class': track.bbox_history[-1][5].item(),
            }
            for track in self.tracker.getTrackers()
        ]

In [None]:
sort_track_udf = Function(
    JsonType(),
    [ArrayType((None, 6), dtype=ColumnType.Type.FLOAT)],
    init_fn=SortTracker.make_instance, update_fn=SortTracker.update, value_fn=SortTracker.value)

In [None]:
try:
    db.create_function('sort_track', sort_track_udf)
except:
    db.update_function('sort_track', sort_track_udf)

Let's verify that it worked.

We're starting with a fresh client to make sure we're not simply referencing cached data.

In [None]:
cl2 = pt.Client()
db2 = cl.get_db('functions')
yolov7 = db.get_function('yolov7')
sort_track = db.get_function('sort_track')

In [None]:
detection = yolov7.eval_fn(img)
detection[:, :4]

In [None]:
state = sort_track.init_fn()
sort_track.update_fn(state, detection)
track_info = sort_track.value_fn(state)

In [None]:
track_info

# Visualization

One way of visualizing the output of the tracking algorithm is by drawing a line to track the bounding boxes of identified objects over time. This is again easily expressed as a windowed aggregate function, which can keep track of the objects and their centroids over time.

In [None]:
class TrackingViz:
    def __init__(self):
        self.centroid_history = defaultdict(list)  # id -> list of centroids                                                                                                                                                             
        self.viz = None  # last image with overlayed visualizations                                                                                                                                                                      

    @classmethod
    def make_instance(cls):
        return cls()

    def update(self, img, bboxes, ids=None):
        assert len(bboxes) == len(ids)
        for i in range(len(bboxes)):
            id, bbox = ids[i], bboxes[i]
            centroid = int((bbox[0] + bbox[2]) // 2), int((bbox[1] + bbox[3]) // 2)
            self.centroid_history[id].append(centroid)

        # create image with visualizations                                                                                                                                                                                               
        self.viz = np.array(img)
        # draw per-object track (straight lines between consecutive centroids)                                                                                                                                                           
        for centroids in self.centroid_history.values():
            for i in range(len(centroids) - 1):
                cv2.line(self.viz, centroids[i], centroids[i + 1], (255, 0, 0), thickness=2)

        for i, box in enumerate(bboxes):
            x1, y1, x2, y2 = [int(i) for i in box]
            id = int(ids[i]) if ids is not None else 0
            #label = str(id) + ":"+ names[cat]                                                                                                                                                                                           
            label = str(id)
            (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
            # bounding box
            cv2.rectangle(self.viz, (x1, y1), (x2, y2), (255,0,20), 2)
            # label rectangle                                                                                                                                                                                                            
            cv2.rectangle(self.viz, (x1, y1 - 20), (x1 + w, y1), (255,144,30), -1)
            # label
            cv2.putText(self.viz, label, (x1, y1 - 5),cv2.FONT_HERSHEY_SIMPLEX, 0.6, [255, 255, 255], 1)

    def value(self):
        return PIL.Image.fromarray(self.viz)

In [None]:
track_viz_udf = Function(
    # signature: (image, bounding_boxes, ids) -> image
    ImageType(), [ImageType(), JsonType(), JsonType()],
    init_fn=TrackingViz.make_instance, update_fn=TrackingViz.update, value_fn=TrackingViz.value)

In [None]:
try:
    db.create_function('track_viz', track_viz_udf)
except:
    db.update_function('track_viz', track_viz_udf)

Let's see what that looks like:

In [None]:
viz_state = track_viz_udf.init_fn()
track_viz_udf.update_fn(viz_state, img, detection[:, :4].tolist(), [1, 2])
track_viz_udf.value_fn(viz_state)