In [1]:
!pip install transformers
!pip install datasets
!pip install --upgrade accelerate
!pip install rank-eval
!pip install ranx

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl 

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
from transformers import AutoModel, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.autonotebook import trange
import numpy as np
import json
import random
from torch.utils.data import DataLoader, RandomSampler
from tqdm import tqdm
from ranx import Qrels, Run, evaluate

In [4]:
import os
import gc

In [5]:
import sys
sys.path.append('/content/drive/MyDrive/685/code')
from utils import tsv_to_dict_multiple, tsv_to_dict_unqiue

In [6]:
model = None
gc.collect()
torch.cuda.empty_cache()

In [7]:
!nvidia-smi

Fri May  2 03:49:20 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   30C    P0             51W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

### Load raw training data

In [8]:
# base-model
# model_name = 'facebook/contriever'
# fine-tuned MS-MARCO
model_name = "facebook/contriever-msmarco"
tokenizer = AutoTokenizer.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [9]:
passages = tsv_to_dict_unqiue(os.path.join('/content/drive/MyDrive/685/data', "collection.tsv"))
queries = tsv_to_dict_unqiue(os.path.join('/content/drive/MyDrive/685/data', "queries.train.tsv"))

In [10]:
def create_dataset(address, n_negatives=2,cache_dir='hf',sample_count=1000):
    def train_gen():
      with open(address,'r') as f:
          count = 0
          for line in f:
              count += 1
              if sample_count != None and count >= sample_count: return
              data_sample = json.loads(line)
              negatives = []
              for neg_id in random.sample(data_sample['neg']['bm25'],n_negatives): negatives.append(passages[str(neg_id)])
              yield {'query': queries[str(data_sample['qid'])],'positive':passages[str(random.sample(data_sample['pos'],1)[0])],'negatives':negatives}
    return Dataset.from_generator(train_gen, cache_dir = cache_dir)

In [11]:
dataset = create_dataset(os.path.join('/content/drive/MyDrive/685/data','msmarco-hard-negatives-bm25_1k.jsonl'), 2, 'hf',2000)

Generating train split: 0 examples [00:00, ? examples/s]

In [12]:
split_dataset = dataset.train_test_split(test_size=0.1)

# There might be a split over
train_dataset = split_dataset['train']
eval_dataset = split_dataset['test']

### Custom data collator

In [13]:
from typing import Any, Dict, List, Union

class CustomDataCollatorWithPadding(DataCollatorWithPadding):
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        query_texts = []
        pos_texts = []
        neg_texts = []

        for feature in features:
            query_texts.append(feature['query'])
            pos_texts.append(feature['positive'])
            if 'negatives' in feature.keys():
                for neg_text in feature['negatives']: neg_texts.append(neg_text)

        tokenized_query_texts = self.tokenizer(
                query_texts,
                max_length=self.max_length,
                padding=self.padding,
                truncation=True,
                return_tensors="pt",
                add_special_tokens=True)

        tokenized_pos_texts = self.tokenizer(
                pos_texts,
                max_length=self.max_length,
                padding=self.padding,
                truncation=True,
                return_tensors="pt",
                add_special_tokens=True)

        tokenized_texts = {
            'tokenized_queries' : tokenized_query_texts,
            'tokenized_positives' : tokenized_pos_texts,
        }


        if len(neg_texts) > 0:
            tokenized_neg_texts = self.tokenizer(
                  neg_texts,
                  max_length=self.max_length,
                  padding=self.padding,
                  truncation=True,
                  return_tensors="pt",
                  add_special_tokens=True)
            tokenized_texts['tokenized_negatives'] = tokenized_neg_texts

        return tokenized_texts

custom_data_collator = CustomDataCollatorWithPadding(
    tokenizer=tokenizer,
    padding='longest',
    max_length=tokenizer.model_max_length
)

In [14]:
out_train = custom_data_collator([train_dataset[i] for i in range(8)])
print(out_train['tokenized_queries']['input_ids'].shape)
print(out_train['tokenized_positives']['input_ids'].shape)
print(out_train['tokenized_negatives']['input_ids'].shape)

torch.Size([8, 9])
torch.Size([8, 160])
torch.Size([16, 162])


In [15]:
out_eval = custom_data_collator([eval_dataset[i] for i in range(2)])
print(out_eval['tokenized_queries']['input_ids'].shape)
print(out_eval['tokenized_positives']['input_ids'].shape)

torch.Size([2, 7])
torch.Size([2, 81])


### Contrastive loss

In [16]:
#enter your code here
def get_contriever_emb(model_output, attention_mask):
    last_hidden = model_output
    last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
    emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    return emb

def contrastive_loss(query_embs,pos_embs,neg_embs):
    temp = 0.05
    query_embs = query_embs/temp
    similarities_pos = torch.mm(query_embs,torch.transpose(pos_embs,0,1))
    similarities = similarities_pos
    if neg_embs != None:
        similarities_neg = torch.mm(query_embs,torch.transpose(neg_embs,0,1))
        similarities = torch.cat((similarities_pos,similarities_neg),dim=1)
    return F.cross_entropy(similarities,torch.arange(0,query_embs.shape[0]).to('cuda'))

# def contrastive_loss(query_embs, pos_embs, neg_embs=None, temperature=0.05):
#     device = query_embs.device
#     #Only the queries are temperature-scaled. Both query and key (positive/negative) should ideally be scaled
#     query_embs = query_embs / temperature
#     pos_embs = pos_embs / temperature

#     sim_pos = torch.mm(query_embs, pos_embs.T)

#   #If your negatives are reused from another forward pass, they might retain gradients accidentally. Best to detach.
#     if neg_embs is not None:
#         neg_embs = neg_embs.detach() / temperature  # Prevent gradient through negatives
#         sim_neg = torch.mm(query_embs, neg_embs.T)
#         logits = torch.cat([sim_pos, sim_neg], dim=1)
#     else:
#         logits = sim_pos

#     targets = torch.arange(query_embs.size(0), device=device)
#     loss = F.cross_entropy(logits, targets)
#     return loss



### Models

In [17]:
class Quantize(nn.Module):
    def __init__(self, codebook_vector_dim, num_clusters, decay = 0.99, eps = 1e-5):
        super().__init__()

        self.dim = codebook_vector_dim
        self.num_clusters = num_clusters
        self.decay = decay
        self.eps = eps

        embed = torch.randn(codebook_vector_dim, num_clusters)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(num_clusters))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input, training = True):
        input = torch.nn.functional.normalize(input, p=2, dim=1)
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim = True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim = True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.num_clusters).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.eps) / (n + self.num_clusters * self.eps) * n
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

        quantize = input + (quantize - input).detach()
        return quantize, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

In [18]:
class Sparse_Layer(nn.Module):
    def __init__(self,d,D):
        super().__init__()

        self.up_projection = nn.Linear(d,D)
        nn.init.xavier_uniform_(self.up_projection.weight)
        nn.init.zeros_(self.up_projection.bias)


    # def forward(self,embeds):
    #     # Normalize input to prevent large values
    #     embeds = torch.nn.functional.normalize(embeds, p=2, dim=1)

    #     print("NaNs before matmul:", torch.isnan(embeds).any())
    #     S = self.up_projection(embeds)
    #     print("NaNs after projection:", torch.isnan(S).any())
    #     # S = torch.log(1 + torch.relu(S))
    #     # S = torch.log1p(F.relu(S))
    #     # print("NaNs after logrelu:", torch.isnan(S).any())
    #     # S = torch.clamp(S, max=10)
    #     S = torch.clamp(F.relu(S), max=100)  # Clamp before log1p to avoid NaNs from large values
    #     print("NaNs after clamp:", torch.isnan(S).any())
    #     S = torch.log1p(S)
    #     print("NaNs after logrelu:", torch.isnan(S).any())
    #     # print("NaNs after clamp:", torch.isnan(S).any())
    #     S = torch.nn.functional.normalize(S, p=2, dim=1,eps=1e-8)
    #     print("NaNs after normalize:", torch.isnan(S).any())
    #     return S

    def forward(self, embeds):
        # Normalize input embeddings to unit norm
        embeds = torch.nn.functional.normalize(embeds, p=2, dim=1)

        # print("NaNs before projection:", torch.isnan(embeds).any())

        # Step 2: Apply linear projection
        S = self.up_projection(embeds)
        # print("NaNs after projection:", torch.isnan(S).any())

        # Step 3: Apply ReLU to keep only positive activations
        S = F.relu(S)

        # Step 4: Clamp large activations before log1p to avoid NaNs from log(Inf)
        S = torch.clamp(S, max=100)
        # print("NaNs after clamp:", torch.isnan(S).any())

        # Step 5: Apply log1p to compress dynamic range
        S = torch.log1p(S)
        # print("NaNs after log1p:", torch.isnan(S).any())

        # Step 6: Final normalization to ensure stable gradient flow
        S = torch.nn.functional.normalize(S, p=2, dim=1, eps=1e-8)
        # print("NaNs after final normalize:", torch.isnan(S).any())

        return S


In [19]:
class LSR_VQ(nn.Module):
    def __init__(self, model_name, emb_dim, num_clusters, d, D, freeze_contriver=False):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name,device_map='auto')
        if freeze_contriver:
            for param in self.model.parameters():
                param.requires_grad = False
        self.quantizer = Quantize(emb_dim,num_clusters)
        self.sparse_layer = Sparse_Layer(d,D)

    def forward(self, inputs):
        query_embs = self.model(**inputs['tokenized_queries']).last_hidden_state
        query_dense = get_contriever_emb(query_embs,inputs['tokenized_queries']['attention_mask'])

        pos_embs = self.model(**inputs['tokenized_positives']).last_hidden_state
        pos_dense = get_contriever_emb(pos_embs,inputs['tokenized_positives']['attention_mask'])

        neg_embs = self.model(**inputs['tokenized_negatives']).last_hidden_state
        neg_dense = get_contriever_emb(neg_embs,inputs['tokenized_negatives']['attention_mask'])

        quantized_embs, _ = self.quantizer(torch.cat([query_dense,pos_dense,neg_dense],dim=0),True)

        query_offset = 0
        pos_offset = query_dense.shape[0]
        neg_offset = query_dense.shape[0] + pos_dense.shape[0]

        query_quantized_embs = quantized_embs[:pos_offset]
        pos_quantized_embs = quantized_embs[pos_offset:neg_offset]
        neg_quantized_embs = quantized_embs[neg_offset:]

        query_sparse = self.sparse_layer(query_quantized_embs)
        pos_sparse = self.sparse_layer(pos_quantized_embs)
        neg_sparse = self.sparse_layer(neg_quantized_embs)

        return query_dense,query_sparse,pos_dense,pos_sparse,neg_dense,neg_sparse

    def get_sparse_rep(self,input_ids,attention_mask):
        # embs = self.model(input_ids)
        embs = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        dense_embs = get_contriever_emb(embs,attention_mask)

        quantized_embs, codes = self.quantizer(dense_embs,False)

        sparse_rep = self.sparse_layer(quantized_embs)

        return sparse_rep

In [20]:
model = LSR_VQ(model_name,768,2048,768,8192)

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

### Training

In [21]:
#enter your code here to define trainer
class LSRVQ_Trainer(Trainer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.query_reg_lambda = 0.1
        self.passage_reg_lambda = 0.1

    def compute_loss(self, model, inputs, return_outputs=False,num_items_in_batch=None):

        query_dense, query_sparse, pos_dense, pos_sparse, neg_dense, neg_sparse = self.model(inputs)

        # Detach negatives to avoid unintended gradient retention
        if neg_dense is not None:
            neg_dense = neg_dense.detach()

        # Constrastive loss
        c_loss = contrastive_loss(query_dense,pos_dense,neg_dense)

        print('c_loss',c_loss)

        # Query sparsity loss
        # query_reg_loss = torch.sum(torch.mean(torch.abs(query_sparse), dim=0) ** 2)
        query_reg_loss = torch.mean(torch.abs(query_sparse))
        print('q_loss',query_reg_loss)

        # Document sparsity loss
        passage_sparse = torch.cat([pos_sparse,neg_sparse], dim=0)
        # passage_reg_loss = torch.sum(torch.mean(torch.abs(passage_sparse), dim=0) ** 2)
        passage_reg_loss = torch.mean(torch.abs(passage_sparse))

        print('p_loss',passage_reg_loss)

        loss = c_loss + self.query_reg_lambda*query_reg_loss + self.passage_reg_lambda*passage_reg_loss

        return (loss,torch.zeros(1)) if return_outputs else loss



In [22]:
training_args = TrainingArguments(
    output_dir="./weights",
    overwrite_output_dir=True,
    num_train_epochs=4,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_steps=0,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=5,
    eval_steps=5,
    eval_strategy="steps",
    save_total_limit=36,
    save_steps=5,
    load_best_model_at_end=True,
    greater_is_better=False,
    learning_rate=1e-4,
    lr_scheduler_type = "linear",
    fp16=True,
    report_to='none',
    metric_for_best_model="eval_loss",
    remove_unused_columns=False,
    # optim = "adamw_8bit",
)

In [23]:
trainer = LSRVQ_Trainer(
        model = model,
        args = training_args,
        data_collator = custom_data_collator,
        train_dataset = train_dataset,
        eval_dataset = eval_dataset)

trainer.can_return_loss=True

def custom_eval_dataloader(eval_set):
    return DataLoader(
    eval_set,
    batch_size=training_args.eval_batch_size,
    sampler=RandomSampler(eval_dataset),
    collate_fn=trainer.data_collator,
)

trainer.get_eval_dataloader = custom_eval_dataloader

In [24]:
trainer.train()

c_loss tensor(0.1271, device='cuda:0', grad_fn=<NllLossBackward0>)
q_loss tensor(0.0062, device='cuda:0', grad_fn=<MeanBackward0>)
p_loss tensor(0.0062, device='cuda:0', grad_fn=<MeanBackward0>)


Step,Training Loss,Validation Loss
5,0.1242,0.258736
10,1.1088,1.933208
15,2.2472,1.900115
20,5.1416,2.169508
25,1.6672,1.274425
30,1.0229,1.374936
35,1.2964,0.852485
40,1.5419,0.782983
45,1.2184,0.84958
50,0.7799,0.484162


c_loss tensor(0.0627, device='cuda:0', grad_fn=<NllLossBackward0>)
q_loss tensor(0.0063, device='cuda:0', grad_fn=<MeanBackward0>)
p_loss tensor(0.0063, device='cuda:0', grad_fn=<MeanBackward0>)
c_loss tensor(0.0403, device='cuda:0', grad_fn=<NllLossBackward0>)
q_loss tensor(0.0063, device='cuda:0', grad_fn=<MeanBackward0>)
p_loss tensor(0.0063, device='cuda:0', grad_fn=<MeanBackward0>)
c_loss tensor(0.1231, device='cuda:0', grad_fn=<NllLossBackward0>)
q_loss tensor(0.0063, device='cuda:0', grad_fn=<MeanBackward0>)
p_loss tensor(0.0063, device='cuda:0', grad_fn=<MeanBackward0>)
c_loss tensor(0.2614, device='cuda:0', grad_fn=<NllLossBackward0>)
q_loss tensor(0.0063, device='cuda:0', grad_fn=<MeanBackward0>)
p_loss tensor(0.0063, device='cuda:0', grad_fn=<MeanBackward0>)
c_loss tensor(0.4155, device='cuda:0')
q_loss tensor(0.0062, device='cuda:0')
p_loss tensor(0.0062, device='cuda:0')
c_loss tensor(0.1962, device='cuda:0')
q_loss tensor(0.0063, device='cuda:0')
p_loss tensor(0.0063, dev

TrainOutput(global_step=228, training_loss=0.587065102695896, metrics={'train_runtime': 200.8527, 'train_samples_per_second': 35.827, 'train_steps_per_second': 1.135, 'total_flos': 0.0, 'train_loss': 0.587065102695896, 'epoch': 4.0})

### Inference

In [25]:
queries_dev = tsv_to_dict_unqiue(os.path.join('/content/drive/MyDrive/685/data', "queries.dev.small.tsv")) # Changed file path to reflect the actual location in Google Drive
qrels_dev = tsv_to_dict_multiple(os.path.join('/content/drive/MyDrive/685/data', "qrels.dev.small.tsv"), keys = [0, 2]) # Changed file path to reflect the actual location in Google Drive

with open('/content/drive/MyDrive/685/embeddings/train/passage_ids.json') as f: # Changed file path to be an absolute path
    inference_passage_ids = json.load(f)

In [26]:
from datasets import load_dataset

def encode_texts(model, tokenizer, text_ds, batch_size, verbose=False):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    text_ids = [i['_id'] for i in text_ds]
    dataloader = DataLoader(text_ds, batch_size=batch_size, shuffle=False)
    allemb = []

    model.to(device)
    model.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader):
            prepare_batch = [batch['text'][i] for i in range(len(batch['text']))]
            tokenized_texts = tokenizer.batch_encode_plus(
                prepare_batch,
                max_length=tokenizer.model_max_length,
                padding=True,
                truncation=True,
                return_tensors="pt",
                add_special_tokens=True,
            ).to(device)
            if verbose:
                print("Batch size:", len(batch['text']))
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch

def encode_texts(model, tokenizer, text_ds, batch_size, verbose=False):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    text_ids = [i['_id'] for i in text_ds]
    dataloader = DataLoader(text_ds, batch_size=batch_size, shuffle=False)
    allemb = []

    model.to(device)
    model.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader):
            prepare_batch = [batch['text'][i] for i in range(len(batch['text']))]
            tokenized_texts = tokenizer.batch_encode_plus(
                prepare_batch,
                max_length=tokenizer.model_max_length,
                padding=True,
                truncation=True,
                return_tensors="pt",
                add_special_tokens=True,
            ).to(device)
            if verbose:
                print("Batch size:", len(batch['text']))
                print(f"Tokenized texts shape: {tokenized_texts['input_ids'].shape}")
                print(f"Decode text sample {tokenizer.decode(tokenized_texts['input_ids'][0])}")
            # The change is in the following line: removed attention_mask from the explicit keyword arguments as it is already present in tokenized_texts
            emb = model.get_sparse_rep(input_ids=tokenized_texts['input_ids'], attention_mask=tokenized_texts['attention_mask'])
            allemb.append(emb.cpu())
    allemb = torch.cat(allemb, dim=0)
    allemb = allemb.cpu().numpy()
    return allemb, text_idsmpy()
    return allemb, text_ids

In [27]:
corpus_ds = Dataset.from_dict({'_id':inference_passage_ids,'text':[passages[_id] for _id in inference_passage_ids]})
queries_ds = Dataset.from_dict({'_id':list(queries_dev.keys()),'text':list(queries_dev.values())})

### Top-k Sparsification

In [None]:
import numpy as np
import os
from tqdm import tqdm
import torch

passage_topk = 1024
save_path = "/content/my_data"
os.makedirs(save_path, exist_ok=True)

embeddings_file = os.path.join(save_path, "passage_sparse_embeddings_topk1024.npy")
ids_file = os.path.join(save_path, "passage_ids.npy")

first_batch = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'

for i in tqdm(range(0, len(corpus_ds), 64)):
    batch = corpus_ds[i:i+64]
    texts = batch['text']
    ids = batch['_id']

    tokenized = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device)

    with torch.no_grad():
        batch_sparse = model.get_sparse_rep(tokenized['input_ids'], tokenized['attention_mask'])

    values, indices = torch.topk(batch_sparse.abs(), k=passage_topk, dim=1)
    sparse_embeddings = torch.zeros_like(batch_sparse)
    sparse_embeddings.scatter_(1, indices, batch_sparse.gather(1, indices))

    sparse_np = sparse_embeddings.cpu().numpy()
    ids_np = np.array(ids)

    if first_batch:
        np.save(embeddings_file, sparse_np)
        np.save(ids_file, ids_np)
        first_batch = False
    else:
        with open(embeddings_file, 'ab') as f:
            np.save(f, sparse_np)
        with open(ids_file, 'ab') as f:
            np.save(f, ids_np)

 91%|█████████▏| 28559/31250 [1:04:36<05:38,  7.96it/s]

In [None]:
embeddings = []
with open(embeddings_file, 'rb') as f:
    while True:
        try:
            embeddings.append(np.load(f))
        except ValueError:
            break
passage_sparse_embeddings = np.concatenate(embeddings, axis=0)


In [None]:
# passage_topk = 1024

# passage_sparse_embeddings, passage_ids = encode_texts(model, tokenizer, corpus_ds, batch_size=64)

# # Apply top-k sparsification
# values, indices = torch.topk(passage_sparse_embeddings.abs(), k=passage_topk, dim=1)
# sparse_embeddings = torch.zeros_like(passage_sparse_embeddings)
# sparse_embeddings.scatter_(1, indices, passage_sparse_embeddings.gather(1, indices))

# passage_sparse_embeddings = sparse_embeddings.detach().cpu().numpy()


### Inverted Index

In [None]:
from collections import defaultdict
import heapq


inverted_index = defaultdict(list)
min_weight = 1e-5
for j, doc_id in enumerate(passage_ids):
    doc_embedding = passage_sparse_embeddings[j]
    # Get non-zero indices and their values
    nonzero_indices = np.nonzero(doc_embedding)[0]  # Ensure indices are integers
    for idx in nonzero_indices:
        weight = doc_embedding[idx]
        if weight > min_weight:
            inverted_index[int(idx)].append((doc_id, float(weight)))

 # Sort postings lists by weight for each term
for idx in inverted_index:
    inverted_index[idx] = sorted(inverted_index[idx], key=lambda x: abs(x[1]), reverse=True)

# Convert to more efficient data structure
optimized_index = {
    idx: (
        np.array([doc_id for doc_id, _ in postings], dtype=np.int32),
        np.array([weight for _, weight in postings], dtype=np.float32)
    )
    for idx, postings in inverted_index.items()
}

### Search Index

In [None]:
def search_inverted_index(query_embedding, inverted_index, query_topk=128, min_weight=1e-5):
    scores = defaultdict(float)
    seen_docs = set()

    # Get top-k query dimensions by weight
    weights = [(idx, weight) for idx, weight in enumerate(query_embedding) if abs(weight) > min_weight]
    top_weights = heapq.nlargest(query_topk, weights, key=lambda x: abs(x[1]))

    # Process each query term
    for idx, query_weight in top_weights:
        if idx not in inverted_index:
            continue

        doc_ids, doc_weights = inverted_index[idx]

        # Only process top documents per term
        for doc_id, doc_weight in zip(doc_ids, doc_weights):
            scores[doc_id] += query_weight * doc_weight
            seen_docs.add(doc_id)

    # Use numpy for final scoring
    if seen_docs:
        doc_ids = np.array(list(seen_docs))
        doc_scores = np.array([scores[doc_id] for doc_id in doc_ids])

        # Get top 1000 results efficiently
        top_k = min(1000, len(doc_scores))
        top_indices = np.argpartition(doc_scores, -top_k)[-top_k:]
        top_indices = top_indices[np.argsort(-doc_scores[top_indices])]

        return [(doc_ids[i], doc_scores[i]) for i in top_indices]

    return []

### Search

In [None]:
query_sparse_embeddings, query_ids = encode_texts(model, tokenizer, queries_ds, batch_size=32)

In [None]:
query_topk = 128

all_results = {}

for i,query_id in enumerate(query_ids):
    query_embedding = query_sparse_embeddings[i].cpu().numpy()[0]
    search_results = search_inverted_index(
                query_embedding,
                optimized_index,
                query_topk=query_topk,
                min_weight=min_weight
            )
    all_results[query_id] = search_results

### Metrics

In [None]:
# Create rank_eval Run and Qrels objects
run = Run()
qrels = Qrels()

# Add results to Run object
for qid in all_results:
    doc_ids = [str(doc_id) for doc_id, score in all_results[qid]]
    scores = [float(score) for _, score in all_results[qid]]
    run.add(qid, doc_ids, scores)

# Add relevance judgments to Qrels object
for qid in qrels_dev:
    qrels.add(qid, [str(doc_id) for doc_id in qrels_dev[qid]], [1] * len(qrels_dev[qid]))

# Evaluate using rank_eval
metrics = ["ndcg@10", "ndcg@100", "ndcg@1000", "recall@10", "recall@100", "recall@1000", "mrr@10"]
results = evaluate(qrels, run, metrics)

metrics = (
    results["mrr@10"],
    {
        '10': results["ndcg@10"],
        '100': results["ndcg@100"],
        '1000': results["ndcg@1000"]
    },
    {
        '10': results["recall@10"],
        '100': results["recall@100"],
        '1000': results["recall@1000"]
    }
)