<a href="https://colab.research.google.com/github/Wolfffff/gpuhackathon-sleap/blob/main/colab_fullmodel_trt_triton.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A Simple Triton Example with a SLEAP Model

For this to work, you'll need to have a working Triton Inference Server serving your model of interest. In this case, we're using a one exposed through [ngrok](https://ngrok.com/). We can simply pull our data of interest and send it directly to our inference server using gRPC -- or HTTP if we wanted.

In [8]:
%%capture

!wget https://github.com/triton-inference-server/server/releases/download/v2.10.0/v2.10.0_ubuntu2004.clients.tar.gz
!tar -zxvf v2.10.0_ubuntu2004.clients.tar.gz --wildcards python/tritonclient-2.10.0-py3-none-manylinux1_x86_64.whl --strip-components 1
!wget https://raw.githubusercontent.com/Wolfffff/gpuhackathon-sleap/main/triton/triton_utils.py
!wget -P data https://storage.googleapis.com/sleap-data/reference/flies13/190719_090330_wt_18159206_rig1.2%4015000-17560.mp4
!wget -P data https://storage.googleapis.com/sleap-data/reference/flies13/190719_090330_wt_18159206_rig1.2%4015000-17560.slp

In [None]:
%%capture
!pip install sleap
# For some reason triton client install only works from their .whl, so we pull the release, extract, and install.
!pip install tritonclient-2.10.0-py3-none-manylinux1_x86_64.whl[all]

# Restart to deal with package import issue...
import os
os.kill(os.getpid(), 9)

In [None]:
import sleap
import tritonclient.grpc as grpcclient
import palettable
import cv2
import time
import tensorflow as tf

In [None]:
slp_path = "data/190719_090330_wt_18159206_rig1.2@15000-17560.slp"
slp_labels = sleap.load_file(slp_path)

In [None]:
def read_frames(video_path, fidxs=None, grayscale=True):
    """Read frames from a video file.
    
    Args:
        video_path: Path to MP4
        fidxs: List of frame indices or None to read all frames (default: None)
        grayscale: Keep only one channel of the images (default: True)
    
    Returns:
        Loaded images in array of shape (n_frames, height, width, channels) and dtype uint8.
    """
    vr = cv2.VideoCapture(video_path)
    if fidxs is None:
        fidxs = np.arange(vr.get(cv2.CAP_PROP_FRAME_COUNT))
    frames = []
    for fidx in fidxs:
        vr.set(cv2.CAP_PROP_POS_FRAMES, fidx)
        img = vr.read()[1]
        if grayscale:
            img = img[:, :, [0]]
        frames.append(img)
    return np.stack(frames, axis=0)

In [None]:
# Exposed localhost though ngrok -- this is the main retricting factor here.
# ngrok seems to be exceptionally slow but probably fine for this example...
video_path = "data/190719_090330_wt_18159206_rig1.2@15000-17560.mp4"
cap = cv2.VideoCapture(video_path)
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)

triton_url = '2.tcp.ngrok.io:19256'
model_name = "ex"
model_version="1"
protocol = 'grpc'
triton_client = grpcclient.InferenceServerClient(url=triton_url)
sent_count = 0

In [None]:
# Built off of 
# https://github.com/triton-inference-server/client/blob/main/src/python/examples/image_client.py
import argparse
from functools import partial
import os
import sys

import numpy as np

import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient


def parse_model(model_metadata, model_config):
  """
    Check model definition get names 
  """
  if len(model_metadata.inputs) != 1:
      raise Exception("expecting 1 input, got {}".format(
          len(model_metadata.inputs)))
  if len(model_metadata.outputs) != 4:
      raise Exception("expecting 4 outputs, got {}".format(
          len(model_metadata.outputs)))

  if len(model_config.input) != 1:
      raise Exception(
          "expecting 1 input in model configuration, got {}".format(
              len(model_config.input)))

  input_metadata = model_metadata.inputs[0]
  if input_metadata.datatype != "UINT8":
      raise Exception("expecting output datatype to be UINT8, model '" +
                      model_metadata.name + "' output type is " +
                      input_metadata.datatype)
  input_config = model_config.input[0]
  output_metadata = model_metadata.outputs[0]


  output_batch_dim = (model_config.max_batch_size > 0)
  non_one_cnt = 0
  for dim in output_metadata.shape:
      if output_batch_dim:
          output_batch_dim = False
      elif dim > 1:
          non_one_cnt += 1
          if non_one_cnt > 100:
              raise Exception("expecting model output to be a vector")

  # Should be [-1,1024,1024,1] but not batching -- they're processed as a single input.
  expected_input_dims = 4 
  if len(input_metadata.shape) != expected_input_dims:
      raise Exception(
          "expecting input to have {} dimensions, model '{}' input has {}".
          format(expected_input_dims, model_metadata.name,
                  len(input_metadata.shape)))

  n = input_metadata.shape[0]
  h = input_metadata.shape[1]
  w = input_metadata.shape[2]
  c = input_metadata.shape[3]

  return (model_config.max_batch_size, input_metadata.name,
          output_metadata.name, c, h, w, input_config.format,
          input_metadata.datatype)

def requestGenerator(batched_image_data, input_name, output_names, dtype, protocol,model_name,model_version):
  
  if protocol == "grpc":
      client = grpcclient
  else:
      client = httpclient

  # Set the input data
  inputs = [client.InferInput(input_name, batched_image_data.shape, dtype)]
  inputs[0].set_data_from_numpy(batched_image_data)


  outputs = []
  for name in output_names:
      outputs.append(client.InferRequestedOutput(name))

  yield inputs, outputs, model_name, model_version

In [None]:
model_metadata = triton_client.get_model_metadata(model_name=model_name, model_version=model_version)

model_config = triton_client.get_model_config(model_name=model_name, model_version=model_version)

model_config = model_config.config

# Base model info
max_batch_size, input_name, output_name, c, h, w, format, dtype = parse_model(model_metadata, model_config)

# Fix output names for when we have multiple
output_names = [model.name for model in model_metadata.outputs]

Pair all of our model info, generate request, and send it off!

In [None]:
# Pair request generator 
def query_triton(frame):
  responses = []

  # Bad practice... but it saves some time
  global model_name, model_version, sent_count, protocol, sent_count
  
  for inputs, outputs,model_name, model_version in requestGenerator(
          frame, input_name, output_names, dtype, protocol,model_name,model_version):
      responses.append(triton_client.infer(model_name,
                                  inputs,
                                  request_id=str(sent_count),
                                  model_version=model_version,
                                  outputs=outputs))
      sent_count += 1
                                  
  return responses

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import matplotlib.cm as cm
from ipywidgets import interactive, fixed,widgets
import time


def plot_frame(video_path, fidx=0):

  global h,w
  start_timestamp = time.time()

  # Fetch the image corresponding to the frame index
  img = read_frames(video_path=video_path,fidxs=[fidx])

  # Send to Triton server for inference!
  start_timestamp_triton = time.time()
  response = query_triton(img)
  triton_inference_time = (time.time() - start_timestamp_triton)

  points = response[0].as_numpy(output_names[0])

  # Class probs to fix coloration
  class_probabilities = response[0].as_numpy(output_names[2])
  classes = np.argmax(class_probabilities,axis=1)

  # Build sleap instances from response
  instances = [sleap.Instance.from_numpy(points[i,:,:],skeleton=slp_labels.skeleton) for i in range(points.shape[0])]
  sleap.nn.viz.plot_img(img,scale=0.5);
  sleap.nn.viz.plot_instances(instances,cmap=np.array(palettable.wesanderson.Royal1_4.mpl_colors)[classes],alpha=0.5);
  plt.show()

  # # Log timing if you want!
  # print (f'Processed frame: {fidx}')
  # print ('\nTriton request time: %.2f s.' % triton_inference_time)
  # print ('\nFrame processing time: %.2f s.' % (time.time() - start_timestamp))

# Initialize and launch the widget making a fixed size so the text doesnt force scrolling
max_frame_idx = int(frame_count) - 1
interactive_plot = interactive(plot_frame, fidx=widgets.IntSlider(min=0, max=max_frame_idx, step=1, value=0,description='Frame ID'),video_path=fixed(video_path))
# output = interactive_plot.children[-1]
# output.layout.height = '1124px'
interactive_plot