In [1]:
import torch
# torch.set_num_threads(1)
from datasets import load_dataset
from transformers import AutoModelForMaskedLM, AutoTokenizer

In [2]:
dataset = load_dataset("KShivendu/dbpedia-entities-openai-1M", split="train")

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

# Single Encoder for Both Query and Document

In [3]:
model_id = "naver/splade-cocondenser-ensembledistil"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)

In [4]:
def compute_vector(text, tokenizer, model):
    """
    Computes a vector from logits and attention mask using ReLU, log, and max operations.
    """
    tokens = tokenizer(text, return_tensors="pt")
    output = model(**tokens)
    logits, attention_mask = output.logits, tokens.attention_mask
    relu_log = torch.log(1 + torch.relu(logits))
    weighted_log = relu_log * attention_mask.unsqueeze(-1)
    max_val, _ = torch.max(weighted_log, dim=1)
    vec = max_val.squeeze()

    return vec, tokens

vec, tokens = compute_vector("Hello World!", tokenizer=tokenizer, model=model)

In [5]:
vec, tokens

(tensor([0., 0., 0.,  ..., 0., 0., 0.], grad_fn=<SqueezeBackward0>),
 {'input_ids': tensor([[ 101, 7592, 2088,  999,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])})

In [6]:
def extract_and_map_sparse_vector(vector, tokenizer):
    """
    Extracts non-zero elements from a given vector and maps these elements to their human-readable tokens using a tokenizer. The function creates and returns a sorted dictionary where keys are the tokens corresponding to non-zero elements in the vector, and values are the weights of these elements, sorted in descending order of weights.
    This function is useful in NLP tasks where you need to understand the significance of different tokens based on a model's output vector. It first identifies non-zero values in the vector, maps them to tokens, and sorts them by weight for better interpretability.
    Args:
    vector (torch.Tensor): A PyTorch tensor from which to extract non-zero elements.
    tokenizer: The tokenizer used for tokenization in the model, providing the mapping from tokens to indices.
    Returns:
    dict: A sorted dictionary mapping human-readable tokens to their corresponding non-zero weights.
    """

    # Extract indices and values of non-zero elements in the vector
    cols = vector.nonzero().squeeze().cpu().tolist()
    weights = vector[cols].cpu().tolist()

    # Map indices to tokens and create a dictionary
    idx2token = {idx: token for token, idx in tokenizer.get_vocab().items()}
    token_weight_dict = {
        idx2token[idx]: round(weight, 2) for idx, weight in zip(cols, weights)
    }

    # Sort the dictionary by weights in descending order
    sorted_token_weight_dict = {
        k: v
        for k, v in sorted(
            token_weight_dict.items(), key=lambda item: item[1], reverse=True
        )
    }

    return sorted_token_weight_dict

extract_and_map_sparse_vector(vector=vec, tokenizer=tokenizer)

{'hello': 3.02,
 'world': 2.53,
 'hi': 1.64,
 '##world': 1.45,
 '!': 1.28,
 'worlds': 0.78,
 'happiness': 0.71,
 'hawkins': 0.59,
 'birthday': 0.53,
 'message': 0.51,
 'greeting': 0.42,
 'song': 0.41,
 'party': 0.37,
 'language': 0.33,
 'jay': 0.33,
 'marty': 0.33,
 'peace': 0.32,
 'welcome': 0.31,
 'jerry': 0.29,
 'gabriel': 0.27,
 'dave': 0.26,
 'winston': 0.24,
 'global': 0.23,
 'music': 0.21,
 'thomas': 0.21,
 'surprise': 0.21,
 'new': 0.2,
 'hey': 0.2,
 'spencer': 0.2,
 '?': 0.19,
 'daniel': 0.19,
 'stanley': 0.19,
 'sound': 0.18,
 'universal': 0.17,
 'roger': 0.16,
 'nelson': 0.16,
 'simon': 0.15,
 'event': 0.14,
 'alien': 0.13,
 'arnold': 0.13,
 'harry': 0.12,
 'justin': 0.12,
 'fuzzy': 0.12,
 'alex': 0.11,
 'baby': 0.1,
 'god': 0.08,
 'graham': 0.08,
 'chorus': 0.08,
 'wave': 0.07,
 'doll': 0.07,
 'ariel': 0.06,
 'carter': 0.05,
 'noah': 0.05,
 'international': 0.04,
 'mia': 0.03,
 'robot': 0.02}

In [None]:
new_dataset = dataset.map(lambda example: {"sparse": compute_vector(example['title'] + example['text'], tokenizer, model)[0]})

Map:   0%|          | 0/1000000 [00:00<?, ? examples/s]