In [None]:
import os
import torch as th
from ultralytics import YOLO


In [None]:
class DetectHandler(object):

    def __init__(self):
        self._context = None
        self.initialized = False
        self.model = None
        self.device = None
    
    def initialize(self, context):
        self.manifest = context.manifest
        properties = context.system_properties
        model_dir = properties.get("model_dir")
        self.device = th.device("cpu")

        # Read model serialize/pt file
        serialized_file = self.manifest['model']['serializedFile']
        model_pt_path = os.path.join(model_dir, serialized_file)
        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model.pt file")

        self.model = th.jit.load(model_pt_path)
        # TODO: 
        #   this implementation can be not fast enough
        #   for more performance try https://github.com/louisoutin/yolov5_torchserve
        self.model = YOLO(
            model_pt_path,  # "../weights/detect.torchscript"
            task="detect"
        )
        self.initialized = True
        

    
    def handle(self, data, context):
        """
        Invoke by TorchServe for prediction request.
        Do pre-processing of data, prediction using model and postprocessing of prediciton output
        :param data: Input data for prediction
        :param context: Initial context contains model server system properties.
        :return: prediction output
        """
        result = []
        boxes = [[int(x) for x in box] for box in data[0].boxes.data.tolist()]
        for box in boxes:
            box_image = data.orig_img[boxes[ibox][1]:boxes[ibox][3], boxes[ibox][0]:boxes[ibox][2],:]
            box_image = box_image / 255.0
            box_image = th.FloatTensor(box_image).permute(2, 0, 1)
            pred_out = self.model(box_image).argmax(2).T.tolist()
            result.append(pred_out)
        return pred_out

In [None]:
class DetectContext:
    def __init__(self):
        self.manifest = {
            "model": {
                "serializedFile": "detect.torchscript"
            }
        }
        self.system_properties = {"model_dir": "../weights/"}

In [None]:
detect = DetectHandler()
det_ctx = DetectContext()
detect.initialize(det_ctx)