In [2]:
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from onnx_converter.converter import sentence_transformers_onnx

In [3]:
# load the model
model = SentenceTransformer.load('./results/domain_adaptation_model')

In [4]:
# Transform the model to use onnx format
onnx_model = sentence_transformers_onnx(
    model,
    output_path="triton/model_repository/domain_adapter/1/model",
    config_path="results/domain_adaptation_model",
    device=torch.device("cpu")
)
onnx_model

  "token_embeddings": torch.Tensor(hidden_state[0]),
  "attention_mask": torch.Tensor(attention_mask),
  if sentence_embedding.shape[0] == 1:


SentenceTransformerModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 384, padding_idx=0)
      (position_embeddings): Embedding(512, 384)
      (token_type_embeddings): Embedding(2, 384)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (LayerNorm): LayerNorm((384,), eps=1e-12, e

In [5]:
# Compute embeddings for two textual contents and compute dot product
with torch.no_grad():
    tokens_1 = model.tokenize(["Composable Lightweight Processors"])
    embedding_1 = onnx_model(tokens_1["input_ids"], tokens_1["attention_mask"], tokens_1["token_type_ids"]).detach().numpy()
    
    tokens_2 = model.tokenize(["ocean"])
    embedding_2 = onnx_model(tokens_2["input_ids"], tokens_2["attention_mask"], tokens_2["token_type_ids"]).detach().numpy()

# Same results as Pytorch-based model - so conversion seems accurate
np.dot(embedding_1, embedding_2)

-0.13528061

In [21]:
import numpy as np
from torchvision import transforms
from PIL import Image
import tritonclient.http as httpclient
from tritonclient.utils import triton_to_np_dtype

# tokenize sentence
sentence=["Composable Lightweight Processors"]
inputs = model.tokenize(sentence)

input_ids = inputs['input_ids'].numpy()
token_type_ids = inputs['token_type_ids'].numpy()
attention_mask = inputs['attention_mask'].numpy()
input_ids.shape

(1, 7)

In [22]:
# Setting up client
client = httpclient.InferenceServerClient(url="localhost:8000")

input_ids_triton = httpclient.InferInput("input_ids", input_ids.shape, datatype="INT64")
input_ids_triton.set_data_from_numpy(input_ids.astype(np.int64))

token_type_ids_triton = httpclient.InferInput("token_type_ids", token_type_ids.shape, datatype="INT64")
token_type_ids_triton.set_data_from_numpy(token_type_ids.astype(np.int64))

attention_mask_triton = httpclient.InferInput("attention_mask", attention_mask.shape, datatype="INT64")
attention_mask_triton.set_data_from_numpy(attention_mask.astype(np.int64))

output = httpclient.InferRequestedOutput("1770")

# Querying the server
results = client.infer(model_name="domain_adapter", inputs=[input_ids_triton, token_type_ids_triton, attention_mask_triton], outputs=[output])
results

<tritonclient.http._infer_result.InferResult at 0x73495d3eb070>

In [23]:
inference_output = results.as_numpy('1770')
inference_output

array([ 1.52457105e-02,  2.90012565e-02,  8.03360790e-02,  2.02961098e-02,
       -8.09904709e-02, -4.72616404e-02, -1.10501960e-01, -4.12367210e-02,
       -1.34621616e-02, -4.28459011e-02,  9.64514352e-03, -4.44887765e-02,
       -1.08766332e-02, -2.58614197e-02,  6.40462413e-02, -1.47566214e-01,
        4.42133732e-02,  4.39511351e-02,  6.85257614e-02,  5.24772704e-02,
       -3.27424370e-02, -8.71217400e-02, -3.72922085e-02,  3.93077284e-02,
        3.54806781e-02,  3.22242863e-02,  2.48503555e-02, -4.15298976e-02,
        1.17560692e-01,  2.10655238e-02, -1.58806201e-02, -3.35653722e-02,
       -4.31694724e-02, -1.00515196e-02,  5.43338433e-02,  2.15568524e-02,
        3.07858200e-03, -2.57533323e-02, -4.77282293e-02, -1.08492963e-01,
        1.15535231e-02,  3.36473882e-02,  2.48221010e-02,  6.98642582e-02,
        5.24503961e-02,  6.22404143e-02,  4.54207361e-02,  2.84772459e-02,
       -4.67003733e-02, -3.11974715e-02,  1.38023775e-03,  9.28664878e-02,
        5.40195741e-02, -