In [1]:
import numpy as np
import torch
from flair.data import Sentence
from flair.models.sequence_tagger_utils.bioes import get_spans_from_bio
import tritonclient.grpc as grpcclient
from tqdm.auto import tqdm
import json
import time
from functools import partial
import random
from concurrent.futures import ThreadPoolExecutor, as_completed

In [2]:
def string_to_array(string, encoding="utf-8", batch=False):
    if batch:
        return np.asarray(list(bytes(string, encoding))).reshape(1, -1)
    else:
        return np.asarray(list(bytes(string, encoding)))


def bytes_to_string(byte_list):
    return bytes(byte_list.tolist()).decode()


class ClientDecoder:
    def __init__(self, triton_server_url, model_name, model_version):
        self.triton_client = grpcclient.InferenceServerClient(
            url=triton_server_url, verbose=False
        )

        self.model_metadata = self.triton_client.get_model_metadata(
            model_name=model_name, model_version=model_version
        )

        self.model_config = self.triton_client.get_model_config(
            model_name=model_name, model_version=model_version
        ).config
        self.model_name = model_name

    def submit(self, sentence_bytes, device="cpu"):
        inputs = [
            grpcclient.InferInput("sentence_bytes", sentence_bytes.shape, "INT64"),
        ]

        inputs[0].set_data_from_numpy(sentence_bytes)

        outputs = [grpcclient.InferRequestedOutput("tagged_sentences")]

        response = self.triton_client.infer(self.model_name, inputs, outputs=outputs)

        tagged_sentences = torch.tensor(
            response.as_numpy("tagged_sentences"), device=DEVICE
        )

        return eval(bytes(tagged_sentences).decode())


class ClientDecoderAsync:
    def __init__(self, triton_server_url, model_name, model_version):
        self.triton_client = grpcclient.InferenceServerClient(
            url=triton_server_url, verbose=False
        )

        self.model_metadata = self.triton_client.get_model_metadata(
            model_name=model_name, model_version=model_version
        )

        self.model_config = self.triton_client.get_model_config(
            model_name=model_name, model_version=model_version
        ).config
        self.model_name = model_name

    def callback(self, user_data, result, error):
        if error:
            user_data.append(error)
        else:
            user_data.append(result)

    def submit(self, sentence_bytes_list, device="cpu"):
        triton_inputs = []
        for sentence_bytes in sentence_bytes_list:
            triton_inputs.append(
                [grpcclient.InferInput("sentence_bytes", sentence_bytes.shape, "INT64")]
            )
            triton_inputs[-1][0].set_data_from_numpy(sentence_bytes)

        outputs = [grpcclient.InferRequestedOutput("tagged_sentences")]

        async_requests = []

        for triton_input in triton_inputs:
            self.triton_client.async_infer(
                model_name=self.model_name,
                inputs=triton_input,
                callback=partial(self.callback, async_requests),
                outputs=outputs,
            )

        while len(async_requests) != len(triton_inputs):
            time.sleep(0.05)

        tagged_sentences = []
        for response in async_requests:
            tagged_sentence = torch.tensor(
                response.as_numpy("tagged_sentences"), device=DEVICE
            )
            tagged_sentence = eval(bytes(tagged_sentence).decode())
            tagged_sentences.append(tagged_sentence)

        return tagged_sentences

In [12]:
TRITON_SERVER_URL = "172.25.4.42:8001"
MODEL_NAME = "flair-ner-english-fast-ensemble"
MODEL_VERSION = "1"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MULTIPLIER = 128
SAMPLE_TEXTS = open("strings_list.txt", "r").read()
STRING_LIST = SAMPLE_TEXTS.split("\n") * MULTIPLIER
STRING_LIST = [sentence for sentence in STRING_LIST if len(sentence) > 0]
# STRING_LIST = sorted(STRING_LIST, key=lambda s: len(s), reverse=True)

sentence_bytes = [string_to_array(string, batch=False) for string in STRING_LIST]
random.shuffle(sentence_bytes)

In [4]:
client_decoder = ClientDecoder(TRITON_SERVER_URL, MODEL_NAME, MODEL_VERSION)

In [5]:
client_decoder_async = ClientDecoderAsync(TRITON_SERVER_URL, MODEL_NAME, MODEL_VERSION)

In [9]:
start = time.time()
responses = [
    client_decoder.submit(sentence_byte, DEVICE) for sentence_byte in sentence_bytes
]
runtime = time.time() - start
print(runtime)

23.843220710754395


In [10]:
start = time.time()
async_responses = client_decoder_async.submit(sentence_bytes)
runtime = time.time() - start
print(runtime)

12.779734373092651
