In [1]:
import numpy as np
import torch
import tritonclient.grpc as grpcclient
from transformers import BertTokenizer

In [2]:
TRITON_SERVER_URL = "172.25.4.42:8001"
MODEL_NAME = "bert-base-uncased"
MODEL_VERSION = "1"

In [3]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [4]:
triton_client = grpcclient.InferenceServerClient(url=TRITON_SERVER_URL, verbose=False)

model_metadata = triton_client.get_model_metadata(
    model_name=MODEL_NAME, model_version=MODEL_VERSION
)

model_config = triton_client.get_model_config(
    model_name=MODEL_NAME, model_version=MODEL_VERSION
).config

In [5]:
text_0 = "Who are the founders of NVIDIA?"
text_1 = "NVIDIA is founded by Jensen Huang, Chris Malachowsky and Curtis Priem."

tokenized_tensor_0 = tokenizer(text_0, add_special_tokens=True, return_tensors="pt")
tokenized_tensor_1 = tokenizer(text_1, add_special_tokens=True, return_tensors="pt")
tokens_tensor = torch.concat(
    (tokenized_tensor_0["input_ids"], tokenized_tensor_1["input_ids"]), axis=1
)
segments_tensors = torch.concat(
    (tokenized_tensor_0["token_type_ids"], tokenized_tensor_1["attention_mask"]), axis=1
)

inputs = [
    grpcclient.InferInput("INPUT__0", tokens_tensor.shape, "INT64"),
    grpcclient.InferInput("INPUT__1", segments_tensors.shape, "INT64"),
]
inputs[0].set_data_from_numpy(tokens_tensor.numpy())
inputs[1].set_data_from_numpy(segments_tensors.numpy())

outputs = [grpcclient.InferRequestedOutput("OUTPUT__0")]

response = triton_client.infer(MODEL_NAME, inputs, outputs=outputs)
response.as_numpy("OUTPUT__0")

array([[[-0.08405367,  0.3233205 , -0.11031957, ..., -0.11474957,
          0.22691588,  0.30055675],
        [-0.6977316 , -0.56306213,  0.3189223 , ...,  0.59553444,
          0.318297  ,  0.37281016],
        [-0.09979279, -0.2772794 ,  0.00791831, ...,  0.2773895 ,
          0.8962006 ,  0.22550069],
        ...,
        [-0.10681662, -0.2938968 , -0.5867815 , ...,  0.93991256,
          0.2358838 , -0.25276625],
        [ 0.7438755 , -0.01677273, -0.3546151 , ...,  0.16310963,
         -0.52852833, -0.32508034],
        [ 0.56116414,  0.17081466, -0.17988236, ...,  0.22354889,
         -0.5551076 , -0.30939323]]], dtype=float32)