In [None]:
!pip install transformers
!pip install datasets
!pip install --upgrade accelerate
!pip install rank-eval
!pip install ranx
!pip install fsspec==2023.9.2

Collecting accelerate
  Downloading accelerate-1.7.0-py3-none-any.whl.metadata (19 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=2.0.0

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from transformers import AutoModel, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.autonotebook import trange
import numpy as np
import json
import random
from torch.utils.data import DataLoader, RandomSampler
from tqdm import tqdm
from ranx import Qrels, Run, evaluate
import wandb

In [None]:
import os
import gc
from collections import defaultdict
import heapq
from scipy.sparse import csr_matrix
import pandas as pd

In [None]:
# from utils import tsv_to_dict_multiple, tsv_to_dict_unqiue, split_embedding_into_chunks
import csv

# Read the tsv file as a dictionary (each key has a single value)
def tsv_to_dict_unqiue(file_path, keys = [0, 1]):
    with open(file_path, mode = "r", encoding = "utf-8") as file:
        reader = csv.reader(file, delimiter= "\t")

        data = {}
        for row in reader:
            data[row[keys[0]]] = row[keys[1]]

    return data


# Read the tsv file as a dictionary (each key has a multiple value)
def tsv_to_dict_multiple(file_path, keys = [0, 2]):
    with open(file_path, mode = "r", encoding = "utf-8") as file:
        reader = csv.reader(file, delimiter= "\t")

        data = {}
        for row in reader:
            if row[keys[0]] in data:
                data[row[keys[0]]].append(row[keys[1]])
            else:
                data[row[keys[0]]] = [row[keys[1]]]

    return data

# Split the embedding to k chunks
def split_embedding_into_chunks(embeddings, k_chunks):
    bsz, dim = embeddings.shape
    assert dim % k_chunks == 0
    return embeddings.view(bsz * k_chunks, dim // k_chunks)

In [None]:
# Optional - To clear GPU memory
gc.collect()
torch.cuda.empty_cache()

In [None]:
!nvidia-smi

Fri May 16 04:30:51 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   43C    P8              9W /   70W |       2MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

### Load raw training data

In [None]:
# base-model
model_name = "facebook/contriever-msmarco"
tokenizer = AutoTokenizer.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [None]:
passages = tsv_to_dict_unqiue(os.path.join('/content/drive/MyDrive/685/data', "collection.tsv"))
queries = tsv_to_dict_unqiue(os.path.join('/content/drive/MyDrive/685/data', "queries.train.tsv"))

### Create training triplets (BM25 hard negatives)

In [None]:
def create_dataset(address, n_negatives=2,cache_dir='hf',sample_count=1000):
    def train_gen():
      with open(address,'r') as f:
          count = 0
          for line in f:
              count += 1
              if sample_count != None and count >= sample_count: return
              data_sample = json.loads(line)
              negatives = []
              for neg_id in random.sample(data_sample['neg']['bm25'],n_negatives): negatives.append(passages[str(neg_id)])
              yield {'query': queries[str(data_sample['qid'])],'positive':passages[str(random.sample(data_sample['pos'],1)[0])],'negatives':[negatives]}
    return Dataset.from_generator(train_gen, cache_dir = cache_dir)

In [None]:
dataset = create_dataset(os.path.join('/content/drive/MyDrive/685/data','msmarco-hard-negatives-bm25_1k.jsonl'), 2, 'hf',1000000)

Generating train split: 0 examples [00:00, ? examples/s]

  table = cls._concat_blocks(blocks, axis=0)


### Create training triplets (BERT based hard negatives)

In [None]:
# def create_dataset(df, n_negatives=1,cache_dir='hf',sample_count=1000):
#     def train_gen():
#       samp_df = df.sample(n=sample_count, random_state=42)
#       for index, row in samp_df.iterrows():
#             yield {'query': queries[str(int(row[2]))],'positive':passages[str(int(row[3]))],'negatives':passages[str(int(row[4]))]}
#     return Dataset.from_generator(train_gen, cache_dir = cache_dir)

In [None]:
# df = pd.read_csv(os.path.join('/content/drive/MyDrive/685/data','bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv'),sep='\t',header=None, index_col=None)

In [None]:
# dataset = create_dataset(df, 1, 'hf',5000000)

### Train Val Split

In [None]:
split_dataset = dataset.train_test_split(test_size=0.1)

# There might be a split over
train_dataset = split_dataset['train']
eval_dataset = split_dataset['test']

### Custom data collator

In [None]:
from typing import Any, Dict, List, Union

class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        query_texts = []
        pos_texts = []
        neg_texts = []

        for feature in features:
            query_texts.append(feature['query'])
            pos_texts.append(feature['positive'])
            if 'negatives' in feature.keys():
                for neg_text in feature['negatives']: neg_texts.append(neg_text)

        tokenized_query_texts = self.tokenizer(
                query_texts,
                max_length=self.max_length,
                padding=self.padding,
                truncation=True,
                return_tensors="pt",
                add_special_tokens=True)

        tokenized_pos_texts = self.tokenizer(
                pos_texts,
                max_length=self.max_length,
                padding=self.padding,
                truncation=True,
                return_tensors="pt",
                add_special_tokens=True)

        tokenized_texts = {
            'tokenized_queries' : tokenized_query_texts,
            'tokenized_positives' : tokenized_pos_texts,
        }


        if len(neg_texts) > 0:
            tokenized_neg_texts = self.tokenizer(
                  neg_texts,
                  max_length=self.max_length,
                  padding=self.padding,
                  truncation=True,
                  return_tensors="pt",
                  add_special_tokens=True)
            tokenized_texts['tokenized_negatives'] = tokenized_neg_texts

        return tokenized_texts

custom_data_collator = CustomDataCollatorWithPadding(
    tokenizer=tokenizer,
    padding='longest',
    max_length=tokenizer.model_max_length
)

In [None]:
out_train = custom_data_collator([train_dataset[i] for i in range(8)])
print(out_train['tokenized_queries']['input_ids'].shape)
print(out_train['tokenized_positives']['input_ids'].shape)
print(out_train['tokenized_negatives']['input_ids'].shape)

torch.Size([8, 10])
torch.Size([8, 119])
torch.Size([8, 282])


In [None]:
out_eval = custom_data_collator([eval_dataset[i] for i in range(2)])
print(out_eval['tokenized_queries']['input_ids'].shape)
print(out_eval['tokenized_positives']['input_ids'].shape)

torch.Size([2, 10])
torch.Size([2, 87])


### Contrastive loss

In [None]:
#enter your code here
def get_contriever_emb(model_output, attention_mask):
    last_hidden = model_output["last_hidden_state"]
    last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
    emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    return emb

def contrastive_loss(query_embs,pos_embs,neg_embs):
    temp = 0.05
    query_embs = query_embs
    similarities_pos = torch.mm(query_embs,torch.transpose(pos_embs,0,1))/temp
    similarities = similarities_pos
    if neg_embs != None:
        similarities_neg = torch.mm(query_embs,torch.transpose(neg_embs,0,1))/temp
        similarities = torch.cat((similarities_pos,similarities_neg),dim=1)
    return F.cross_entropy(similarities,torch.arange(0,query_embs.shape[0]).to('cuda'))




## Model

### Vector Quantization module

In [None]:
class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, random_initialization=False):
        super().__init__()
        self.dim = dim
        self.n_embed = n_embed
        self.decay = decay
        self.eps = eps

        # Inititalize with pretrained codebook vectors for better training
        if random_initialization:
          embed = torch.randn(dim, n_embed)
        else:
          embed = torch.load(os.path.join('/content/drive/MyDrive/685/code','codebook_vectors.pt')).T

        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.ones(n_embed))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input,training=True):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)

            n = self.cluster_size.sum()
            cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

            # --- Dead code reinitialization ---
            with torch.no_grad():
                dead_mask = self.cluster_size < 0.1
                num_dead = dead_mask.sum().item()
                if num_dead > 0 and flatten.size(0) > 0:
                    random_indices = torch.randint(0, flatten.size(0), (num_dead,))
                    self.embed[:, dead_mask] = flatten[random_indices].t()
                    self.embed_avg[:, dead_mask] = flatten[random_indices].t()
                    self.cluster_size[dead_mask] = 1.0  # or a small positive value

        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

### Sparse Layer Module

In [None]:
class Sparse_Layer(nn.Module):
    def __init__(self,d,D):
        super().__init__()

        self.up_projection = nn.Linear(d,D)
        nn.init.xavier_uniform_(self.up_projection.weight)
        nn.init.zeros_(self.up_projection.bias)

    def forward(self,embeds):
        S = self.up_projection(embeds)
        S = torch.log(1 + torch.relu(S))
        S = torch.clamp(S, max=10)
        return S

### LSR_VQ - End-to-End Pipeline

In [None]:
class LSR_VQ(nn.Module):
    def __init__(self, model_name, emb_dim, num_clusters, num_chunks, d, D, config):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name,device_map='auto')
        if config['freeze_contriver']:
            for param in self.model.parameters():
                param.requires_grad = False
        self.quantizer = Quantize(emb_dim,num_clusters)
        self.sparse_layer = Sparse_Layer(d,D)
        self.num_chunks = num_chunks
        self.config = config

    def forward(self, inputs):
        query_embs = self.model(**inputs['tokenized_queries'])
        query_dense = get_contriever_emb(query_embs,inputs['tokenized_queries']['attention_mask'])

        pos_embs = self.model(**inputs['tokenized_positives'])
        pos_dense = get_contriever_emb(pos_embs,inputs['tokenized_positives']['attention_mask'])

        neg_embs = self.model(**inputs['tokenized_negatives'])
        neg_dense = get_contriever_emb(neg_embs,inputs['tokenized_negatives']['attention_mask'])

        codebook_loss = 0

        if self.config['use_quantization'] :

            embeds_to_quantize = torch.cat([query_dense,pos_dense,neg_dense],dim=0)

            embeds = split_embedding_into_chunks(embeds_to_quantize,self.num_chunks)

            quantized_embs_chunks, codebook_loss, _ = self.quantizer(embeds,not self.config['freeze_vq'])

            bsz_k, dim_chunk = quantized_embs_chunks.shape
            bsz = bsz_k // self.num_chunks
            quantized_embs = quantized_embs_chunks.view(bsz, dim_chunk * self.num_chunks)

            query_offset = 0
            pos_offset = query_dense.shape[0]
            neg_offset = query_dense.shape[0] + pos_dense.shape[0]

            query_quantized_embs = quantized_embs[:pos_offset]
            pos_quantized_embs = quantized_embs[pos_offset:neg_offset]
            neg_quantized_embs = quantized_embs[neg_offset:]

        else:
          query_quantized_embs = query_dense
          pos_quantized_embs = pos_dense
          neg_quantized_embs = neg_dense

        query_sparse = self.sparse_layer(query_quantized_embs)
        pos_sparse = self.sparse_layer(pos_quantized_embs)
        neg_sparse = self.sparse_layer(neg_quantized_embs)

        return query_dense,query_sparse,pos_dense,pos_sparse,neg_dense,neg_sparse,codebook_loss

    def get_sparse_rep(self,inputs):
        embs = self.model(**inputs)
        dense_embs = get_contriever_emb(embs,inputs['attention_mask'])

        if self.config['use_quantization']:
            embeds = split_embedding_into_chunks(dense_embs,self.num_chunks)

            quantized_embs_chunks, diff, _ = self.quantizer(embeds,False)

            bsz_k, dim_chunk = quantized_embs_chunks.shape
            bsz = bsz_k // self.num_chunks
            quantized_embs = quantized_embs_chunks.view(bsz, dim_chunk * self.num_chunks)
        else:
            quantized_embs = dense_embs

        sparse_rep = self.sparse_layer(quantized_embs)

        return sparse_rep

In [None]:
# Model design choices

num_chunks = 16
vocab_size = 8192
cluster_size = 6000
dense_embed_size = 768

In [None]:
# Config setting

model_config = {
    'use_quantization': True,
    'freeze_contriver': False,
    'freeze_vq': False
}

In [None]:
# Weights and biases reporting to monitor training and results

wandb.init(project = "LSR-VQ", config = model_config)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msensai99[0m ([33mteam-layout[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
model = LSR_VQ(model_name,dense_embed_size//num_chunks,cluster_size,num_chunks,dense_embed_size,vocab_size,model_config)

In [None]:
# To load trained weights

from safetensors.torch import load_file
state_dict = load_file("/content/drive/MyDrive/685/checkpoints/NoVQ/checkpoint-35000/model.safetensors")
model.load_state_dict(state_dict)

<All keys matched successfully>

### Training

In [None]:
# Custom trainer
class LSRVQ_Trainer(Trainer):

    def __init__(self, *args, trainer_config,**kwargs):
        super().__init__(*args, **kwargs)
        self.config = trainer_config
        self.query_reg_lambda = self.config['query_lambda']
        self.passage_reg_lambda = self.config['passage_lambda']

    def compute_loss(self, model, inputs, return_outputs=False,num_items_in_batch=None):

        query_dense, query_sparse, pos_dense, pos_sparse, neg_dense, neg_sparse, codebook_loss = self.model(inputs)

        c_loss = contrastive_loss(query_sparse,pos_sparse,neg_sparse)

        loss = c_loss
        logs = {"contrastive_loss": c_loss.item()}

        # Query sparsity loss
        if self.config['query_sparsity_loss']:
          query_reg_loss = torch.sum(torch.mean(torch.abs(query_sparse), dim=0) ** 2)
          q_loss = self.query_reg_lambda*query_reg_loss
          loss += q_loss
          logs["query_sparsity_loss"] = q_loss.item()

        # Document sparsity loss
        passage_sparse = torch.cat([pos_sparse,neg_sparse], dim=0)
        passage_reg_loss = torch.sum(torch.mean(torch.abs(passage_sparse), dim=0) ** 2)
        p_loss = self.passage_reg_lambda*passage_reg_loss
        logs["passage_sparsity_loss"] = p_loss.item()

        loss += p_loss

        # VQ codebook loss
        if self.config['codebook_loss']:
          loss += codebook_loss
          logs["codebook_loss"] = codebook_loss.item()

        logs["total_loss"] = loss.item()

        if self.config['verbose']:
          if self.config['verbose']:
              for k, v in logs.items():
                  print(f"{k}: {v:.4f}")

        # Log to wandb
        if wandb.run is not None:
            wandb.log(logs)

        return (loss,torch.zeros(1)) if return_outputs else loss


In [None]:
training_args = TrainingArguments(
    output_dir="./weights",
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=1000,
    eval_steps=1000,
    eval_strategy="steps",
    save_total_limit=50,
    save_steps=1000,
    load_best_model_at_end=True,
    greater_is_better=False,
    learning_rate=1e-4,
    lr_scheduler_type = "linear",
    fp16=True,
    report_to='none',
    metric_for_best_model="eval_loss",
    remove_unused_columns=False,
)

In [None]:
trainer_config = {
    'query_lambda': 0.5,
    'passage_lambda': 0.5,
    'query_sparsity_loss': True,
    'codebook_loss': True,
    'verbose': False
}

In [None]:
trainer = LSRVQ_Trainer(
        trainer_config=trainer_config,
        model = model,
        args = training_args,
        data_collator = custom_data_collator,
        train_dataset = train_dataset,
        eval_dataset = eval_dataset,
        )

trainer.can_return_loss=True

def custom_eval_dataloader(eval_set):
    return DataLoader(
    eval_set,
    batch_size=training_args.eval_batch_size,
    sampler=RandomSampler(eval_dataset),
    collate_fn=trainer.data_collator,
)

trainer.get_eval_dataloader = custom_eval_dataloader

In [None]:
trainer.train()

Step,Training Loss,Validation Loss
1000,1.2653,1.427986
2000,0.7929,1.232941
3000,0.756,1.303543
4000,0.7412,0.856673
5000,0.6984,1.020975
6000,0.6905,0.78189
7000,0.6655,0.697751
8000,0.5752,0.716633
9000,0.5731,0.659197
10000,0.5439,0.666202


TrainOutput(global_step=35360, training_loss=0.4392120920155383, metrics={'train_runtime': 12094.8808, 'train_samples_per_second': 187.104, 'train_steps_per_second': 2.924, 'total_flos': 0.0, 'train_loss': 0.4392120920155383, 'epoch': 5.0})

### Inference

In [None]:
# Load dev qrels and metadata

queries_dev = tsv_to_dict_unqiue(os.path.join('/content/drive/MyDrive/685/data', "queries.dev.small.tsv"))
qrels_dev = tsv_to_dict_multiple(os.path.join('/content/drive/MyDrive/685/data', "qrels.dev.small.tsv"), keys = [0, 2])

with open('/content/drive/MyDrive/685/embeddings/dev/passage_ids.json') as f:
    inference_passage_ids = json.load(f)

In [None]:
# Compute sparse embeddings 

from datasets import load_dataset

def encode_texts(model, tokenizer, text_ds, batch_size, verbose=False):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    text_ids = [i['_id'] for i in text_ds]
    dataloader = DataLoader(text_ds, batch_size=batch_size, shuffle=False)
    allemb = []

    model.to(device)
    model.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader):
            prepare_batch = [batch['text'][i] for i in range(len(batch['text']))]
            tokenized_texts = tokenizer.batch_encode_plus(
                prepare_batch,
                max_length=tokenizer.model_max_length,
                padding=True,
                truncation=True,
                return_tensors="pt",
                add_special_tokens=True,
            ).to(device)
            if verbose:
                print("Batch size:", len(batch['text']))
                print(f"Tokenized texts shape: {tokenized_texts['input_ids'].shape}")
                print(f"Decode text sample {tokenizer.decode(tokenized_texts['input_ids'][0])}")
            emb = model.get_sparse_rep(tokenized_texts)
            allemb.append(emb)
    allemb = torch.cat(allemb, dim=0)
    return allemb, text_ids

In [None]:
# Corpurs and dev queries dataset

corpus_ds = Dataset.from_dict({'_id':inference_passage_ids,'text':[passages[_id] for _id in inference_passage_ids]})
queries_ds = Dataset.from_dict({'_id':list(queries_dev.keys()),'text':list(queries_dev.values())})

### Top-k Sparsification

In [None]:
passage_topk = 1024
query_topk = 16
min_weight = 1e-5

In [None]:
# Passage sparsification

p_sparse_embeddings, passage_ids = encode_texts(model, tokenizer, corpus_ds, batch_size=1024)

# Apply top-k sparsification
values, indices = torch.topk(p_sparse_embeddings.abs(), k=passage_topk, dim=1)
sparse_embeddings = torch.zeros_like(p_sparse_embeddings)
sparse_embeddings.scatter_(1, indices, p_sparse_embeddings.gather(1, indices))

passage_sparse_embeddings = sparse_embeddings.detach().cpu().numpy()

100%|██████████| 8/8 [01:29<00:00, 11.22s/it]


In [None]:
# Query Sparsification

q_sparse_embeddings, query_ids = encode_texts(model, tokenizer, queries_ds, batch_size=128)

values, indices = torch.topk(q_sparse_embeddings.abs(), k=query_topk, dim=1)
sparse_embeddings = torch.zeros_like(q_sparse_embeddings)
sparse_embeddings.scatter_(1, indices, q_sparse_embeddings.gather(1, indices))
sparse_embeddings[sparse_embeddings.abs() < min_weight] = 0
query_sparse_embeddings = sparse_embeddings.detach().cpu().numpy()

100%|██████████| 55/55 [00:08<00:00,  6.44it/s]


### Inverted Index

In [None]:
rows, cols, data = [], [], []

for j, doc_id in enumerate(passage_ids):
    doc_embedding = passage_sparse_embeddings[j]
    # Get non-zero indices and their values
    nonzero_indices = np.nonzero(doc_embedding)[0]
    for idx in nonzero_indices:
        weight = doc_embedding[idx]
        if weight > min_weight:
          rows.append(j)
          cols.append(idx)
          data.append(float(weight))

optimized_index = csr_matrix((data, (rows, cols)), shape = (len(passage_ids), vocab_size), dtype = np.float32)

### Search

In [None]:
batch_size = 512
all_results = {}
for start in tqdm(range(0, len(queries_ds), batch_size), desc = "Evaluating queries (csr)..."):
    end = min(start + batch_size, len(queries_ds))
    query_chunk = query_sparse_embeddings[start:end]
    scores_chunk = query_chunk @ optimized_index.T  # shape: [batch_size, num_passages]

    topk_idx = np.argpartition(-scores_chunk, 1000, axis = 1)[:, :1000]
    for i in range(scores_chunk.shape[0]):
        passage_scores = scores_chunk[i]
        top_i = topk_idx[i]
        top_indices = top_i[np.argsort(-passage_scores[top_i])]
        all_results[query_ids[start + i]] = [(passage_ids[t], passage_scores[t]) for t in top_indices]

Evaluating queries (csr)...: 100%|██████████| 14/14 [00:04<00:00,  2.81it/s]


### Metrics

In [None]:
run_dict = {}
for qid, results in all_results.items():
    run_dict[qid] = {
        str(passage_id): float(score)
        for passage_id, score in results
    }
run = Run(run_dict)

qrels_dict = {
    qid: {str(passage_id): 1 for passage_id in qrels_dev[qid]}
    for qid in all_results.keys()
}
qrels = Qrels(qrels_dict)

# Evaluate using rankx
metrics = ["ndcg@10", "ndcg@100", "ndcg@1000", "recall@10", "recall@100", "recall@1000", "mrr@10"]
results = evaluate(qrels, run, metrics)

m = (results["mrr@10"],
    {
        '10': results["ndcg@10"],
        '100': results["ndcg@100"],
        '1000': results["ndcg@1000"]
    },
    {
        '10': results["recall@10"],
        '100': results["recall@100"],
        '1000': results["recall@1000"]
    }
)

In [None]:
m

(np.float64(0.3973082958111612),
 {'10': np.float64(0.42941978161620187),
  '100': np.float64(0.47334865342734245),
  '1000': np.float64(0.49263936881957143)},
 {'10': np.float64(0.5501552053486151),
  '100': np.float64(0.7603271251193887),
  '1000': np.float64(0.9148519579751672)})

In [None]:
m

(np.float64(0.29163420430254244),
 {'10': np.float64(0.3340381478056171),
  '100': np.float64(0.39728924184075215),
  '1000': np.float64(0.41999863894721956)},
 {'10': np.float64(0.48475405921681),
  '100': np.float64(0.7870582617000954),
  '1000': np.float64(0.9635148042024833)})

In [None]:
m

(np.float64(0.225375676536135),
 {'10': np.float64(0.26898959359941194),
  '100': np.float64(0.34772644249372625),
  '1000': np.float64(0.37047270527950577)},
 {'10': np.float64(0.4211437440305635),
  '100': np.float64(0.7980659025787966),
  '1000': np.float64(0.9747731614135625)})

In [None]:
# 4 epochs
# all losses, high lambdas
m

(np.float64(0.21746043116386957),
 {'10': np.float64(0.2660949538941461),
  '100': np.float64(0.3493895628275042),
  '1000': np.float64(0.3683083365871593)},
 {'10': np.float64(0.4358046800382044),
  '100': np.float64(0.8304560649474689),
  '1000': np.float64(0.9763849092645653)})

In [None]:
# 3 epochs 5L
# all losses, high lambda
m

(np.float64(0.6774686746713968),
 {'10': np.float64(0.7210872497108043),
  '100': np.float64(0.7418761424412927),
  '1000': np.float64(0.7458216151356812)},
 {'10': np.float64(0.8706542502387774),
  '100': np.float64(0.9649116523400192),
  '1000': np.float64(0.9950453677172876)})

In [None]:
# 5 epochs
# all lossses, high lambda
m

(np.float64(0.7131263360167371),
 {'10': np.float64(0.7521769361768146),
  '100': np.float64(0.7711645309332931),
  '1000': np.float64(0.774010873260608)},
 {'10': np.float64(0.88810888252149),
  '100': np.float64(0.9743791786055397),
  '1000': np.float64(0.9963705826170008)})

In [None]:
# 5L evaluation - VQ
m

(np.float64(0.2301698162550598),
 {'10': np.float64(0.27255224748544954),
  '100': np.float64(0.3363970047951509),
  '1000': np.float64(0.3606111138861944)},
 {'10': np.float64(0.41793218720152814),
  '100': np.float64(0.7232449856733524),
  '1000': np.float64(0.9129297994269341)})

In [None]:
# 5L evaluation - without VQ
m

(np.float64(0.46837011643243737),
 {'10': np.float64(0.5200407078961042),
  '100': np.float64(0.5662247216678938),
  '1000': np.float64(0.5754124549094598)},
 {'10': np.float64(0.6971227316141355),
  '100': np.float64(0.9091690544412607),
  '1000': np.float64(0.9800740210124164)})

In [None]:
# VQ with KMeans intitialization
m

(np.float64(0.720491767862828),
 {'10': np.float64(0.7599282328907433),
  '100': np.float64(0.7776452977217517),
  '1000': np.float64(0.7803430710930395)},
 {'10': np.float64(0.8947946513849093),
  '100': np.float64(0.9746895893027699),
  '1000': np.float64(0.9954990448901623)})

In [None]:
# 5L evaluation - VQ with KMeans intialization
m

(np.float64(0.2298591781507254),
 {'10': np.float64(0.2724794560437748),
  '100': np.float64(0.3377962385283097),
  '1000': np.float64(0.3615194341372948)},
 {'10': np.float64(0.4181470869149952),
  '100': np.float64(0.7311962750716332),
  '1000': np.float64(0.9163681948424068)})

### Check sparsity

In [None]:
#check query sparsity
avg_sparsity = 0
for i in range(len(query_sparse_embeddings)):
  avg_sparsity += np.where(query_sparse_embeddings[i] > 0)[0].shape[0]
  # print(np.where(query_sparse_embeddings[i] > 0)[0].shape)
print('avg sparsity', avg_sparsity/len(query_sparse_embeddings))

avg sparsity 16.0


In [None]:
# check passage sparsity
avg_sparsity = 0
for row_idx in range(optimized_index.shape[0]):
    row_data = optimized_index.getrow(row_idx)
    non_zero_count = row_data.nnz
    avg_sparsity += non_zero_count
    # print(non_zero_count)
print('average sparsity',avg_sparsity/optimized_index.shape[0])

average sparsity 64.0
