# Exporting Transformer Models to ONNX Model

This notebook demonstrates how to export three different Transformer models for serving in Vespa, see [MS Marco Passage ranking](https://github.com/vespa-engine/sample-apps/blob/master/msmarco-ranking/passage-ranking-README.md).


In [None]:
!python3 -m pip install torch numpy transformers onnx onnxruntime "protobuf >=4.24.2, <=4.24.2"

In [None]:
from transformers import AutoModel, AutoTokenizer, BertTokenizer, BertPreTrainedModel, BertModel
import transformers
import torch 
from pathlib import Path
import torch.nn as nn

# Sentence Transformer (bi-encoder) for dense retrieval using approximate nearest neighbor search 

We create a wrapper model so that we can compute the mean pooling over the output using ONNX. 
Almost all sentence-transformer models uses mean pooling over the output layer. 
We also perform normalization so we can use innerproduct distance, which speeds up nearest neighbor search distance calculations, instead of regular angular distance.  

In [None]:
class MeanPoolingEncoderONNX(BertPreTrainedModel):

    def __init__(self,config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.init_weights()
        
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        token_embeddings = self.bert(input_ids,attention_mask=attention_mask)[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        sum_embeddings = sum_embeddings / sum_mask
        return torch.nn.functional.normalize(sum_embeddings, p=2, dim=1)

In [None]:
encoder = MeanPoolingEncoderONNX.from_pretrained("sentence-transformers/msmarco-MiniLM-L-6-v3")
encoder = encoder.eval()

input_names = ["input_ids", "attention_mask"]
output_names = ["contextual"]
input_ids = torch.ones(1,32, dtype=torch.int64)
attention_mask = torch.ones(1,32,dtype=torch.int64)
args = (input_ids, attention_mask)
torch.onnx.export(encoder,
                args=args,
                f="sentence-msmarco-MiniLM-L-6-v3.onnx",
                do_constant_folding=True,
                input_names = input_names,
                output_names = output_names,
                dynamic_axes = {
                    "input_ids": {0: "batch", 1:"batch"},
                    "attention_mask": {0: "batch", 1: "batch"},
                    "contextual": {0: "batch"},
                },
                opset_version=12)

In [None]:
from onnxruntime.quantization import quantize_dynamic, QuantType
quantized_model = quantize_dynamic("sentence-msmarco-MiniLM-L-6-v3.onnx", 
                                   "sentence-msmarco-MiniLM-L-6-v3-quantized.onnx")


## Vespa ColBERT model (Late interaction model)

Here we also define a small wrapper


In [None]:
class VespaColBERT(BertPreTrainedModel):

    def __init__(self,config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.linear = nn.Linear(config.hidden_size, 32, bias=False)
        self.init_weights()

    def forward(self, input_ids, attention_mask):
        Q = self.bert(input_ids,attention_mask=attention_mask)[0]
        Q = self.linear(Q)
        return torch.nn.functional.normalize(Q, p=2, dim=2)  

In [None]:
colbert_query_encoder = VespaColBERT.from_pretrained("vespa-engine/col-minilm") 
input_names = ["input_ids", "attention_mask"]
output_names = ["contextual"]
#input examples 
input_ids = torch.ones(1,32, dtype=torch.int64)
attention_mask = torch.ones(1,32,dtype=torch.int64)
args = (input_ids, attention_mask)
torch.onnx.export(colbert_query_encoder,
                args=args,
                f="vespa-colMiniLM-L-6.onnx",
                input_names = input_names,
                output_names = output_names,
                dynamic_axes = {
                    "input_ids": {0: "batch", 1: "batch"},
                    "attention_mask": {0: "batch", 1: "batch"},
                    "contextual": {0: "batch", 1: "batch"},
                },
                opset_version=12)

In [None]:
quantized_model = quantize_dynamic("vespa-colMiniLM-L-6.onnx", 
                                   "vespa-colMiniLM-L-6-quantized.onnx")

## All to all Cross Attention Model 

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import transformers.convert_graph_to_onnx as onnx_convert

cross_model = "cross-encoder/ms-marco-MiniLM-L-6-v2"
output_file = "ms-marco-MiniLM-L-6-v2.onnx"
tokenizer = AutoTokenizer.from_pretrained(cross_model)
model = AutoModelForSequenceClassification.from_pretrained(cross_model)
model = model.eval()
pipeline = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer)
onnx_convert.convert_pytorch(pipeline, opset=12, output=Path(output_file), use_external_format=False)
onnx_convert.quantize(Path(output_file))