# Experiment Notebook
Load .onnx and Verify Embedding without ML-Commons API to see if the problem is with ML-Commons API or the .onnx file itself

In [None]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('../../..')))

In [None]:
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings("ignore", message="Unverified HTTPS request")
warnings.filterwarnings("ignore", message="TracerWarning: torch.tensor")
warnings.filterwarnings("ignore", message="using SSL with verify_certs=False is insecure.")

import opensearch_py_ml as oml
from opensearchpy import OpenSearch
from opensearch_py_ml.ml_models import SentenceTransformerModel
# import mlcommon to later register the model to OpenSearch Cluster
from opensearch_py_ml.ml_commons import MLCommonClient

In [None]:
CLUSTER_URL = 'https://localhost:9200'

## Trace the Model in Onnx Using save_as_onnx
See `opensearch_py_ml/ml_models/sentencetransformermodel.py`

In [None]:
# Below is the function we use in save_as_onnx

# from transformers.convert_graph_to_onnx import convert
# from pathlib import Path

# model = SentenceTransformer(model_id)
folder_path='sentence-transformers-onxx/distiluse-base-multilingual-cased-v1'

# model_name = str(model_id.split("/")[-1] + ".onnx")

model_path = os.path.join(folder_path, "onnx", model_name)
        
# convert(
#     framework="pt",
#     model=model_id,
#     output=Path(model_path),
#     opset=15,
# )

In [None]:
pre_trained_model = SentenceTransformerModel(model_id=model_id, folder_path=folder_path, overwrite=True)
model_path_onnx = pre_trained_model.save_as_onnx(model_id=model_id)

## Load Onnx Model to Check Our .onnx file

In [None]:
import onnx
onnx_model = onnx.load(model_path)
# Check that the model is well formed
onnx.checker.check_model(onnx_model)

# Print a human readable representation of the graph
# print(onnx.helper.printable_graph(onnx_model.graph))

## Verify Embedidngs

In [None]:
import onnxruntime as ort

ort_session = ort.InferenceSession(model_path)

In [None]:
from transformers import AutoTokenizer

input_sentences = ["first sentence", "second sentence", "very very long dksfml smflskdm"]
auto_features = autotokenizer(
            input_sentences, return_tensors="pt", padding=True, truncation=True
        )
auto_features

In [None]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {
    ort_session.get_inputs()[0].name: to_numpy(auto_features['input_ids']),
    ort_session.get_inputs()[1].name: to_numpy(auto_features['attention_mask']),        
             }
ort_outs = ort_session.run(None, ort_inputs)

# Wrong Output Shape

In [None]:
len(ort_outs[0])

In [None]:
ort_outs[0][0].shape

In [None]:
import numpy as np

from sentence_transformers import SentenceTransformer

original_pre_trained_model = SentenceTransformer(model_id) # From Huggingface
original_embedding_data = list(
    pre_trained_model.encode(input_sentences, convert_to_numpy=True)
)

In [None]:
embedding_data_onnx = [
            ort_outs[0][i]
            for i in range(len(input_sentences))
        ]

In [None]:
for i in range(len(input_sentences)):
    print(i)
    print(np.testing.assert_allclose(original_embedding_data[i], embedding_data_onnx[i], rtol=1e-03, atol=1e-05))

## More Info

In [None]:
onnx_model.graph.output

In [None]:
nlp = load_graph_from_args("feature-extraction", "pt", model_id, None)

In [None]:
nlp.model.modules

In [None]:
# https://huggingface.co/docs/transformers/serialization
# https://github.com/oborchers/sentence-transformers/blob/master/examples/onnx_inference/onnx_inference.ipynb
# https://github.com/UKPLab/sentence-transformers/pull/668