In [1]:
import os
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, Union, Mapping, OrderedDict

import torch
from transformers.onnx import export
from transformers.onnx import OnnxConfig
from transformers.utils import ModelOutput
from sentence_transformers.models import Dense
from transformers import AutoTokenizer, AutoModel, DistilBertModel

# get with SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v2', cache_folder=".")
model_ckpt = "./sentence-transformers_distiluse-base-multilingual-cased-v2"

class SBertOnnxConfig(OnnxConfig):
    @property
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
        return OrderedDict([
            ("input_ids", {0: "batch", 1: "sequence"}),
            ("attention_mask", {0: "batch", 1: "sequence"})
        ])
    @property
    def outputs(self) -> Mapping[str, Mapping[int, str]]:
        return OrderedDict([
                ("last_hidden_state", {0: "batch", 1: "sequence"})
        ])

@dataclass
class EmbeddingOutput(ModelOutput):
    last_hidden_state: Optional[torch.FloatTensor] = None

class OwnSBert(DistilBertModel):
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],  *model_args, **kwargs):
        _model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
        additional_layer = Dense.load(kwargs.get("path_to_additional_layer"))
        _model.additional_layer_linear = additional_layer.linear
        _model.additional_layer_activation = additional_layer.activation_function
        return _model

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        embeddings = super().forward(input_ids=input_ids,
                               attention_mask=attention_mask,
                               head_mask=head_mask,
                               inputs_embeds=inputs_embeds,
                               output_attentions=True,
                               output_hidden_states=True,
                               return_dict=True)

        mean_embedding = embeddings.last_hidden_state.mean(dim=1)
        last_hidden_state = self.additional_layer_activation(self.additional_layer_linear(mean_embedding))
        return EmbeddingOutput(last_hidden_state=last_hidden_state)


tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
base_model = OwnSBert.from_pretrained(model_ckpt, path_to_additional_layer="./sentence-transformers_distiluse-base-multilingual-cased-v2/2_Dense")

# print(base_model(**tokenizer([sentences[0], sentences[1]], padding="longest", truncation=True, return_tensors="pt")))

onnx_path = Path("exported_model/model.onnx")
onnx_config = SBertOnnxConfig.from_model_config(base_model.config)
onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)
base_model.config.save_pretrained("./exported_model/")

In [6]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("BAAI/bge-small-en", device="cuda")
model.max_seq_length = 384
model.half()

SentenceTransformer(
  (0): Transformer({'max_seq_length': 384, 'do_lower_case': True}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [8]:
import pandas as pd

df = pd.read_csv("../preprocessed/334_tfidf_gpu/000/data2.csv")

In [9]:
df

Unnamed: 0,prompt,A,B,C,D,E,answer,context
0,Which of the following statements accurately d...,MOND is a theory that reduces the observed mis...,MOND is a theory that increases the discrepanc...,MOND is a theory that explains the missing bar...,MOND is a theory that reduces the discrepancy ...,MOND is a theory that eliminates the observed ...,D,Modified Newtonian dynamics > Modified Newtoni...
1,Which of the following is an accurate definiti...,Dynamic scaling refers to the evolution of sel...,Dynamic scaling refers to the non-evolution of...,Dynamic scaling refers to the evolution of sel...,Dynamic scaling refers to the non-evolution of...,Dynamic scaling refers to the evolution of sel...,A,Dynamic scaling > Dynamic scaling > Here the e...
2,Which of the following statements accurately d...,The triskeles symbol was reconstructed as a fe...,The triskeles symbol is a representation of th...,The triskeles symbol is a representation of a ...,The triskeles symbol represents three interloc...,The triskeles symbol is a representation of th...,A,Triskelion > Use in European antiquity > Class...
3,What is the significance of regularization in ...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,C,Regularization (physics) > Classical physics e...
4,Which of the following statements accurately d...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,D,Diffraction > Patterns > Several qualitative o...
...,...,...,...,...,...,...,...,...
1195,What is the systematic name of the enzyme clas...,Medium-chain acyl-CoA hydrolase,Medium-chain hydrolase,Medium-chain acyl coenzyme A hydrolase,Medium-chain acyl-thioester hydrolase,ADP-dependent-medium-chain-acyl-CoA hydrolase,E,ADP-dependent medium-chain-acyl-CoA hydrolase ...
1196,Which of the following statements accurately d...,A polygon covering problem refers to finding a...,A polygon covering problem refers to finding a...,A polygon covering problem refers to finding a...,A polygon covering problem refers to finding a...,A polygon covering problem refers to finding a...,A,Polygon covering > Polygon covering > In geome...
1197,What is V1936 Aquilae?,V1936 Aquilae is a blue supergiant and candida...,V1936 Aquilae is a red supergiant and candidat...,V1936 Aquilae is a blue supergiant and candida...,V1936 Aquilae is a yellow dwarf star located i...,V1936 Aquilae is a binary star system located ...,A,TT Aquilae > TT Aquilae > TT Aquilae (TT Aql) ...
1198,How does Pirlimycin hydrochloride act against ...,Pirlimycin hydrochloride inhibits bacterial pr...,Pirlimycin hydrochloride disrupts the cell mem...,Pirlimycin hydrochloride inhibits bacterial pr...,Pirlimycin hydrochloride inhibits DNA replicat...,Pirlimycin hydrochloride blocks bacterial cell...,A,Pirlimycin > Pirlimycin > Pirlimycin hydrochlo...


In [18]:
from tqdm import tqdm

In [19]:
%%time
embeddings = []

for i, row in tqdm(df.iterrows()):
    embedding = embed_model.get_text_embedding(row["context"])
    embeddings.append(embedding)

1200it [01:38, 12.24it/s]

CPU times: user 14min 42s, sys: 1.09 s, total: 14min 43s
Wall time: 1min 38s





In [13]:
%%time
section_embeddings = model.encode(
    df.context.tolist(),
    batch_size=32,
    device="cuda",
    show_progress_bar=True,
    convert_to_tensor=False,
    normalize_embeddings=True,
)

Batches:   0%|          | 0/38 [00:00<?, ?it/s]

CPU times: user 11.7 s, sys: 2.98 s, total: 14.7 s
Wall time: 5.39 s


In [15]:
section_embeddings[0]

array([-4.1138e-02, -8.3389e-03, -2.9587e-02, -3.0533e-02, -3.4561e-03,
       -1.3138e-02, -1.6403e-02,  2.1881e-02, -2.1255e-02,  6.1264e-03,
        2.1072e-02, -5.3802e-02,  3.1261e-03, -2.5635e-02,  1.0086e-02,
       -3.5645e-02,  2.7237e-02,  6.6605e-03, -4.1748e-02,  1.7120e-02,
        1.7502e-02, -3.4058e-02, -2.2564e-03, -3.8300e-02,  7.4097e-02,
        3.8818e-02,  1.2383e-02, -3.3447e-02, -2.5192e-02, -2.6831e-01,
        2.7298e-02, -4.6234e-02,  6.6299e-03, -1.8204e-02,  8.7433e-03,
        1.0551e-02,  3.0624e-02,  1.9180e-02, -1.8768e-02,  1.4267e-02,
        5.0964e-02,  1.5747e-02,  2.8427e-02, -9.4604e-03, -1.2840e-02,
        4.2686e-03,  1.1292e-02, -2.0813e-02, -1.8280e-02, -1.3222e-02,
       -4.7035e-03, -1.8463e-02, -2.6901e-02,  3.7018e-02,  2.3376e-02,
        5.1361e-02,  1.9592e-02, -7.3090e-03,  2.7786e-02, -3.5339e-02,
        3.5736e-02,  1.6174e-02, -2.0190e-01,  3.6774e-02,  3.7598e-02,
        2.0065e-02,  8.8358e-04, -6.1737e-02,  2.0828e-02,  2.80