In [1]:
import numpy as np
import torch
from flair.data import Sentence
from flair.models.sequence_tagger_utils.bioes import get_spans_from_bio
import tritonclient.grpc as grpcclient
from tqdm.auto import tqdm
import json
from concurrent.futures import ThreadPoolExecutor, as_completed

In [2]:
def string_to_array(string, encoding="utf-8"):
    return np.asarray(list(bytes(string, encoding)))


def bytes_to_string(byte_list):
    return bytes(byte_list.tolist()).decode()


class ClientDecoder:
    def __init__(self, triton_server_url, model_name, model_version):
        self.triton_client = grpcclient.InferenceServerClient(
            url=triton_server_url, verbose=False
        )

        self.model_metadata = self.triton_client.get_model_metadata(
            model_name=model_name, model_version=model_version
        )

        self.model_config = self.triton_client.get_model_config(
            model_name=model_name, model_version=model_version
        ).config
        self.model_name = model_name
        self.viterbi_decoder = torch.load(
            "/workspace/triton-models/flair-ner-english-fast-tokenization/1/viterbi_decoder.bin"
        )

    def submit(self, sentence_bytes, device="cpu"):
        inputs = [
            grpcclient.InferInput("sentence_bytes", sentence_bytes.shape, "INT64"),
        ]

        inputs[0].set_data_from_numpy(sentence_bytes)

        outputs = [grpcclient.InferRequestedOutput("tagged_sentences")]

        response = self.triton_client.infer(self.model_name, inputs, outputs=outputs)

        tagged_sentences = torch.tensor(
            response.as_numpy("tagged_sentences"), device=DEVICE
        )

        return eval(bytes(tagged_sentences).decode())

In [3]:
TRITON_SERVER_URL = "172.25.4.42:8001"
MODEL_NAME = "flair-ner-english-fast-ensemble"
MODEL_VERSION = "1"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MULTIPLIER = 128
SAMPLE_TEXTS = open("strings_list.txt", "r").read()
STRING_LIST = SAMPLE_TEXTS.split("\n") * MULTIPLIER

In [4]:
requests = [string_to_array(string) for string in STRING_LIST]

embeddings = torch.load(
    "/workspace/triton-models/flair-ner-english-fast-tokenization/1/embeddings.bin",
    map_location=torch.device(DEVICE),
)

viterbi_decoder = torch.load(
    "/workspace/triton-models/flair-ner-english-fast-tokenization/1/viterbi_decoder.bin",
    map_location=torch.device(DEVICE),
)

client_decoder = ClientDecoder(TRITON_SERVER_URL, MODEL_NAME, MODEL_VERSION)

sentence_bytes = [string_to_array(string) for string in STRING_LIST]

InferenceServerException: [StatusCode.UNAVAILABLE] failed to connect to all addresses

In [5]:
sentence_bytes = [string_to_array(string) for string in STRING_LIST]
sentence_bytes[0]

array([ 78,  86,  73,  68,  73,  65,  32, 105, 115,  32, 102, 111, 117,
       110, 100, 101, 100,  32,  98, 121,  32,  74, 101, 110, 115, 101,
       110,  32,  72, 117,  97, 110, 103,  44,  32,  67, 104, 114, 105,
       115,  32,  77,  97, 108,  97,  99, 104, 111, 119, 115, 107, 121,
        32,  97, 110, 100,  32,  67, 117, 114, 116, 105, 115,  32,  80,
       114, 105, 101, 109,  46])

In [None]:
est_total = len(sentence_bytes)
pbar = tqdm(
    total=est_total,
    desc="Submitting sentences to {} at {}".format(MODEL_NAME, TRITON_SERVER_URL),
)

responses = []
with ThreadPoolExecutor() as executor:
    for sentence_byte in sentence_bytes:
        futures = []
        futures.append(executor.submit(client_decoder.submit, sentence_byte, DEVICE))

        for future in as_completed(futures):
            infer_results = future.result()
            responses.append(infer_results)
        pbar.update()

In [None]:
responses[0]