In [1]:
import torch
# torch.set_num_threads(1)
from datasets import load_dataset
from transformers import AutoModelForMaskedLM, AutoTokenizer

In [2]:
dataset = load_dataset("KShivendu/dbpedia-entities-openai-1M", split="train")

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

# Single Encoder for Both Query and Document

In [44]:
model_id = "naver/efficient-splade-VI-BT-large-doc"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)

tokenizer_config.json:   0%|          | 0.00/449 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/620 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [45]:
def compute_vector(text, tokenizer, model, max_token_length: int = 512):
    """
    Computes a vector from logits and attention mask using ReLU, log, and max operations.
    If RuntimeError occurs due to token length, it truncates the number of tokens to 512.
    """
    # Tokenize the text and truncate if needed
    tokens = tokenizer(text, return_tensors="pt")
    if tokens.input_ids.size(1) > max_token_length:
        print(f"Truncating to 512 tokens for:\n{text}")
        tokens = tokenizer(text, return_tensors="pt", max_length=max_token_length, truncation=True)

    output = model(**tokens)
    logits, attention_mask = output.logits, tokens.attention_mask
    relu_log = torch.log(1 + torch.relu(logits))
    weighted_log = relu_log * attention_mask.unsqueeze(-1)
    max_val, _ = torch.max(weighted_log, dim=1)
    vec = max_val.squeeze()

    return vec, tokens

vec, tokens = compute_vector("""Hello World!""", tokenizer=tokenizer, model=model)

In [46]:
def extract_and_map_sparse_vector(vector, tokenizer):
    """
    Extracts non-zero elements from a given vector and maps these elements to their human-readable tokens using a tokenizer. The function creates and returns a sorted dictionary where keys are the tokens corresponding to non-zero elements in the vector, and values are the weights of these elements, sorted in descending order of weights.
    This function is useful in NLP tasks where you need to understand the significance of different tokens based on a model's output vector. It first identifies non-zero values in the vector, maps them to tokens, and sorts them by weight for better interpretability.
    Args:
    vector (torch.Tensor): A PyTorch tensor from which to extract non-zero elements.
    tokenizer: The tokenizer used for tokenization in the model, providing the mapping from tokens to indices.
    Returns:
    dict: A sorted dictionary mapping human-readable tokens to their corresponding non-zero weights.
    """

    # Extract indices and values of non-zero elements in the vector
    cols = vector.nonzero().squeeze().cpu().tolist()
    weights = vector[cols].cpu().tolist()

    # Map indices to tokens and create a dictionary
    idx2token = {idx: token for token, idx in tokenizer.get_vocab().items()}
    token_weight_dict = {
        idx2token[idx]: round(weight, 2) for idx, weight in zip(cols, weights)
    }

    # Sort the dictionary by weights in descending order
    sorted_token_weight_dict = {
        k: v
        for k, v in sorted(
            token_weight_dict.items(), key=lambda item: item[1], reverse=True
        )
    }

    return sorted_token_weight_dict

extract_and_map_sparse_vector(vector=vec, tokenizer=tokenizer)

{'hello': 2.21,
 'world': 1.25,
 '!': 1.1,
 'greeting': 0.9,
 'welcome': 0.84,
 'cheers': 0.64,
 'merry': 0.6,
 'goodbye': 0.56,
 'lyrics': 0.44,
 'eclipse': 0.43,
 'monde': 0.27,
 'congratulations': 0.23,
 'translation': 0.21,
 'paris': 0.2,
 'you': 0.17,
 'hepburn': 0.09,
 'japan': 0.08,
 'please': 0.08,
 'thanks': 0.08,
 'fifa': 0.05,
 'wow': 0.05,
 'exclaimed': 0.04,
 'collins': 0.03,
 'many': 0.02,
 'moon': 0.02,
 'tokyo': 0.02,
 'worlds': 0.01,
 'disney': 0.0,
 'physicist': 0.0}

In [47]:
# new_dataset = dataset.map(lambda example: {"embed_text": example['title'] + example['text']}, num_proc=4)

In [52]:
ds_100K = new_dataset.train_test_split(test_size=0.1, seed=37)['test']
ds_100K

Dataset({
    features: ['_id', 'title', 'text', 'openai', 'embed_text'],
    num_rows: 100000
})

In [53]:
ds = ds_100K.map(lambda example: {"vec": compute_vector(example['embed_text'], tokenizer=tokenizer, model=model)[0]})

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

Truncating to 512 tokens for:
Thrill Jockey discographyThis is the discography of the record label Thrill Jockeythrill 378 - Liturgy - The Ark Workthrill 339 - Wrekmeister Harmonies - "You've Always Meant So Much to Me" (2013)thrill 315 - Matmos - "The Ganzfeld EP" (2012)thrill 313 - Guardian Alien - "See the World Given to a One Love Entity" (2012)thrill 286 - High Places - "Original Colors" (2011)thrill 285 - Future Islands - "On the Water" (2011)thrill 284 - Future Islands - "Before the Bridge" b/w "Find Love" 7" (2011)thrill 283 - Luke Roberts - "Big Bells & Dime Songs" (2011)thrill 282 - Tunnels - "The Blackout" (2011)thrill 280 - Barn Owl - "Lost in the Glare" (2011)thrill 279 - Wooden Shjips - "West" (2011)thrill 278 - The Sea and Cake - "The Moonlight Butterfly" (2011)thrill 277 - Barn Owl - "Shadowland" (2011)thrill 276 - White Hills - "H-p1" (2011)thrill 275 - Pontiak - "Comecrudos" (2011)thrill 274 - Mountains - "Air Museum" (2011)thrill 273 - Liturgy - "Aesthethica" (2011)t

In [60]:
# ds = ds.rename_column("vec", "splade")
# ds = ds.remove_columns("embed_text")

In [62]:
len(ds)

100000

In [61]:
# !huggingface-cli login --token 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
ds.push_to_hub("nirantk/dbpedia-entities-efficient-splade-100K")

Pushing dataset shards to the dataset hub:   0%|          | 0/26 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

In [42]:
import numpy as np
vec = np.array(ds[0]['vec'])

sparse_indices = vec.nonzero()
sparse_values = vec[sparse_indices]
sparse_indices, sparse_values

((array([ 1059,  2057,  2103,  2198,  2220,  2221,  2237,  2284,  2304,
          2313,  2314,  2316,  2318,  2328,  2352,  2631,  2688,  2783,
          2796,  2835,  2912,  3007,  3060,  3146,  3295,  3516,  3586,
          3636,  3683,  3741,  3842,  3876,  4075,  4135,  4213,  5111,
          5194,  5858,  5917,  6946,  7734,  7929,  8124,  8129,  8545,
          8807,  9424, 10930, 11333, 11382, 11905, 12155, 12849, 13796,
         15068, 15536, 24053, 24185, 28909]),),
 array([0.2877084 , 2.20456076, 0.88390034, 0.52624637, 0.0764193 ,
        1.86921799, 0.74949855, 0.5232088 , 0.47974336, 1.41573536,
        0.04795758, 0.07491256, 0.47559032, 0.27835336, 0.24559918,
        1.29121816, 0.05531136, 0.29036081, 0.32312509, 0.49300152,
        2.35762143, 1.04617584, 0.01916639, 0.04127081, 1.50662303,
        0.67788154, 1.03206241, 0.02812318, 0.01505031, 0.11226109,
        0.83307409, 0.42092898, 0.10255694, 0.24148026, 0.10010952,
        0.37311473, 0.92574584, 1.99337435, 