# Trace Cross Encoder

In [1]:
import torch
from sentence_transformers import CrossEncoder
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model_id = "cross-encoder/ms-marco-electra-base"
model = AutoModelForSequenceClassification.from_pretrained(model_id)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
features = tokenizer([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')],  padding=True, truncation=True, return_tensors="pt")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
features

{'input_ids': tensor([[  101, 23032,   102, 20423,  2487,   102],
        [  101, 23032,   102, 20423,  2475,   102],
        [  101, 23032,   102, 20423,  2509,   102]]), 'token_type_ids': tensor([[0, 0, 0, 1, 1, 1],
        [0, 0, 0, 1, 1, 1],
        [0, 0, 0, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1]])}

In [3]:
ex_input = (features["input_ids"], features["token_type_ids"], features["attention_mask"])
traced_model = torch.jit.trace(model, ex_input, strict=False)
torch.jit.save(traced_model, "traced_cross_encoder.pt")

In [4]:
loaded_model = torch.jit.load("traced_cross_encoder.pt")
loaded_model.eval()

RecursiveScriptModule(
  original_name=ElectraForSequenceClassification
  (electra): RecursiveScriptModule(
    original_name=ElectraModel
    (embeddings): RecursiveScriptModule(
      original_name=ElectraEmbeddings
      (word_embeddings): RecursiveScriptModule(original_name=Embedding)
      (position_embeddings): RecursiveScriptModule(original_name=Embedding)
      (token_type_embeddings): RecursiveScriptModule(original_name=Embedding)
      (LayerNorm): RecursiveScriptModule(original_name=LayerNorm)
      (dropout): RecursiveScriptModule(original_name=Dropout)
    )
    (encoder): RecursiveScriptModule(
      original_name=ElectraEncoder
      (layer): RecursiveScriptModule(
        original_name=ModuleList
        (0): RecursiveScriptModule(
          original_name=ElectraLayer
          (attention): RecursiveScriptModule(
            original_name=ElectraAttention
            (self): RecursiveScriptModule(
              original_name=ElectraSelfAttention
              (query): R

In [5]:
test_sentences = [('How many people live in Berlin?', 'How many people live in Berlin?'), ('Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.')]
test_features = tokenizer(test_sentences,  padding=True, truncation=True, return_tensors="pt")
pt_prediction = loaded_model(**test_features)
pt_prediction

{'logits': tensor([[ -0.7667],
         [-10.7274]], grad_fn=<AddmmBackward0>)}

# Compare Output

In [6]:
from sentence_transformers import CrossEncoder
model = CrossEncoder(model_id)
original_embedding = model.predict(test_sentences)
original_embedding

array([3.1719896e-01, 2.1936203e-05], dtype=float32)

In [7]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
import torch

config = AutoConfig.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, config=config)
tokenizer = AutoTokenizer.from_pretrained(model_id)

features = tokenizer(test_sentences,  padding=True, truncation='longest_first', return_tensors="pt", max_length=512).to('cpu')

model.eval()
with torch.no_grad():
    scores = model(**features).logits
    print(scores)

tensor([[ -0.7667],
        [-10.7274]])


In [8]:
from sentence_transformers import CrossEncoder
model = CrossEncoder(model_id)
model.model(**test_features, return_dict=True)

SequenceClassifierOutput(loss=None, logits=tensor([[ -0.7667],
        [-10.7274]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [9]:
#     def smart_batching_collate_text_only(self, batch):
#         texts = [[] for _ in range(len(batch[0]))]

#         for example in batch:
#             for idx, text in enumerate(example):
#                 texts[idx].append(text.strip())

#         tokenized = self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_length)

#         for name in tokenized:
#             tokenized[name] = tokenized[name].to(self._target_device)

#         return tokenized