In [None]:
import json
import os
import sys
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from functools import singledispatch
from multiprocessing import cpu_count

import numpy as np
import tritonclient.grpc as grpcclient
from PIL import Image, ImageDraw
from tqdm import tqdm

if sys.version_info >= (3, 0):
    import queue
else:
    import Queue as queue

In [None]:
def load_image(img_path: str):
    """
    Loads an encoded image as an array of bytes.

    """

    return np.expand_dims(np.fromfile(img_path, dtype="uint8"), axis=0)


def render_image(filename, image_wise_bboxes, outline_color=(118, 185, 0), linewidth=3):
    """Render images with overlain outputs."""
    image = Image.open(filename).convert("RGB")
    w, h = image.size
    draw = ImageDraw.Draw(image)
    wpercent = 736 / float(image.size[0])
    linewidth = int(linewidth / wpercent)
    for box in image_wise_bboxes:
        if (box[2] - box[0]) >= 0 and (box[3] - box[1]) >= 0:
            draw.rectangle(box, outline=outline_color)
            for i in range(linewidth):
                x1 = max(0, box[0] - i)
                y1 = max(0, box[1] - i)
                x2 = min(w, box[2] + i)
                y2 = min(h, box[3] + i)
                draw.rectangle(box, outline=outline_color)
    return image


def submit_to_triton(image_data, input_name, output_names, request_id=None):
    inputs = [grpcclient.InferInput(input_name, image_data.shape, "UINT8")]
    inputs[0].set_data_from_numpy(image_data)

    outputs = [
        grpcclient.InferRequestedOutput(output_name, class_count=0)
        for output_name in output_names
    ]
    return triton_client.infer(
        model_name, inputs, outputs=outputs, request_id=request_id
    )


def parse_model_grpc(model_metadata, model_config):
    """
    Check the configuration of a model to make sure it meets the
    requirements for an image classification network (as expected by
    this client)
    """
    if len(model_metadata.inputs) != 1:
        raise Exception("expecting 1 input, got {}".format(len(model_metadata.inputs)))

    if len(model_config.input) != 1:
        raise Exception(
            "expecting 1 input in model configuration, got {}".format(
                len(model_config.input)
            )
        )

    input_metadata = model_metadata.inputs[0]
    output_metadata = model_metadata.outputs

    return (input_metadata.name, output_metadata, model_config.max_batch_size)


@singledispatch
def to_serializable(val):
    """Used by default."""
    return str(val)

In [None]:
render = True
n_processes = cpu_count()
n_threads = (n_processes * 4) - 1
model_name = "facenet_ensemble"
model_version = "1"
url = "172.25.0.42:8001"
image_folder = "/workspace/sample-imgs"

json_data_filename = "/workspace/facedetect_data.json"

filenames = [
    os.path.join(image_folder, f)
    for f in os.listdir(image_folder)
    if os.path.isfile(os.path.join(image_folder, f))
] * n_threads

triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)

model_metadata = triton_client.get_model_metadata(
    model_name=model_name, model_version=model_version
)

model_config = triton_client.get_model_config(
    model_name=model_name, model_version=model_version
).config

input_name, output_metadata, batch_size = parse_model_grpc(model_metadata, model_config)
output_names = [i.name for i in output_metadata]

In [None]:
results = {}
with ProcessPoolExecutor(max_workers=n_threads) as executor:
    p_images_data = [
        p_image_data for p_image_data in executor.map(load_image, filenames)
    ]

process_pool_data = zip(p_images_data, filenames)

with ThreadPoolExecutor(max_workers=n_threads) as executor:
    with tqdm(total=len(filenames)) as progress:
        future_to_request = [
            executor.submit(
                submit_to_triton, t_image_data, input_name, output_names, t_request_id
            )
            for t_image_data, t_request_id in process_pool_data
        ]

        for future in as_completed(future_to_request):
            future.add_done_callback(lambda p: progress.update())
            infer_result = future.result()
            this_id = infer_result.get_response().id
            image_wise_bboxes = infer_result.as_numpy(output_names[0]).reshape(-1, 4)
            image_probas = infer_result.as_numpy(output_names[1]).reshape(-1, 1)
            results[this_id] = {output_names[0]: image_wise_bboxes}

In [None]:
if render:
    print("Rendering unique files in filenames...")
    rendered_images = []
    img_w = 800
    for this_id in results:
        image_wise_bboxes = results[this_id][output_names[0]]
        img = render_image(this_id, image_wise_bboxes)
        wpercent = img_w / float(img.size[0])
        hsize = int(float(img.size[1]) * float(wpercent))
        img = img.resize((img_w, hsize), Image.ANTIALIAS)
        img.show()

In [None]:
json_data_file = {}

for file in results:
    face_list = []
    for idx, bbox in enumerate(results[file]["true_boxes"]):
        face_idx = "face_{}".format(idx)
        face_list.append({face_idx: bbox})
    json_data_file[file] = face_list

with open(json_data_filename, "w") as outfile:
    json.dump(json_data_file, outfile, default=to_serializable)

___