In [1]:
import time
import pprint
import multiprocessing
from pathlib import Path

import onnx
import torch
import transformers

import numpy as np
import onnxruntime as rt

from sentence_transformers import SentenceTransformer
from transformers import convert_graph_to_onnx

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
span = "first sentence"


model_id = "sentence-transformers/distiluse-base-multilingual-cased-v2"
model_raw = SentenceTransformer(model_id)

In [3]:
model_pipeline = transformers.FeatureExtractionPipeline(
    model=transformers.AutoModel.from_pretrained(model_id),
    tokenizer=transformers.AutoTokenizer.from_pretrained(model_id, use_fast=True),
    framework="pt",
    device=-1
)

config = model_pipeline.model.config
tokenizer = model_pipeline.tokenizer

with torch.no_grad():
    input_names, output_names, dynamic_axes, tokens = convert_graph_to_onnx.infer_shapes(
        model_pipeline, 
        "pt"
    )
    ordered_input_names, model_args = convert_graph_to_onnx.ensure_valid_input(
        model_pipeline.model, tokens, input_names
    )

Found input input_ids with shape: {0: 'batch', 1: 'sequence'}
Found input attention_mask with shape: {0: 'batch', 1: 'sequence'}
Found output output_0 with shape: {0: 'batch', 1: 'sequence'}
Found output output_1 with shape: {0: 'batch', 1: 'sequence'}
Found output output_2 with shape: {0: 'batch', 1: 'sequence'}
Found output output_3 with shape: {0: 'batch', 1: 'sequence'}
Found output output_4 with shape: {0: 'batch', 1: 'sequence'}
Found output output_5 with shape: {0: 'batch', 1: 'sequence'}
Found output output_6 with shape: {0: 'batch', 1: 'sequence'}
Found output output_7 with shape: {0: 'batch', 1: 'sequence'}
Ensuring inputs are in correct order
head_mask is not present in the generated input list.
Generated inputs order: ['input_ids', 'attention_mask']


In [4]:
print(input_names)
print(output_names)
print(dynamic_axes)
print(tokens)
print(ordered_input_names)
print(model_args)

['input_ids', 'attention_mask']
['output_0', 'output_1', 'output_2', 'output_3', 'output_4', 'output_5', 'output_6', 'output_7']
{'input_ids': {0: 'batch', 1: 'sequence'}, 'attention_mask': {0: 'batch', 1: 'sequence'}, 'output_0': {0: 'batch', 1: 'sequence'}, 'output_1': {0: 'batch', 1: 'sequence'}, 'output_2': {0: 'batch', 1: 'sequence'}, 'output_3': {0: 'batch', 1: 'sequence'}, 'output_4': {0: 'batch', 1: 'sequence'}, 'output_5': {0: 'batch', 1: 'sequence'}, 'output_6': {0: 'batch', 1: 'sequence'}, 'output_7': {0: 'batch', 1: 'sequence'}}
{'input_ids': tensor([[  101, 10747, 10124,   169, 45700, 37131,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}
['input_ids', 'attention_mask']
(tensor([[  101, 10747, 10124,   169, 45700, 37131,   102]]), tensor([[1, 1, 1, 1, 1, 1, 1]]))


In [5]:
for i in range(8):
    del dynamic_axes[f"output_{i}"] # Delete unused output

output_names = ["sentence_embedding"]
dynamic_axes["sentence_embedding"] = {0: 'batch'}

# Check that everything worked
print(output_names)
print(dynamic_axes)

['sentence_embedding']
{'input_ids': {0: 'batch', 1: 'sequence'}, 'attention_mask': {0: 'batch', 1: 'sequence'}, 'sentence_embedding': {0: 'batch'}}


In [6]:
import torch
from sentence_transformers.models import Dense

class SentenceTransformer(transformers.DistilBertModel):
    def __init__(self, config):
        super().__init__(config)
        # Naming alias for ONNX output specification
        # Makes it easier to identify the layer
        self.sentence_embedding = torch.nn.Identity()

    def forward(self, input_ids, attention_mask):
        # Get the token embeddings from the base model
        token_embeddings = super().forward(
            input_ids, 
            attention_mask=attention_mask, 
        )[0]
        # Stack the pooling layer on top of it
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        # Add dense layer
        intermediate_embeddings = self.sentence_embedding(sum_embeddings / sum_mask)
        
        feature_out = {'sentence_embedding': torch.FloatTensor(intermediate_embeddings)}
        dense_layer = Dense(768, 512, bias=True, activation_function=torch.nn.modules.activation.Tanh())
        dense_layer.forward(feature_out)
        return feature_out['sentence_embedding']

# Create the new model based on the config of the original pipeline
model = SentenceTransformer(config=config).from_pretrained(model_id)

In [7]:
np.testing.assert_allclose(
    model_raw.encode(span),
    model(**tokenizer(span, return_tensors="pt")).squeeze().detach().numpy(),
    atol=1e-6,
)

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=1e-06

Mismatched elements: 512 / 512 (100%)
Max absolute difference: 0.19937615
Max relative difference: 353.38086
 x: array([ 2.924567e-02,  6.141232e-02, -4.720753e-02,  7.542608e-02,
       -1.127941e-02, -2.926223e-02, -9.203118e-04,  7.731913e-03,
        9.389930e-03, -5.170759e-02,  1.561492e-02, -1.805862e-02,...
 y: array([ 3.328726e-02, -1.804699e-02,  1.051623e-02, -6.104988e-02,
        1.422415e-03, -3.669605e-03, -1.427884e-02,  4.272969e-03,
       -4.546718e-02, -4.240783e-02,  1.901109e-02, -1.150605e-01,...