# Trace Cross Encoder

In [1]:
model_id = "cross-encoder/ms-marco-electra-base"
test_sentences = [('How many people live in Berlin?', 'How many people live in Berlin?'), ('Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.')]

In [2]:
import torch
from sentence_transformers import CrossEncoder
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model = AutoModelForSequenceClassification.from_pretrained(model_id)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
features = tokenizer([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')],  padding=True, truncation=True, return_tensors="pt")

In [3]:
features

{'input_ids': tensor([[  101, 23032,   102, 20423,  2487,   102],
        [  101, 23032,   102, 20423,  2475,   102],
        [  101, 23032,   102, 20423,  2509,   102]]), 'token_type_ids': tensor([[0, 0, 0, 1, 1, 1],
        [0, 0, 0, 1, 1, 1],
        [0, 0, 0, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1]])}

In [4]:
ex_input = (features["input_ids"], features["token_type_ids"], features["attention_mask"])
traced_model = torch.jit.trace(model, ex_input, strict=False)
torch.jit.save(traced_model, "traced_cross_encoder.pt")

In [5]:
loaded_model = torch.jit.load("traced_cross_encoder.pt")
loaded_model.eval()

RecursiveScriptModule(
  original_name=ElectraForSequenceClassification
  (electra): RecursiveScriptModule(
    original_name=ElectraModel
    (embeddings): RecursiveScriptModule(
      original_name=ElectraEmbeddings
      (word_embeddings): RecursiveScriptModule(original_name=Embedding)
      (position_embeddings): RecursiveScriptModule(original_name=Embedding)
      (token_type_embeddings): RecursiveScriptModule(original_name=Embedding)
      (LayerNorm): RecursiveScriptModule(original_name=LayerNorm)
      (dropout): RecursiveScriptModule(original_name=Dropout)
    )
    (encoder): RecursiveScriptModule(
      original_name=ElectraEncoder
      (layer): RecursiveScriptModule(
        original_name=ModuleList
        (0): RecursiveScriptModule(
          original_name=ElectraLayer
          (attention): RecursiveScriptModule(
            original_name=ElectraAttention
            (self): RecursiveScriptModule(
              original_name=ElectraSelfAttention
              (query): R

In [6]:
test_features = tokenizer(test_sentences,  padding=True, truncation=True, return_tensors="pt")
pt_prediction = loaded_model(**test_features)
pt_prediction

{'logits': tensor([[ -0.7667],
         [-10.7274]], grad_fn=<AddmmBackward0>)}

In [7]:
from torch import nn
default_activation_function = nn.Sigmoid() #if config.num_labels == 1 else nn.Identity()
activation_fct = default_activation_function
logits = activation_fct(pt_prediction['logits'])

In [8]:
logits

tensor([[3.1720e-01],
        [2.1936e-05]], grad_fn=<SigmoidBackward0>)

# Compare Output

In [9]:
from sentence_transformers import CrossEncoder
model = CrossEncoder(model_id)
original_embedding = model.predict(test_sentences)
original_embedding

array([3.1719887e-01, 2.1936203e-05], dtype=float32)

In [10]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
import torch

config = AutoConfig.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, config=config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
max_length = 512

In [11]:
from torch.utils.data import DataLoader

def smart_batching_collate_text_only(batch):
    texts = [[] for _ in range(len(batch[0]))]

    for example in batch:
        for idx, text in enumerate(example):
            texts[idx].append(text.strip())

    tokenized = tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=max_length)

    for name in tokenized:
        tokenized[name] = tokenized[name].to('cpu')

    return tokenized

inp_dataloader = DataLoader(test_sentences, batch_size=32, collate_fn=smart_batching_collate_text_only, num_workers=0, shuffle=False)

In [12]:
inp_dataloader

<torch.utils.data.dataloader.DataLoader at 0x17560e2d0>

In [13]:
from torch import nn
default_activation_function = nn.Sigmoid() #if config.num_labels == 1 else nn.Identity()

In [14]:
iterator = inp_dataloader
activation_fct = None

if activation_fct is None:
    activation_fct = default_activation_function

pred_scores = []
model.eval()
model.to('cpu')
with torch.no_grad():
    for features in iterator:
        model_predictions = model(**features, return_dict=True)
        logits = activation_fct(model_predictions.logits)

#         if apply_softmax and len(logits[0]) > 1:
#             logits = torch.nn.functional.softmax(logits, dim=1)
        pred_scores.extend(logits)

In [15]:
pred_scores

[tensor([0.3172]), tensor([2.1936e-05])]