In [1]:
!pip install transformers
!pip install datasets
!pip install --upgrade accelerate
!pip install rank-eval
!pip install ranx

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (1

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
from transformers import AutoModel, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.autonotebook import trange
import numpy as np
import json
import random
from torch.utils.data import DataLoader, RandomSampler
from tqdm import tqdm
from ranx import Qrels, Run, evaluate

In [4]:
import os
import gc
from collections import defaultdict
import heapq
from scipy.sparse import csr_matrix
import pandas as pd

In [5]:
from utils import tsv_to_dict_multiple, tsv_to_dict_unqiue, split_embedding_into_chunks

In [23]:
# model = None
# trainer = None
gc.collect()
torch.cuda.empty_cache()

In [24]:
!nvidia-smi

Wed May  7 03:30:28 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   31C    P0             48W /  400W |   30367MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

### Load raw training data

In [6]:
# base-model
# model_name = 'facebook/contriever'
# fine-tuned MS-MARCO
model_name = "facebook/contriever-msmarco"
tokenizer = AutoTokenizer.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [7]:
passages = tsv_to_dict_unqiue(os.path.join('/content/drive/MyDrive/685/data', "collection.tsv"))
queries = tsv_to_dict_unqiue(os.path.join('/content/drive/MyDrive/685/data', "queries.train.tsv"))

In [None]:
def create_dataset(address, n_negatives=2,cache_dir='hf',sample_count=1000):
    def train_gen():
      with open(address,'r') as f:
          count = 0
          for line in f:
              count += 1
              if sample_count != None and count >= sample_count: return
              data_sample = json.loads(line)
              negatives = []
              for neg_id in random.sample(data_sample['neg']['bm25'],n_negatives): negatives.append(passages[str(neg_id)])
              yield {'query': queries[str(data_sample['qid'])],'positive':passages[str(random.sample(data_sample['pos'],1)[0])],'negatives':[negatives]}
    return Dataset.from_generator(train_gen, cache_dir = cache_dir)

In [None]:
dataset = create_dataset(os.path.join('/content/drive/MyDrive/685/data','msmarco-hard-negatives-bm25_1k.jsonl'), 2, 'hf',1000000)

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
def create_dataset(df, n_negatives=1,cache_dir='hf',sample_count=1000):
    def train_gen():
      samp_df = df.sample(n=sample_count, random_state=42)
      for index, row in samp_df.iterrows():
            yield {'query': queries[str(int(row[2]))],'positive':passages[str(int(row[3]))],'negatives':passages[str(int(row[4]))]}
    return Dataset.from_generator(train_gen, cache_dir = cache_dir)

In [None]:
df = pd.read_csv(os.path.join('/content/drive/MyDrive/685/data','bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv'),sep='\t',header=None, index_col=None)

In [None]:
dataset = create_dataset(df, 1, 'hf',50)

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
split_dataset = dataset.train_test_split(test_size=0.1)

# There might be a split over
train_dataset = split_dataset['train']
eval_dataset = split_dataset['test']

### Custom data collator

In [None]:
from typing import Any, Dict, List, Union

class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        query_texts = []
        pos_texts = []
        neg_texts = []

        for feature in features:
            query_texts.append(feature['query'])
            pos_texts.append(feature['positive'])
            if 'negatives' in feature.keys():
                for neg_text in feature['negatives']: neg_texts.append(neg_text)

        tokenized_query_texts = self.tokenizer(
                query_texts,
                max_length=self.max_length,
                padding=self.padding,
                truncation=True,
                return_tensors="pt",
                add_special_tokens=True)

        tokenized_pos_texts = self.tokenizer(
                pos_texts,
                max_length=self.max_length,
                padding=self.padding,
                truncation=True,
                return_tensors="pt",
                add_special_tokens=True)

        tokenized_texts = {
            'tokenized_queries' : tokenized_query_texts,
            'tokenized_positives' : tokenized_pos_texts,
        }


        if len(neg_texts) > 0:
            tokenized_neg_texts = self.tokenizer(
                  neg_texts,
                  max_length=self.max_length,
                  padding=self.padding,
                  truncation=True,
                  return_tensors="pt",
                  add_special_tokens=True)
            tokenized_texts['tokenized_negatives'] = tokenized_neg_texts

        return tokenized_texts

custom_data_collator = CustomDataCollatorWithPadding(
    tokenizer=tokenizer,
    padding='longest',
    max_length=tokenizer.model_max_length
)

In [None]:
out_train = custom_data_collator([train_dataset[i] for i in range(8)])
print(out_train['tokenized_queries']['input_ids'].shape)
print(out_train['tokenized_positives']['input_ids'].shape)
print(out_train['tokenized_negatives']['input_ids'].shape)

torch.Size([8, 36])
torch.Size([8, 139])
torch.Size([8, 191])


In [None]:
out_eval = custom_data_collator([eval_dataset[i] for i in range(2)])
print(out_eval['tokenized_queries']['input_ids'].shape)
print(out_eval['tokenized_positives']['input_ids'].shape)

torch.Size([2, 14])
torch.Size([2, 120])


### Contrastive loss

In [8]:
#enter your code here
def get_contriever_emb(model_output, attention_mask):
    last_hidden = model_output["last_hidden_state"]
    last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
    emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    return emb

def contrastive_loss(query_embs,pos_embs,neg_embs):
    temp = 0.05
    query_embs = query_embs
    similarities_pos = torch.mm(query_embs,torch.transpose(pos_embs,0,1))/temp
    similarities = similarities_pos
    if neg_embs != None:
        similarities_neg = torch.mm(query_embs,torch.transpose(neg_embs,0,1))/temp
        similarities = torch.cat((similarities_pos,similarities_neg),dim=1)
    return F.cross_entropy(similarities,torch.arange(0,query_embs.shape[0]).to('cuda'))




### Models

In [9]:
class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.99, eps=1e-5, random_initialization=False):
        super().__init__()
        self.dim = dim
        self.n_embed = n_embed
        self.decay = decay
        self.eps = eps


        if random_initialization:
          embed = torch.randn(dim, n_embed)
        else:
          embed = torch.load(os.path.join('/content/drive/MyDrive/685/code','codebook_vectors.pt')).T

        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(n_embed))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input,training=True):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim=True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            # If not using distributed training, comment these out
            # dist_fn.all_reduce(embed_onehot_sum)
            # dist_fn.all_reduce(embed_sum)

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)

            n = self.cluster_size.sum()
            cluster_size = (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

            # --- Dead code reinitialization ---
            with torch.no_grad():
                dead_mask = self.cluster_size < 0.1
                num_dead = dead_mask.sum().item()
                if num_dead > 0 and flatten.size(0) > 0:
                    random_indices = torch.randint(0, flatten.size(0), (num_dead,))
                    self.embed[:, dead_mask] = flatten[random_indices].t()
                    self.embed_avg[:, dead_mask] = flatten[random_indices].t()
                    self.cluster_size[dead_mask] = 1.0  # or a small positive value

        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

In [10]:
class Sparse_Layer(nn.Module):
    def __init__(self,d,D):
        super().__init__()

        self.up_projection = nn.Linear(d,D)
        nn.init.xavier_uniform_(self.up_projection.weight)
        nn.init.zeros_(self.up_projection.bias)

    def forward(self,embeds):
        S = self.up_projection(embeds)
        S = torch.log(1 + torch.relu(S))
        S = torch.clamp(S, max=10)
        return S

In [11]:
class LSR_VQ(nn.Module):
    def __init__(self, model_name, emb_dim, num_clusters, num_chunks, d, D, config):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name,device_map='auto')
        if config['freeze_contriver']:
            for param in self.model.parameters():
                param.requires_grad = False
        self.quantizer = Quantize(emb_dim,num_clusters)
        self.sparse_layer = Sparse_Layer(d,D)
        self.num_chunks = num_chunks
        self.config = config

    def forward(self, inputs):
        query_embs = self.model(**inputs['tokenized_queries'])
        query_dense = get_contriever_emb(query_embs,inputs['tokenized_queries']['attention_mask'])

        pos_embs = self.model(**inputs['tokenized_positives'])
        pos_dense = get_contriever_emb(pos_embs,inputs['tokenized_positives']['attention_mask'])

        neg_embs = self.model(**inputs['tokenized_negatives'])
        neg_dense = get_contriever_emb(neg_embs,inputs['tokenized_negatives']['attention_mask'])

        codebook_loss = 0

        if self.config['use_quantization'] :

            embeds_to_quantize = torch.cat([query_dense,pos_dense,neg_dense],dim=0)

            embeds = split_embedding_into_chunks(embeds_to_quantize,self.num_chunks)

            quantized_embs_chunks, codebook_loss, _ = self.quantizer(embeds,not self.config['freeze_vq'])

            bsz_k, dim_chunk = quantized_embs_chunks.shape
            bsz = bsz_k // self.num_chunks
            quantized_embs = quantized_embs_chunks.view(bsz, dim_chunk * self.num_chunks)

            query_offset = 0
            pos_offset = query_dense.shape[0]
            neg_offset = query_dense.shape[0] + pos_dense.shape[0]

            query_quantized_embs = quantized_embs[:pos_offset]
            pos_quantized_embs = quantized_embs[pos_offset:neg_offset]
            neg_quantized_embs = quantized_embs[neg_offset:]

        else:
          query_quantized_embs = query_dense
          pos_quantized_embs = pos_dense
          neg_quantized_embs = neg_dense

        query_sparse = self.sparse_layer(query_quantized_embs)
        pos_sparse = self.sparse_layer(pos_quantized_embs)
        neg_sparse = self.sparse_layer(neg_quantized_embs)

        return query_dense,query_sparse,pos_dense,pos_sparse,neg_dense,neg_sparse,codebook_loss

    def get_sparse_rep(self,inputs):
        embs = self.model(**inputs)
        dense_embs = get_contriever_emb(embs,inputs['attention_mask'])

        if self.config['use_quantization']:
            embeds = split_embedding_into_chunks(dense_embs,self.num_chunks)

            quantized_embs_chunks, diff, _ = self.quantizer(embeds,False)

            bsz_k, dim_chunk = quantized_embs_chunks.shape
            bsz = bsz_k // self.num_chunks
            quantized_embs = quantized_embs_chunks.view(bsz, dim_chunk * self.num_chunks)
        else:
            quantized_embs = dense_embs

        sparse_rep = self.sparse_layer(quantized_embs)

        return sparse_rep

In [12]:
num_chunks = 16
vocab_size = 8192
cluster_size = 6000

In [13]:
model_config = {
    'use_quantization': True,
    'freeze_contriver': False,
    'freeze_vq': False
}

In [14]:
model = LSR_VQ(model_name,768//num_chunks,cluster_size,num_chunks,768,vocab_size,model_config)

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [15]:
# load trained model if needed
from safetensors.torch import load_file

state_dict = load_file("/content/drive/MyDrive/685/checkpoints/BM25_hard_4epochs_5L/checkpoint-20000/model.safetensors")
model.load_state_dict(state_dict)

<All keys matched successfully>

### Training

In [None]:
#enter your code here to define trainer
class LSRVQ_Trainer(Trainer):

    def __init__(self, *args, trainer_config,**kwargs):
        super().__init__(*args, **kwargs)
        self.config = trainer_config
        self.query_reg_lambda = self.config['query_lambda']
        self.passage_reg_lambda = self.config['passage_lambda']

    def compute_loss(self, model, inputs, return_outputs=False,num_items_in_batch=None):

        query_dense, query_sparse, pos_dense, pos_sparse, neg_dense, neg_sparse, codebook_loss = self.model(inputs)

        c_loss = contrastive_loss(query_sparse,pos_sparse,neg_sparse)

        loss = c_loss

        # Query sparsity loss
        if self.config['query_sparsity_loss']:
          query_reg_loss = torch.sum(torch.mean(torch.abs(query_sparse), dim=0) ** 2)
          q_loss = self.query_reg_lambda*query_reg_loss
          loss += q_loss

        # Document sparsity loss
        passage_sparse = torch.cat([pos_sparse,neg_sparse], dim=0)
        passage_reg_loss = torch.sum(torch.mean(torch.abs(passage_sparse), dim=0) ** 2)
        # passage_reg_loss = passage_sparse.abs().mean()
        p_loss = self.passage_reg_lambda*passage_reg_loss

        loss += p_loss

        if self.config['codebook_loss']:
          loss += codebook_loss

        if self.config['verbose']:
          print('c_loss',c_loss)
          if self.config['query_sparsity_loss']: print('q_loss',q_loss)
          print('p_loss',p_loss)
          if self.config['codebook_loss']: print('codebook_loss',codebook_loss)

        return (loss,torch.zeros(1)) if return_outputs else loss


In [None]:
training_args = TrainingArguments(
    output_dir="./weights",
    overwrite_output_dir=True,
    num_train_epochs=4,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=1000,
    eval_steps=1000,
    eval_strategy="steps",
    save_total_limit=50,
    save_steps=1000,
    load_best_model_at_end=True,
    greater_is_better=False,
    learning_rate=1e-4,
    lr_scheduler_type = "linear",
    fp16=True,
    report_to='none',
    metric_for_best_model="eval_loss",
    remove_unused_columns=False,
)

In [None]:
trainer_config = {
    'query_lambda': 0.5,
    'passage_lambda': 0.5,
    'query_sparsity_loss': True,
    'codebook_loss': True,
    'verbose': False
}

In [None]:
trainer = LSRVQ_Trainer(
        trainer_config=trainer_config,
        model = model,
        args = training_args,
        data_collator = custom_data_collator,
        train_dataset = train_dataset,
        eval_dataset = eval_dataset,
        )

trainer.can_return_loss=True

def custom_eval_dataloader(eval_set):
    return DataLoader(
    eval_set,
    batch_size=training_args.eval_batch_size,
    sampler=RandomSampler(eval_dataset),
    collate_fn=trainer.data_collator,
)

trainer.get_eval_dataloader = custom_eval_dataloader

In [None]:
trainer.train()

Step,Training Loss,Validation Loss
1000,1.3781,1.854664
2000,0.804,1.274487
3000,0.7543,1.085198
4000,0.7248,0.90699
5000,0.6846,0.837574
6000,0.6706,0.7995
7000,0.6386,0.706165
8000,0.5384,0.837233
9000,0.529,0.657306
10000,0.5138,0.656339


Step,Training Loss,Validation Loss
1000,1.3781,1.854664
2000,0.804,1.274487
3000,0.7543,1.085198
4000,0.7248,0.90699
5000,0.6846,0.837574
6000,0.6706,0.7995
7000,0.6386,0.706165
8000,0.5384,0.837233
9000,0.529,0.657306
10000,0.5138,0.656339


KeyboardInterrupt: 

In [None]:
cp -R './weights/checkpoint-14000' '/content/drive/MyDrive/685/checkpoints/BM25_hard_4epochs_5L/'

### Inference

In [16]:
queries_dev = tsv_to_dict_unqiue(os.path.join('/content/drive/MyDrive/685/data', "queries.dev.small.tsv"))
qrels_dev = tsv_to_dict_multiple(os.path.join('/content/drive/MyDrive/685/data', "qrels.dev.small.tsv"), keys = [0, 2])

with open('/content/drive/MyDrive/685/embeddings/train/5L_wrong_query/passage_ids.json') as f:
    inference_passage_ids = json.load(f)

In [None]:
import ast
q = pd.read_csv(os.path.join('/content/drive/MyDrive/685/data', "small_qrels.csv"))
qrels_dev = {}
for index,row in q.iterrows():
  qrels_dev[str(row['qid'])] = ast.literal_eval(row['pid'])

queries_small = {}
for qid in qrels_dev:
  queries_small[qid] = queries_dev[qid]
queries_dev = queries_small

inference_passage_ids = set()
for i in qrels_dev.values():
  for pid in i: inference_passage_ids.add(pid)
inference_passage_ids = list(inference_passage_ids)

In [17]:
from datasets import load_dataset

def encode_texts(model, tokenizer, text_ds, batch_size, verbose=False):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    text_ids = [i['_id'] for i in text_ds]
    dataloader = DataLoader(text_ds, batch_size=batch_size, shuffle=False)
    allemb = []

    model.to(device)
    model.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader):
            prepare_batch = [batch['text'][i] for i in range(len(batch['text']))]
            tokenized_texts = tokenizer.batch_encode_plus(
                prepare_batch,
                max_length=tokenizer.model_max_length,
                padding=True,
                truncation=True,
                return_tensors="pt",
                add_special_tokens=True,
            ).to(device)
            if verbose:
                print("Batch size:", len(batch['text']))
                print(f"Tokenized texts shape: {tokenized_texts['input_ids'].shape}")
                print(f"Decode text sample {tokenizer.decode(tokenized_texts['input_ids'][0])}")
            emb = model.get_sparse_rep(tokenized_texts)
            # allemb.append(emb.cpu())
            allemb.append(emb)
    allemb = torch.cat(allemb, dim=0)
    # allemb = allemb.cpu().numpy()
    return allemb, text_ids

In [18]:
corpus_ds = Dataset.from_dict({'_id':inference_passage_ids,'text':[passages[_id] for _id in inference_passage_ids]})
queries_ds = Dataset.from_dict({'_id':list(queries_dev.keys()),'text':list(queries_dev.values())})

In [19]:
len(corpus_ds)

516472

### Top-k Sparsification

In [25]:
passage_topk = 1024

p_sparse_embeddings, passage_ids = encode_texts(model, tokenizer, corpus_ds, batch_size=1024)


 69%|██████▊   | 346/505 [14:16<06:33,  2.47s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.83 GiB. GPU 0 has a total capacity of 39.56 GiB of which 1.45 GiB is free. Process 3208 has 38.10 GiB memory in use. Of the allocated memory 23.50 GiB is allocated by PyTorch, and 14.11 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Apply top-k sparsification
values, indices = torch.topk(p_sparse_embeddings.abs(), k=passage_topk, dim=1)
sparse_embeddings = torch.zeros_like(p_sparse_embeddings)
sparse_embeddings.scatter_(1, indices, p_sparse_embeddings.gather(1, indices))

passage_sparse_embeddings = sparse_embeddings.detach().cpu().numpy()

### Inverted Index

In [None]:
rows, cols, data = [], [], []

min_weight = 1e-5
for j, doc_id in enumerate(passage_ids):
    doc_embedding = passage_sparse_embeddings[j]
    # Get non-zero indices and their values
    nonzero_indices = np.nonzero(doc_embedding)[0]  # Ensure indices are integers
    for idx in nonzero_indices:
        weight = doc_embedding[idx]
        if weight > min_weight:
          rows.append(j)
          cols.append(idx)
          data.append(float(weight))

optimized_index = csr_matrix((data, (rows, cols)), shape = (len(passage_ids), vocab_size), dtype = np.float32)

### Search

In [None]:
q_sparse_embeddings, query_ids = encode_texts(model, tokenizer, queries_ds, batch_size=128)

100%|██████████| 55/55 [00:02<00:00, 27.19it/s]


In [None]:
query_topk = 256
values, indices = torch.topk(q_sparse_embeddings.abs(), k=query_topk, dim=1)
sparse_embeddings = torch.zeros_like(q_sparse_embeddings)
sparse_embeddings.scatter_(1, indices, q_sparse_embeddings.gather(1, indices))
sparse_embeddings[sparse_embeddings.abs() < min_weight] = 0
query_sparse_embeddings = sparse_embeddings.detach().cpu().numpy()

In [None]:
batch_size = 512
all_results = {}
for start in tqdm(range(0, len(queries_ds), batch_size), desc = "Evaluating queries (csr)..."):
    end = min(start + batch_size, len(queries_ds))
    query_chunk = query_sparse_embeddings[start:end]
    scores_chunk = query_chunk @ optimized_index.T  # shape: [batch_size, num_passages]

    topk_idx = np.argpartition(-scores_chunk, 1000, axis = 1)[:, :1000]
    for i in range(scores_chunk.shape[0]):
        passage_scores = scores_chunk[i]
        top_i = topk_idx[i]
        top_indices = top_i[np.argsort(-passage_scores[top_i])]

        # Ensure self.passage_ids is referring to the correct mode (train/dev)
        all_results[query_ids[start + i]] = [(passage_ids[t], passage_scores[t]) for t in top_indices]

Evaluating queries (csr)...: 100%|██████████| 14/14 [00:05<00:00,  2.64it/s]


In [None]:
len(all_results)

6980

### Metrics

In [None]:
run_dict = {}
for qid, results in all_results.items():
    run_dict[qid] = {
        str(passage_id): float(score)
        for passage_id, score in results
    }
run = Run(run_dict)

qrels_dict = {
    qid: {str(passage_id): 1 for passage_id in qrels_dev[qid]}
    for qid in all_results.keys()
}
qrels = Qrels(qrels_dict)

# Evaluate using rankx
metrics = ["ndcg@10", "ndcg@100", "ndcg@1000", "recall@10", "recall@100", "recall@1000", "mrr@10"]
results = evaluate(qrels, run, metrics)

m = (results["mrr@10"],
    {
        '10': results["ndcg@10"],
        '100': results["ndcg@100"],
        '1000': results["ndcg@1000"]
    },
    {
        '10': results["recall@10"],
        '100': results["recall@100"],
        '1000': results["recall@1000"]
    }
)

  scores[i] = _ndcg(qrels[i], run[i], k, rel_lvl, jarvelin)


In [None]:
m

(np.float64(0.45659094010096873),
 {'10': np.float64(0.4909917763933433),
  '100': np.float64(0.5370616818207461),
  '1000': np.float64(0.5541936891033905)},
 {'10': np.float64(0.6196155682903535),
  '100': np.float64(0.8396012416427888),
  '1000': np.float64(0.9742120343839542)})

In [None]:
m

(np.float64(0.29163420430254244),
 {'10': np.float64(0.3340381478056171),
  '100': np.float64(0.39728924184075215),
  '1000': np.float64(0.41999863894721956)},
 {'10': np.float64(0.48475405921681),
  '100': np.float64(0.7870582617000954),
  '1000': np.float64(0.9635148042024833)})

In [None]:
m

(np.float64(0.225375676536135),
 {'10': np.float64(0.26898959359941194),
  '100': np.float64(0.34772644249372625),
  '1000': np.float64(0.37047270527950577)},
 {'10': np.float64(0.4211437440305635),
  '100': np.float64(0.7980659025787966),
  '1000': np.float64(0.9747731614135625)})

In [None]:
# 4 epochs
# all losses, high lambdas
m

(np.float64(0.21746043116386957),
 {'10': np.float64(0.2660949538941461),
  '100': np.float64(0.3493895628275042),
  '1000': np.float64(0.3683083365871593)},
 {'10': np.float64(0.4358046800382044),
  '100': np.float64(0.8304560649474689),
  '1000': np.float64(0.9763849092645653)})

In [None]:
# 3 epochs 5L
# all losses, high lambda
m

(np.float64(0.6774686746713968),
 {'10': np.float64(0.7210872497108043),
  '100': np.float64(0.7418761424412927),
  '1000': np.float64(0.7458216151356812)},
 {'10': np.float64(0.8706542502387774),
  '100': np.float64(0.9649116523400192),
  '1000': np.float64(0.9950453677172876)})

In [None]:
#check query sparsity
avg_sparsity = 0
for i in range(len(query_sparse_embeddings)):
  avg_sparsity += np.where(query_sparse_embeddings[i] > 0)[0].shape[0]
  print(np.where(query_sparse_embeddings[i] > 0)[0].shape)
print('avg sparsity', avg_sparsity/len(query_sparse_embeddings))

(45,)
(35,)
(29,)
(45,)
(26,)
(66,)
(47,)
(36,)
(55,)
(35,)
(46,)
(29,)
(57,)
(56,)
(57,)
(44,)
(57,)
(22,)
(36,)
(54,)
(30,)
(31,)
(22,)
(55,)
(38,)
(32,)
(47,)
(41,)
(34,)
(40,)
(43,)
(34,)
(43,)
(60,)
(33,)
(37,)
(37,)
(44,)
(49,)
(19,)
(44,)
(36,)
(30,)
(29,)
(40,)
(40,)
(39,)
(46,)
(33,)
(38,)
(54,)
(40,)
(40,)
(43,)
(51,)
(38,)
(48,)
(63,)
(51,)
(46,)
(40,)
(41,)
(38,)
(32,)
(36,)
(16,)
(50,)
(38,)
(33,)
(40,)
(41,)
(45,)
(44,)
(42,)
(32,)
(39,)
(61,)
(54,)
(41,)
(32,)
(33,)
(27,)
(39,)
(42,)
(39,)
(47,)
(64,)
(47,)
(43,)
(67,)
(81,)
(38,)
(84,)
(33,)
(17,)
(61,)
(52,)
(57,)
(41,)
(30,)
(52,)
(59,)
(40,)
(49,)
(32,)
(50,)
(31,)
(33,)
(41,)
(53,)
(56,)
(57,)
(37,)
(37,)
(24,)
(101,)
(16,)
(50,)
(32,)
(57,)
(60,)
(53,)
(49,)
(23,)
(52,)
(33,)
(25,)
(41,)
(40,)
(35,)
(30,)
(44,)
(42,)
(52,)
(35,)
(36,)
(52,)
(32,)
(54,)
(41,)
(32,)
(49,)
(34,)
(76,)
(36,)
(47,)
(29,)
(43,)
(47,)
(39,)
(40,)
(47,)
(48,)
(41,)
(83,)
(44,)
(52,)
(32,)
(47,)
(42,)
(40,)
(40,)
(40,)
(51,)
(43,)
(40,)
(38

In [None]:
# check passage sparsity
avg_sparsity = 0
for row_idx in range(optimized_index.shape[0]):
    row_data = optimized_index.getrow(row_idx)
    non_zero_count = row_data.nnz
    avg_sparsity += non_zero_count
    print(non_zero_count)
print('average sparsity',avg_sparsity/optimized_index.shape[0])

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
52
107
100
106
83
82
96
63
96
77
123
109
70
101
63
49
97
85
107
49
86
96
123
62
112
52
95
69
120
63
109
80
110
116
68
64
59
46
60
72
100
98
62
64
124
93
118
89
95
73
82
50
76
96
81
77
110
90
85
118
103
104
117
73
59
125
118
60
89
107
90
95
101
57
75
79
91
119
48
103
95
65
75
102
106
69
96
100
48
119
102
88
66
75
100
59
71
43
75
91
100
60
99
123
73
73
106
137
124
58
104
109
80
90
97
109
66
79
102
81
96
128
80
87
82
60
76
49
105
75
64
59
91
55
93
75
79
75
86
84
107
133
114
55
115
109
80
87
68
99
94
85
67
109
81
67
66
98
88
75
61
84
48
96
56
80
59
98
116
71
78
175
88
65
50
35
55
54
55
141
75
110
79
81
68
58
56
78
103
67
69
82
76
102
100
52
78
69
55
55
55
72
69
96
67
57
77
45
53
62
70
76
78
60
58
120
89
79
68
87
102
98
44
121
61
82
105
77
86
71
100
96
98
104
113
100
89
79
115
54
83
76
107
117
88
61
104
102
105
108
94
84
87
64
97
72
106
91
68
101
62
97
138
83
88
82
53
43
114
81
77
140
65
66
81
112
106
86
71
82
85
57
102
73
93
