In [1]:
import numpy as np
import torch
from flair.data import Sentence
from flair.models.sequence_tagger_utils.bioes import get_spans_from_bio
import tritonclient.grpc as grpcclient

from flair_model_surgery import TritonFastNERTagger

In [2]:
TRITON_SERVER_URL = "172.25.4.42:8001"
MODEL_NAME = "flair-ner-english-fast"
MODEL_VERSION = "1"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
class ClientDecoder:
    def __init__(self, tagger, triton_server_url, model_name, model_version):
        self.tagger = tagger
        self.triton_client = grpcclient.InferenceServerClient(
            url=triton_server_url, verbose=False
        )

        self.model_metadata = self.triton_client.get_model_metadata(
            model_name=model_name, model_version=model_version
        )

        self.model_config = self.triton_client.get_model_config(
            model_name=model_name, model_version=model_version
        ).config
        self.model_name = model_name

    def submit(self, string, device="cpu"):
        sentences = [Sentence(string)]
        sorted_lengths, sentence_tensor = self.tagger.forward(sentences)
        inputs = [
            grpcclient.InferInput("INPUT__0", sorted_lengths.shape, "INT64"),
            grpcclient.InferInput("INPUT__1", sentence_tensor.shape, "FP32"),
        ]

        inputs[0].set_data_from_numpy(sorted_lengths.cpu().numpy())
        inputs[1].set_data_from_numpy(sentence_tensor.cpu().numpy())

        outputs = [
            grpcclient.InferRequestedOutput("OUTPUT__0"),
            grpcclient.InferRequestedOutput("OUTPUT__1"),
            grpcclient.InferRequestedOutput("OUTPUT__2"),
        ]

        response = self.triton_client.infer(self.model_name, inputs, outputs=outputs)

        features = torch.tensor(response.as_numpy("OUTPUT__0"), device=DEVICE)
        sorted_lengths = torch.tensor(response.as_numpy("OUTPUT__1"), device=DEVICE)
        transitions = torch.tensor(response.as_numpy("OUTPUT__2"), device=DEVICE)

        embedding = (features, sorted_lengths, transitions)

        predictions, all_tags = tagger.viterbi_decoder.decode(
            embedding, True, sentences
        )

        for sentence, sentence_predictions in zip(sentences, predictions):
            sentence_tags = [label[0] for label in sentence_predictions]
            sentence_scores = [label[1] for label in sentence_predictions]
            predicted_spans = get_spans_from_bio(sentence_tags, sentence_scores)
            for predicted_span in predicted_spans:
                span = sentence[predicted_span[0][0] : predicted_span[0][-1] + 1]
                span.add_label("ner", value=predicted_span[2], score=predicted_span[1])

        return sentences[0]

In [4]:
tagger = torch.load("tagger.bin", map_location=torch.device(DEVICE))
client_decoder = ClientDecoder(tagger, TRITON_SERVER_URL, MODEL_NAME, MODEL_VERSION)

In [5]:
string_list = [
    "With the belief that the PC one day would become a consumer device for enjoying games and multimedia, NVIDIA is founded by Jensen Huang, Chris Malachowsky and Curtis Priem.",
    "At the time, there were more than two dozen graphics chips companies, a number that would soar to 70 three years later.",
    "By 2006, NVIDIA was the only independent still operating.",
    "NVIDIA's reaches its first strategic partnership with SGS-Thomson Microelectronics to manufacture the company's single-chip graphical-user interface accelerator.",
    "Diamond Multimedia Systems is selected to install the chips in multimedia accelerator boards.",
    "Sega, the leader in arcade games, ports Virtual Fighter to be the first 3D game to run on NVIDIA graphics.",
    "NVIDIA unveils its first Microsoft DirectX drivers with support for Direct3D, an API used to render 3D graphics where performance is critical.",
    "The company introduces RIVA 128, the world's first 128-bit 3D processor.",
    "It receives OEM acceptance, and more than one million units are shipped within its first four months.",
]

In [6]:
sentences = [client_decoder.submit(string, DEVICE) for string in string_list]

In [7]:
dict_format = {}
for sentence in sentences:
    sentence_list = []
    for entity in sentence.get_spans("ner"):
        sentence_list.append(
            {
                "entity_group": entity.tag,
                "start": entity.start_position,
                "word": entity.text,
                "end": entity.end_position,
                "score": int(entity.score * 100),
            }
        )
    dict_format[sentence.text] = sentence_list

dict_format

{'With the belief that the PC one day would become a consumer device for enjoying games and multimedia , NVIDIA is founded by Jensen Huang , Chris Malachowsky and Curtis Priem .': [{'entity_group': 'ORG',
   'start': 25,
   'word': 'PC',
   'end': 27,
   'score': 60},
  {'entity_group': 'ORG',
   'start': 102,
   'word': 'NVIDIA',
   'end': 108,
   'score': 98},
  {'entity_group': 'PER',
   'start': 123,
   'word': 'Jensen Huang',
   'end': 135,
   'score': 99},
  {'entity_group': 'PER',
   'start': 137,
   'word': 'Chris Malachowsky',
   'end': 154,
   'score': 99},
  {'entity_group': 'PER',
   'start': 159,
   'word': 'Curtis Priem',
   'end': 171,
   'score': 99}],
 'At the time , there were more than two dozen graphics chips companies , a number that would soar to 70 three years later .': [],
 'By 2006 , NVIDIA was the only independent still operating .': [{'entity_group': 'ORG',
   'start': 9,
   'word': 'NVIDIA',
   'end': 15,
   'score': 99}],
 "NVIDIA 's reaches its first strat