In [2]:
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from onnx_converter.converter import sentence_transformers_onnx

In [3]:
# load the model
model = SentenceTransformer.load('./results/domain_adaptation_model')

In [4]:
# Transform the model to use onnx format
onnx_model = sentence_transformers_onnx(
    model,
    output_path="triton/model_repository/domain_adapter/1/model",
    config_path="results/domain_adaptation_model",
    device=torch.device("cpu")
)
onnx_model

  "token_embeddings": torch.Tensor(hidden_state[0]),
  "attention_mask": torch.Tensor(attention_mask),
  if sentence_embedding.shape[0] == 1:


SentenceTransformerModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 384, padding_idx=0)
      (position_embeddings): Embedding(512, 384)
      (token_type_embeddings): Embedding(2, 384)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (LayerNorm): LayerNorm((384,), eps=1e-12, e

In [5]:
# Compute embeddings for two textual contents and compute dot product
with torch.no_grad():
    tokens_1 = model.tokenize(["Composable Lightweight Processors"])
    embedding_1 = onnx_model(tokens_1["input_ids"], tokens_1["attention_mask"], tokens_1["token_type_ids"]).detach().numpy()
    
    tokens_2 = model.tokenize(["ocean"])
    embedding_2 = onnx_model(tokens_2["input_ids"], tokens_2["attention_mask"], tokens_2["token_type_ids"]).detach().numpy()

# Same results as Pytorch-based model - so conversion seems accurate
np.dot(embedding_1, embedding_2)

-0.13528061

In [20]:
import numpy as np
from torchvision import transforms
from PIL import Image
import tritonclient.http as httpclient
from tritonclient.utils import triton_to_np_dtype

# tokenize sentence
sentence=["Composable Lightweight Processors"]
inputs = model.tokenize(sentence)

input_ids = inputs['input_ids'].numpy()
token_type_ids = inputs['token_type_ids'].numpy()
attention_mask = inputs['attention_mask'].numpy()
input_ids.shape

(1, 7)

In [8]:
# Setting up client
client = httpclient.InferenceServerClient(url="localhost:8000")

input_ids_triton = httpclient.InferInput("input_ids", input_ids.shape, datatype="INT64")
input_ids_triton.set_data_from_numpy(input_ids.astype(np.int64))

token_type_ids_triton = httpclient.InferInput("token_type_ids", token_type_ids.shape, datatype="INT64")
token_type_ids_triton.set_data_from_numpy(token_type_ids.astype(np.int64))

attention_mask_triton = httpclient.InferInput("attention_mask", attention_mask.shape, datatype="INT64")
attention_mask_triton.set_data_from_numpy(attention_mask.astype(np.int64))

output = httpclient.InferRequestedOutput("1770")

# Querying the server
results = client.infer(model_name="domain_adapter", inputs=[input_ids_triton, token_type_ids_triton, attention_mask_triton], outputs=[output])
results

<tritonclient.http._infer_result.InferResult at 0x73496ba481c0>

In [9]:
inference_output = results.as_numpy('1770')
inference_output

array([ 0.0256958 ,  0.01481628,  0.07751465,  0.01100922, -0.08972168,
       -0.03074646, -0.09466553, -0.03234863,  0.00525665, -0.01855469,
        0.01269531, -0.04891968, -0.02073669,  0.00374222,  0.09515381,
       -0.15673828,  0.01867676,  0.04284668,  0.08343506,  0.07293701,
       -0.05923462, -0.07788086, -0.04202271,  0.0297699 ,  0.05596924,
        0.03842163,  0.01280975, -0.07324219,  0.10656738,  0.03051758,
        0.00310135, -0.04406738, -0.0335083 , -0.01245117,  0.03692627,
        0.03503418, -0.01293945, -0.03292847, -0.04998779, -0.10693359,
        0.01838684,  0.00383759,  0.03805542,  0.05883789,  0.06506348,
        0.04418945,  0.05438232,  0.02909851, -0.04785156, -0.03979492,
        0.00611115,  0.07922363,  0.09942627, -0.01474762, -0.00510788,
       -0.0413208 ,  0.08270264, -0.01644897,  0.000669  , -0.0647583 ,
        0.10211182, -0.06185913, -0.0216217 , -0.01271057, -0.01105499,
       -0.05075073,  0.04049683, -0.01849365, -0.03598022,  0.02