In [1]:
import numpy as np
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

## Running the model with Transformers and Torch

In [2]:
sentences = [
    "Hello World",
    "Built by Nirant Kasliwal",
]

## PyTorch Code from the [SPLADERunner](https://github.com/PrithivirajDamodaran/SPLADERunner) library

In [3]:
hf_token = "<your_hf_token_here>"

In [4]:
# Download the model and tokenizer
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("prithivida/Splade_PP_en_v1", token=hf_token)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}
model = AutoModelForMaskedLM.from_pretrained("prithivida/Splade_PP_en_v1", token=hf_token)
model.to(device)

# Tokenize the input
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {key: val.to(device) for key, val in inputs.items()}
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
token_type_ids = inputs["token_type_ids"]

# Run model and prepare sparse vector
outputs = model(**inputs)
logits = outputs.logits
print("Output Logits shape: ", logits.shape)
print("Output Attention mask shape: ", attention_mask.shape)
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
vector = max_val.squeeze()
print("Sparse Vector shape: ", vector.shape)
# print("Number of Actual Dimensions: ", len(cols))
cols = [vec.nonzero().squeeze().cpu().tolist() for vec in vector]
weights = [vec[col].cpu().tolist() for vec, col in zip(vector, cols)]

idx = 1
cols, weights = cols[idx], weights[idx]
# Print the BOW representation
d = {k: v for k, v in zip(cols, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
print(f"SPLADE BOW rep for sentence:\t{sentences[idx]}\n{bow_rep}")

Output Logits shape:  torch.Size([2, 10, 30522])
Output Attention mask shape:  torch.Size([2, 10])
Sparse Vector shape:  torch.Size([2, 30522])
SPLADE BOW rep for sentence:	Built by Nirant Kasliwal
[('##rant', 2.02), ('built', 1.94), ('##wal', 1.79), ('##sl', 1.69), ('build', 1.57), ('ka', 1.4), ('ni', 1.26), ('made', 0.93), ('architect', 0.76), ('was', 0.69), ('who', 0.61), ('his', 0.5), ('wrote', 0.47), ('india', 0.45), ('company', 0.41), ('##i', 0.41), ('he', 0.37), ('manufacturer', 0.36), ('by', 0.35), ('engineer', 0.33), ('architecture', 0.33), ('ko', 0.23), ('him', 0.22), ('invented', 0.19), ('said', 0.14), ('k', 0.11), ('man', 0.11), ('statue', 0.11), ('bomb', 0.1), ('##wa', 0.1), ('builder', 0.09), ('.', 0.07), ('started', 0.06), (',', 0.04), ('ku', 0.03)]


## Export with output_attentions and logits

In [5]:
from transformers import AutoTokenizer

model_id = "nirantk/SPLADE_PP_en_v1"
output_dir = f"models/{model_id.replace('/', '_')}"
model_kwargs = {"output_attentions": True, "return_dict": True}

print(f"Exporting model to {output_dir}")
tokenizer.save_pretrained(output_dir)
# main_export(
#     model_id,
#     output=output_dir,
#     no_post_process=True,
#     model_kwargs=model_kwargs,
#     token=hf_token,
# )

Exporting model to models/nirantk_SPLADE_PP_en_v1


('models/nirantk_SPLADE_PP_en_v1/tokenizer_config.json',
 'models/nirantk_SPLADE_PP_en_v1/special_tokens_map.json',
 'models/nirantk_SPLADE_PP_en_v1/vocab.txt',
 'models/nirantk_SPLADE_PP_en_v1/added_tokens.json',
 'models/nirantk_SPLADE_PP_en_v1/tokenizer.json')

## Running the model with ONNX

In [6]:
from optimum.onnxruntime import ORTModelForMaskedLM

model = ORTModelForMaskedLM.from_pretrained("nirantk/SPLADE_PP_en_v1")
tokenizer = AutoTokenizer.from_pretrained("nirantk/SPLADE_PP_en_v1")

In [7]:
inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {key: val.to(device) for key, val in inputs.items()}
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
token_type_ids = inputs["token_type_ids"]

onnx_input = {
    "input_ids": input_ids.cpu().numpy(),
    "attention_mask": attention_mask.cpu().numpy(),
    "token_type_ids": token_type_ids.cpu().numpy(),
}

logits = model(**onnx_input).logits

In [8]:
logits.shape

(2, 10, 30522)

In [9]:
print("Output Logits shape: ", logits.shape)

relu_log = np.log(1 + np.maximum(logits, 0))

# Equivalent to relu_log * attention_mask.unsqueeze(-1)
# For NumPy, you might need to explicitly expand dimensions if 'attention_mask' is not already 2D
weighted_log = relu_log * np.expand_dims(attention_mask, axis=-1)

# Equivalent to torch.max(weighted_log, dim=1)
# NumPy's max function returns only the max values, not the indices, so we don't need to unpack two values
max_val = np.max(weighted_log, axis=1)

# Equivalent to max_val.squeeze()
# This step may be unnecessary in NumPy if max_val doesn't have unnecessary dimensions
vector = np.squeeze(max_val)
print("Sparse Vector shape: ", vector.shape)

# print(vector[0].nonzero())

cols = [vec.nonzero()[0].squeeze().tolist() for vec in vector]
weights = [vec[col].tolist() for vec, col in zip(vector, cols)]

idx = 1
cols, weights = cols[idx], weights[idx]
# Print the BOW representation
d = {k: v for k, v in zip(cols, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
print(f"SPLADE BOW rep for sentence:\t{sentences[idx]}\n{bow_rep}")

Output Logits shape:  (2, 10, 30522)
Sparse Vector shape:  (2, 30522)
SPLADE BOW rep for sentence:	Built by Nirant Kasliwal
[('##rant', 2.02), ('built', 1.94), ('##wal', 1.79), ('##sl', 1.69), ('build', 1.57), ('ka', 1.4), ('ni', 1.26), ('made', 0.93), ('architect', 0.76), ('was', 0.69), ('who', 0.61), ('his', 0.5), ('wrote', 0.47), ('india', 0.45), ('company', 0.41), ('##i', 0.41), ('he', 0.37), ('manufacturer', 0.36), ('by', 0.35), ('engineer', 0.33), ('architecture', 0.33), ('ko', 0.23), ('him', 0.22), ('invented', 0.19), ('said', 0.14), ('k', 0.11), ('man', 0.11), ('statue', 0.11), ('bomb', 0.1), ('##wa', 0.1), ('builder', 0.09), ('.', 0.07), ('started', 0.06), (',', 0.04), ('ku', 0.03)]


In [10]:
len(cols)

35

In [11]:
cols

[1010,
 1012,
 1047,
 2001,
 2002,
 2010,
 2011,
 2032,
 2040,
 2056,
 2072,
 2081,
 2158,
 2194,
 2318,
 2328,
 2626,
 2634,
 3857,
 3992,
 4213,
 4294,
 4944,
 5968,
 6231,
 7751,
 8826,
 9152,
 10556,
 12508,
 12849,
 13476,
 13970,
 14540,
 17884]