In [4]:
# !pip install -U transformers torch

In [8]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

In [9]:
model_id = 'naver/splade-cocondenser-ensembledistil'

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)

# test text

In [25]:
text = 'test text'

In [48]:
tokens = tokenizer(text, return_tensors='pt')
print(f' tokenized text: {tokenizer.convert_ids_to_tokens(tokens["input_ids"][0])}')
output = model(**tokens)
output

 tokenized text: ['[CLS]', 'test', 'text', '[SEP]']


MaskedLMOutput(loss=None, logits=tensor([[[ -6.1641,  -8.0746,  -7.6957,  ...,  -7.7311,  -7.7196,  -5.3543],
         [-35.6677, -25.3951, -22.7048,  ..., -24.4138, -25.9929, -26.7460],
         [-22.1844, -18.0183, -19.8789,  ..., -17.4275, -16.6500, -19.5127],
         [-18.3353, -15.3200, -14.5211,  ..., -14.9485, -14.1121, -15.8339]]],
       grad_fn=<ViewBackward0>), hidden_states=None, attentions=None)

In [27]:
output.logits.shape

torch.Size([1, 4, 30522])

Now we have a probability distribution over all token, but we want it over the entire text, the splade paper does this by:

In [28]:


vec = torch.max(
    torch.log(
        1 + torch.relu(output.logits)
    ) * tokens.attention_mask.unsqueeze(-1),
dim=1)[0].squeeze()

vec.shape

torch.Size([30522])

In [29]:
vec

tensor([0., 0., 0.,  ..., 0., 0., 0.], grad_fn=<SqueezeBackward0>)

In [32]:
# extract non-zero positions
cols = vec.nonzero().squeeze().cpu().tolist()
print(f"amount of non-zero values: {len(cols)}")

# extract the non-zero values
weights = vec[cols].cpu().tolist()
# use to create a dictionary of token ID to weight
sparse_dict = dict(zip(cols, weights))

print("the non-zero values:")
sparse_dict

amount of non-zero values: 33
the non-zero values:


{2726: 0.20849496126174927,
 2773: 0.09160938858985901,
 3076: 0.20714879035949707,
 3189: 0.3412291705608368,
 3231: 3.0339627265930176,
 3648: 0.0031361228320747614,
 3661: 0.3299848735332489,
 3752: 0.2602483034133911,
 3793: 2.7784647941589355,
 3836: 0.06602185219526291,
 4180: 0.4668448269367218,
 4405: 0.01899118907749653,
 4471: 0.3733828067779541,
 4918: 0.025326814502477646,
 5074: 0.1128750592470169,
 5604: 2.3174471855163574,
 5852: 0.5542181730270386,
 6140: 0.0016638495726510882,
 6254: 0.31116071343421936,
 6845: 0.12719713151454926,
 6981: 1.8357378244400024,
 7099: 0.3066442012786865,
 7551: 0.09702988713979721,
 7667: 0.4796116352081299,
 8744: 0.2446175217628479,
 8785: 0.2586182653903961,
 10618: 0.13413843512535095,
 11360: 1.4634112119674683,
 12874: 0.5858922004699707,
 14686: 0.283345490694046,
 19461: 0.05838468298316002,
 22498: 0.3845391273498535,
 28770: 0.2040555477142334}

These tokens do not tell us much lets map them back to tokens:

In [35]:
# extract the ID position to text token mappings
idx2token = {
    idx: token for token, idx in tokenizer.get_vocab().items()
}

In [36]:
# map token IDs to human-readable tokens
sparse_dict_tokens = {
    idx2token[idx]: round(weight, 2) for idx, weight in zip(cols, weights)
}
# sort so we can see most relevant tokens first
sparse_dict_tokens = {
    k: v for k, v in sorted(
        sparse_dict_tokens.items(),
        key=lambda item: item[1],
        reverse=True
    )
}
sparse_dict_tokens

{'test': 3.03,
 'text': 2.78,
 'testing': 2.32,
 'texts': 1.84,
 'exam': 1.46,
 'pearson': 0.59,
 'tests': 0.55,
 'assessment': 0.48,
 'content': 0.47,
 'abbreviation': 0.38,
 'message': 0.37,
 'report': 0.34,
 'letter': 0.33,
 'document': 0.31,
 'sample': 0.31,
 'quote': 0.28,
 'reading': 0.26,
 'math': 0.26,
 'blank': 0.24,
 'thomas': 0.21,
 'student': 0.21,
 'proctor': 0.2,
 'lab': 0.13,
 'certification': 0.13,
 'roger': 0.11,
 'experiment': 0.1,
 'word': 0.09,
 'teacher': 0.07,
 'quiz': 0.06,
 'charlie': 0.03,
 'pilot': 0.02,
 'judge': 0.0,
 'print': 0.0}

# Comparing vectors

We will now compare 3 pieces of text to eachother to see how that works:

In [42]:
texts = [
   "information retrieval is hard to understand, but lovely when you understand it.",
   "I love going to the University of Amsterdam",
   "I don't want to go to school mum... we need to do information retrieval"
]

In [43]:
tokens = tokenizer(
    texts, return_tensors='pt',
    padding=True, truncation=True
)
output = model(**tokens)
# aggregate the token-level vecs and transform to sparse
vecs = torch.max(
    torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1), dim=1
)[0].squeeze().detach().cpu().numpy()
vecs.shape

(3, 30522)

In [44]:
import numpy as np

sim = np.zeros((vecs.shape[0], vecs.shape[0]))

for i, vec in enumerate(vecs):
    sim[i,:] = np.dot(vec, vecs.T) / (
        np.linalg.norm(vec) * np.linalg.norm(vecs, axis=1)
    )

In [45]:
sim

array([[1.00000012, 0.01163663, 0.33802783],
       [0.01163663, 1.        , 0.17227599],
       [0.33802783, 0.17227599, 1.        ]])