###Setting Up

In [2]:
# !unzip '/content/LSR-VQ-main.zip' -d '/content/drive/MyDrive/My685'

In [1]:
cd '/content/drive/MyDrive/My685/LSR-VQ-main'

/content/drive/MyDrive/685/LSR-VQ-main


###Importing Libraries

In [5]:
!pip install torch ir_datasets wandb numpy scikit-learn sentence-transformers transformers tqdm scipy matplotlib rank-eval ranx
!pip install faiss-cpu

Collecting ir_datasets
  Downloading ir_datasets-0.5.10-py3-none-any.whl.metadata (12 kB)
Collecting rank-eval
  Downloading rank_eval-0.1.3-py3-none-any.whl.metadata (6.8 kB)
Collecting ranx
  Downloading ranx-0.3.20-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metad

In [6]:
# Import libraries
import pandas as pd
import csv

import torch
import ir_datasets
import faiss
import wandb
import heapq
import time
import sys
import random
import string
import os
import pickle
import math

import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import numpy as np

from sentence_transformers import SentenceTransformer
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
from sklearn.metrics import ndcg_score, recall_score
from collections import defaultdict
from scipy.sparse import csr_matrix
from collections import defaultdict
from tqdm import tqdm
# from rank_eval import Qrels, Run, evaluate
from ranx import Qrels, Run, evaluate

from collections import Counter
import json

os.makedirs("embeddings/train/", exist_ok = True)
os.makedirs("embeddings/dev/", exist_ok = True)

###Download Dataset

In [7]:
# # Download collection (pId -> passage text)
# !wget -P data/raw/ https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz

# # Download queries (qId -> query text)
# !wget -P data/raw/ https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz

# # Download qRels Dev and Train
# !wget -P data/raw/ https://msmarco.z22.web.core.windows.net/msmarcoranking/qrels.dev.tsv
# !wget -P data/raw/ https://msmarco.z22.web.core.windows.net/msmarcoranking/qrels.train.tsv

# # Download training data - qId positive_pId and negative_pId
# !wget -P data/raw/ https://msmarco.z22.web.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz

# # Download total dataset (This contains all the required files)
# !wget -P data/raw https://msmarco.z22.web.core.windows.net/msmarcoranking/collectionandqueries.tar.gz

In [8]:
!ls -R data/raw/


data/raw/:
collectionandqueries.tar.gz  qrels.dev.tsv	    queries.eval.small.tsv
collection.tar.gz	     qrels.train.tsv	    queries.eval.tsv
collection.tsv		     queries.dev.small.tsv  queries.tar.gz
qrels.dev.small.tsv	     queries.dev.tsv	    queries.train.tsv


In [9]:
# # Extract collection.tar.gz (contains passage ID and passage text)
# !tar -xvzf data/raw/collection.tar.gz -C data/raw/

# # Extract queries.tar.gz (contains query ID and query text)
# !tar -xvzf data/raw/queries.tar.gz -C data/raw/

# # Extract collectionandqueries.tar.gz (optional, if you want everything bundled)
# !tar -xvzf data/raw/collectionandqueries.tar.gz -C data/raw/

# # Decompress qidpidtriples.train.full.2.tsv.gz (triplets)
# !gunzip data/raw/qidpidtriples.train.full.2.tsv.gz


###Load Dataset

In [10]:
# Read the tsv file as a dictionary
def open_file(file_path, keys = [0, 1]):
    with open(file_path, mode = "r", encoding = "utf-8") as file:
        reader = csv.reader(file, delimiter= "\t")
        data = {}
        for row in reader:
            if row[keys[0]] in data:
                data[row[keys[0]]].append(row[keys[1]])
            else:
                data[row[keys[0]]] = [row[keys[1]]]
    return data

# Load and preprocess the dataset
def load_and_preprocess_dataset():
    # Load all passages
    passages = open_file("data/raw/collection.tsv")

    # Load all train, dev & eval queries
    queries_train = open_file("data/raw/queries.train.tsv")
    queries_dev = open_file("data/raw/queries.dev.small.tsv")
    # queries_eval = open_file("data/queries.eval.tsv")

    # Load qRels train & dev
    # These only have relevant files i.e binary relevance judgment score
    # will be 1 for all the entries
    qrels_train = open_file("data/raw/qrels.train.tsv", keys = [0, 2])
    qrels_dev = open_file("data/raw/qrels.dev.small.tsv", keys = [0, 2])

    return passages, queries_train, queries_dev, qrels_train, qrels_dev

In [11]:
passages, queries_train, queries_dev, qrels_train, qrels_dev = load_and_preprocess_dataset()

In [12]:
def print_samples(file_name, dict_, n = 2):
    print('-' * 15)
    print(f'Statistics for {file_name}:')
    print('Total number of samples:', len(dict_))
    samples = list(dict_.items())[:2]
    for entry in samples:
        print(entry)

print_samples("passages", passages)
print_samples("queries_train", queries_train)
print_samples("queries_dev", queries_dev)
print_samples("qrels_train", qrels_train)
print_samples("qrels_dev", qrels_dev)

---------------
Statistics for passages:
Total number of samples: 8841823
('0', ['The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.'])
('1', ['The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science.'])
---------------
Statistics for queries_train:
Total number of samples: 808731
('121352', ['define extreme'])
('634306', ['what does chattel mean on credit history'])
---------------
Statistics for queries_dev:
Total number of samples: 6980
('1048585', ["what is paula deen's brother"])
('2', [' Androgen receptor define'])
---------------
Statistics for qrels_train:
Total number of samples: 502939
('1

###Initializing Model

In [13]:
import torch
from transformers import AutoTokenizer, AutoModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device', device)

# Getting tokens
tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
model = AutoModel.from_pretrained('facebook/contriever-msmarco').to(device)

# Mean pooling cause we get embedding for each token but we want for entire sentence
def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim = 1) / mask.sum(dim = 1)[..., None]
    return sentence_embeddings

Using device cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

###Evaluate Model

In [14]:

# Evaluate model
def evaluate_model(model, passages, dev_queries, dev_qrels, tokenizer, device, batch_size_inference=128, min_weight=1e-5, query_topk=128, doc_topk=1024, quick_mode=True):
    base_model = model.module if isinstance(model, nn.DataParallel) else model

    model.eval()
    all_results = {}

    # Filter passages if in quick mode(Only keep passages that are marked relevant in dev_qrels)
    if quick_mode:
        relevant_passage_ids = set()
        for qid in dev_qrels:
            rels = dev_qrels[qid]
            if isinstance(rels, list):
                relevant_passage_ids.update(rels)
            else:
                relevant_passage_ids.add(rels)
        passages = {pid: passages[pid] for pid in relevant_passage_ids if pid in passages}
        print(f"Quick mode: using {len(passages)} passages for evaluation")

    # Filter queries and qrels if in quick mode(Limits to only one relevant passage per query.)
    if quick_mode:
        dev_qrels = {
            qid: [dev_qrels[qid][0]] if isinstance(dev_qrels[qid], list) else [dev_qrels[qid]]
            for qid in dev_queries if qid in dev_qrels
        }
        print(f"Quick mode: using {len(dev_queries)} queries for evaluation")

    if len(passages) == 0 or len(dev_queries) == 0:
        print("No data to evaluate.")
        return 0.0, {}, {}

    print("Building FAISS index...")
    passage_ids = list(passages.keys())
    passage_embeddings = []

    # Encode all passages
    for i in tqdm(range(0, len(passage_ids), batch_size_inference), desc="Encoding passages"):
        batch_passages = [passages[pid][0] for pid in passage_ids[i:i + batch_size_inference]]
        passage_inputs = tokenizer(batch_passages, padding=True, truncation=True, return_tensors='pt')
        passage_inputs = {k: v.to(device) for k, v in passage_inputs.items()}

        with torch.no_grad():
            outputs = model(**passage_inputs)
            batch_embeddings = mean_pooling(outputs.last_hidden_state, passage_inputs['attention_mask'])
            batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
            passage_embeddings.append(batch_embeddings.cpu().numpy())

    passage_embeddings = np.vstack(passage_embeddings)

    dimension = passage_embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)
    index.add(passage_embeddings)

    # Evaluate queries
    with torch.no_grad():
        for qid, query in tqdm(dev_queries.items(), desc="Evaluating"):
            if qid not in dev_qrels:
                continue

            query_text = query[0]
            query_input = tokenizer(query_text, padding=True, truncation=True, return_tensors='pt')
            query_input = {k: v.to(device) for k, v in query_input.items()}
            output = model(**query_input)
            query_embedding = mean_pooling(output.last_hidden_state, query_input['attention_mask'])
            query_embedding = F.normalize(query_embedding, p=2, dim=1).cpu().numpy()

            scores, indices = index.search(query_embedding, 1000)
            search_results = [(passage_ids[idx], float(score)) for idx, score in zip(indices[0], scores[0])]
            all_results[qid] = search_results

    # Build Run and Qrels
    run_dict = {
        qid: {str(pid): float(score) for pid, score in results}
        for qid, results in all_results.items()
    }

    qrels_dict = {
        qid: {str(doc_id): 1 for doc_id in dev_qrels[qid]}
        for qid in dev_qrels
    }

    run = Run(run_dict)
    qrels = Qrels(qrels_dict)

    metrics = ["ndcg@10", "ndcg@100", "ndcg@1000", "recall@10", "recall@100", "recall@1000", "mrr@10"]
    results = evaluate(qrels, run, metrics)

    return (
        results["mrr@10"],
        {
            '10': results["ndcg@10"],
            '100': results["ndcg@100"],
            '1000': results["ndcg@1000"]
        },
        {
            '10': results["recall@10"],
            '100': results["recall@100"],
            '1000': results["recall@1000"]
        }
    )


In [15]:
mrr, ndcg_scores, recall_scores = evaluate_model(
    model=model,
    passages=passages,
    dev_queries=queries_dev,
    dev_qrels=qrels_dev,
    tokenizer=tokenizer,
    device=device,
    quick_mode=True
)


Quick mode: using 7433 passages for evaluation
Quick mode: using 6980 queries for evaluation
Building FAISS index...



Encoding passages:   0%|          | 0/59 [00:00<?, ?it/s][A
Encoding passages:   2%|▏         | 1/59 [00:03<03:26,  3.56s/it][A
Encoding passages:   3%|▎         | 2/59 [00:05<02:12,  2.33s/it][A
Encoding passages:   5%|▌         | 3/59 [00:06<01:38,  1.76s/it][A
Encoding passages:   7%|▋         | 4/59 [00:07<01:21,  1.48s/it][A
Encoding passages:   8%|▊         | 5/59 [00:08<01:11,  1.32s/it][A
Encoding passages:  10%|█         | 6/59 [00:09<01:05,  1.24s/it][A
Encoding passages:  12%|█▏        | 7/59 [00:10<01:01,  1.18s/it][A
Encoding passages:  14%|█▎        | 8/59 [00:11<00:57,  1.13s/it][A
Encoding passages:  15%|█▌        | 9/59 [00:12<01:00,  1.21s/it][A
Encoding passages:  17%|█▋        | 10/59 [00:13<00:59,  1.20s/it][A
Encoding passages:  19%|█▊        | 11/59 [00:14<00:53,  1.12s/it][A
Encoding passages:  20%|██        | 12/59 [00:16<00:57,  1.21s/it][A
Encoding passages:  22%|██▏       | 13/59 [00:17<00:54,  1.19s/it][A
Encoding passages:  24%|██▎       | 1

In [16]:
print(f"MRR@10: {mrr:.4f}")
print("NDCG Scores:", ndcg_scores)
print("Recall Scores:", recall_scores)

MRR@10: 0.9414
NDCG Scores: {'10': np.float64(0.9549372455575459), '100': np.float64(0.9560142436401606), '1000': np.float64(0.9560499593537618)}
Recall Scores: {'10': np.float64(0.9948424068767908), '100': np.float64(0.9995702005730659), '1000': np.float64(0.9998567335243553)}


###Get Query and Passage Embeddings

In [26]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

def get_embeddings(
    model,
    passages,
    queries,
    qrels,
    tokenizer,
    device,
    batch_size_inference=128,
    quick_mode=True,
    output_batching=True,
    mode='train'
):
    base_model = model.module if isinstance(model, torch.nn.DataParallel) else model
    model.eval()

    # Filter passages if in quick mode
    if quick_mode:
        relevant_passage_ids = set()
        for qid in qrels:
            rels = qrels[qid]
            if isinstance(rels, list):
                relevant_passage_ids.update(rels)
            else:
                relevant_passage_ids.add(rels)

        filtered_passages = {}
        for pid in relevant_passage_ids:
            if pid in passages:
                text = passages[pid]
                if isinstance(text, tuple):
                    text = text[0]
                if not isinstance(text, str):
                    text = str(text)
                filtered_passages[pid] = text
        passages = filtered_passages
        print(f"Quick mode: using {len(passages)} {mode} passages")

    # Filter queries and qrels if in quick mode
    if quick_mode:
        queries = {qid: queries[qid] for qid in qrels if qid in queries}
        qrels = {qid: qrels[qid] for qid in queries}
        print(f"Quick mode: using {len(queries)} {mode} queries")

    passage_ids = list(passages.keys())
    passage_embeddings = []

    # Encode all passages
    for i in tqdm(range(0, len(passage_ids), batch_size_inference), desc="Encoding passages"):
        batch_passages = []
        for pid in passage_ids[i:i + batch_size_inference]:
            text = passages[pid]
            if isinstance(text, tuple):
                text = text[0]
            if not isinstance(text, str):
                text = str(text)
            batch_passages.append(text)

        passage_inputs = tokenizer(batch_passages, padding=True, truncation=True, return_tensors='pt')
        passage_inputs = {k: v.to(device) for k, v in passage_inputs.items()}

        with torch.no_grad():
            outputs = model(**passage_inputs)
            batch_embeddings = mean_pooling(outputs.last_hidden_state, passage_inputs['attention_mask'])
            batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)  # Normalize here
            passage_embeddings.append(batch_embeddings)

    query_ids = list(qrels.keys())
    query_embeddings = []

    # Encode all queries
    for i in tqdm(range(0, len(query_ids), batch_size_inference), desc="Encoding queries"):
        batch_queries = []
        for qid in query_ids[i:i + batch_size_inference]:
            text = queries[qid]
            if isinstance(text, tuple):
                text = text[0]
            if not isinstance(text, str):
                text = str(text)
            batch_queries.append(text)

        query_inputs = tokenizer(batch_queries, padding=True, truncation=True, return_tensors='pt')
        query_inputs = {k: v.to(device) for k, v in query_inputs.items()}

        with torch.no_grad():
            outputs = model(**query_inputs)
            batch_embeddings = mean_pooling(outputs.last_hidden_state, query_inputs['attention_mask'])
            batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
            query_embeddings.append(batch_embeddings)

    if output_batching:
        passage_embeddings = torch.cat(passage_embeddings, dim=0)
        query_embeddings = torch.cat(query_embeddings, dim=0)
    else:
        print("Warning: output_batching=False, embeddings are returned as list of batches.")

    return {
        'embeddings': {
            'passage_embeddings': passage_embeddings,
            'query_embeddings': query_embeddings
        },
        'mappings': {
            'passage_ids': passage_ids,
            'query_ids': query_ids
        }
    }


In [27]:

def load_or_save_embeddings(
    model,
    passages,
    queries,
    qrels,
    tokenizer,
    device,
    batch_size=128,
    quick_mode=True,
    output_batching=True,
    mode='train',
    force_recompute=False
):
    save_dir = os.path.join('embeddings', mode)
    os.makedirs(save_dir, exist_ok=True)

    passage_embeddings_path = os.path.join(save_dir, 'passage_embeddings.pt')
    query_embeddings_path = os.path.join(save_dir, 'query_embeddings.pt')
    passage_ids_path = os.path.join(save_dir, 'passage_ids.json')
    query_ids_path = os.path.join(save_dir, 'query_ids.json')

    # Try loading if saved
    if not force_recompute and all(os.path.exists(p) for p in [passage_embeddings_path, query_embeddings_path, passage_ids_path, query_ids_path]):
        print(f" Loading cached embeddings and IDs from {save_dir}...")

        passage_embeddings = torch.load(passage_embeddings_path, map_location=device)
        query_embeddings = torch.load(query_embeddings_path, map_location=device)

        with open(passage_ids_path, "r") as f:
            passage_ids = json.load(f)

        with open(query_ids_path, "r") as f:
            query_ids = json.load(f)

        print(f" Successfully loaded embeddings: {passage_embeddings.size(0)} passages, {query_embeddings.size(0)} queries.")

        return {
            'embeddings': {
                'passage_embeddings': passage_embeddings,
                'query_embeddings': query_embeddings
            },
            'mappings': {
                'passage_ids': passage_ids,
                'query_ids': query_ids
            }
        }

    # Otherwise compute
    print(f" Computing embeddings for {mode} set...")

    obj = get_embeddings(
        model=model,
        passages=passages,
        queries=queries,
        qrels=qrels,
        tokenizer=tokenizer,
        device=device,
        batch_size_inference=batch_size,
        quick_mode=quick_mode,
        output_batching=output_batching,
        mode=mode
    )

    passage_embeddings = obj['embeddings']['passage_embeddings'].float()  # Save as float32
    query_embeddings = obj['embeddings']['query_embeddings'].float()
    passage_ids = obj['mappings']['passage_ids']
    query_ids = obj['mappings']['query_ids']

    # Save computed embeddings
    torch.save(passage_embeddings, passage_embeddings_path)
    torch.save(query_embeddings, query_embeddings_path)

    with open(passage_ids_path, "w") as f:
        json.dump(passage_ids, f)

    with open(query_ids_path, "w") as f:
        json.dump(query_ids, f)

    print(f" Saved embeddings to {save_dir}.")

    return {
        'embeddings': {
            'passage_embeddings': passage_embeddings,
            'query_embeddings': query_embeddings
        },
        'mappings': {
            'passage_ids': passage_ids,
            'query_ids': query_ids
        }
    }


In [28]:
# !rm -rf embeddings/train
# !rm -rf embeddings/dev

In [29]:
train_data = load_or_save_embeddings(
    model=model,
    passages=passages,
    queries=queries_train,
    qrels=qrels_train,
    tokenizer=tokenizer,
    device=device,
    batch_size=128,
    quick_mode=True,
    output_batching=True,
    mode='train'
)

 Computing embeddings for train set...
Quick mode: using 516472 train passages
Quick mode: using 502939 train queries


Encoding passages: 100%|██████████| 4035/4035 [1:39:30<00:00,  1.48s/it]
Encoding queries: 100%|██████████| 3930/3930 [11:45<00:00,  5.57it/s]


 Saved embeddings to embeddings/train.


In [44]:
dev_data = load_or_save_embeddings(
    model=model,
    passages=passages,
    queries=queries_dev,
    qrels=qrels_dev,
    tokenizer=tokenizer,
    device=device,
    batch_size=128,
    quick_mode=True,
    output_batching=True,
    mode='dev'
)

 Computing embeddings for dev set...
Quick mode: using 7433 dev passages
Quick mode: using 6980 dev queries


Encoding passages: 100%|██████████| 59/59 [01:23<00:00,  1.42s/it]
Encoding queries: 100%|██████████| 55/55 [00:10<00:00,  5.45it/s]


 Saved embeddings to embeddings/dev.


In [30]:
train_passage_embeddings = train_data['embeddings']['passage_embeddings']

# Extract query embeddings
train_query_embeddings = train_data['embeddings']['query_embeddings']

# Extract passage ids
train_passage_ids = train_data['mappings']['passage_ids']

# Extract query ids
train_query_ids = train_data['mappings']['query_ids']

In [45]:
dev_passage_embeddings = dev_data['embeddings']['passage_embeddings']
dev_query_embeddings = dev_data['embeddings']['query_embeddings']
dev_passage_ids = dev_data['mappings']['passage_ids']
dev_query_ids = dev_data['mappings']['query_ids']

###Vector Quantizaiton

In [31]:
class Quantize(nn.Module):
    def __init__(self, dim, num_clusters, decay = 0.99, eps = 1e-5):
        super().__init__()

        self.dim = dim
        self.num_clusters = num_clusters
        self.decay = decay
        self.eps = eps

        embed = torch.randn(dim, num_clusters)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(num_clusters))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim = True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim = True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.num_clusters).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if self.training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.eps) / (n + self.num_clusters * self.eps) * n
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

        quantize = input + (quantize - input).detach()

        return quantize, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

In [39]:
# Split embedding into k chunks
def split_embedding_into_chunks(embeddings, k_chunks):
    """
    Split the input embeddings along the feature dimension into k_chunks smaller parts.
    """
    bsz, dim = embeddings.shape
    assert dim % k_chunks == 0, "Embedding dimension must be divisible by k_chunks!"
    return embeddings.view(bsz * k_chunks, dim // k_chunks)


# Train the VQ Quantizer (only on passages)
def train_VQ(
    model,
    passages,
    queries,
    qrels,
    tokenizer,
    device,
    batch_size=128,
    num_clusters=256,
    k_chunks=16,
    quick_mode=True,
    output_batching=True,
    mode='train'
):
    """
    Function to train the codebook (Quantizer) using only passage embeddings.
    """
    # Load or save embeddings
    obj = load_or_save_embeddings(
        model=model,
        passages=passages,
        queries= queries,
        qrels=qrels,
        tokenizer=tokenizer,
        device=device,
        batch_size=batch_size,
        quick_mode=quick_mode,
        output_batching=output_batching,
        mode=mode
    )

    passage_embeddings = obj['embeddings']['passage_embeddings']  # [num_passages, dim]
    dim = passage_embeddings.shape[1] // k_chunks  # Dynamically infer dim per chunk

    passage_chunked_embeddings = split_embedding_into_chunks(passage_embeddings, k_chunks)

    # Train the codebook
    quantizer = Quantize(dim=dim, num_clusters=num_clusters).to(device=device)
    quantizer.train()  # Set to training mode

    for i in tqdm(range(0, passage_chunked_embeddings.shape[0], batch_size), desc="Training codebook vectors"):
        batch_embs = passage_chunked_embeddings[i:i + batch_size].to(device)
        _, code = quantizer(batch_embs)

    quantizer.eval()
    return quantizer

In [40]:
# Perform inference (quantization)
def perform_VQ(
    quantizer,
    embeddings,
    batch_size=32,
    k_chunks=16,
    embedding_type='passage'
):
    """
    Function to quantize embeddings into discrete token assignments.
    Works for passage embeddings or query embeddings.
    """
    quantizer.eval()  # No codebook updates during inference
    code_indices = []

    for i in tqdm(range(0, embeddings.shape[0], batch_size), desc=f"Quantizing {embedding_type} embeddings"):
        batch_embs = embeddings[i:i + batch_size]
        batch_chunked_embs = split_embedding_into_chunks(batch_embs, k_chunks)
        _, code = quantizer(batch_chunked_embs)
        code = code.view(-1, k_chunks)  # [batch_size, k_chunks]
        code_indices.append(code)

    code_indices = torch.cat(code_indices, dim=0)  # [num_examples, k_chunks]

    return code_indices

###Build Inverted Index from Passage Codes

In [50]:
def create_inverted_index(
    passage_ids,
    code_indices,
    vocab_size=None,
    batch_size=128,
    mode='dev'
):
    num_passages = len(passage_ids)

    if vocab_size is None:
        vocab_size = int(max(code.cpu().max().item() for code in code_indices)) + 1  # <-- FIX HERE

    rows, cols, data = [], [], []

    for start in tqdm(range(0, num_passages, batch_size), desc=f"Building {mode} passage matrix"):
        end = min(start + batch_size, num_passages)
        batch_codes = code_indices[start:end]

        for i, codes in enumerate(batch_codes):
            code_counts = Counter(codes.tolist())  # <-- safer to .tolist() here
            for code, freq in code_counts.items():
                rows.append(start + i)
                cols.append(code)
                data.append(freq)

    passage_matrix = csr_matrix(
        (data, (rows, cols)),
        shape=(num_passages, vocab_size),
        dtype=np.float32
    )

    return passage_matrix


In [42]:
import numpy as np
from tqdm import tqdm

def search_inverted_index(
    query_matrix,
    passage_matrix,
    passage_ids,
    query_ids,
    batch_size=128,
    topk=1000
):
    """
    Search over the inverted index using sparse matrix multiplication.
    """

    num_queries = query_matrix.shape[0]
    results = {}

    for start in tqdm(range(0, num_queries, batch_size), desc="Scoring queries in batches"):
        end = min(start + batch_size, num_queries)
        query_chunk = query_matrix[start:end]

        # Sparse matrix multiplication (very fast)
        scores_chunk = query_chunk @ passage_matrix.T  # shape: [batch_size, num_passages]

        scores_chunk = scores_chunk.toarray()  # Dense for indexing

        # For each query in the chunk
        for i in range(scores_chunk.shape[0]):
            passage_scores = scores_chunk[i]
            top_i = np.argpartition(-passage_scores, topk)[:topk]
            sorted_indices = top_i[np.argsort(-passage_scores[top_i])]

            query_id = query_ids[start + i]
            top_passages = [(passage_ids[idx], float(passage_scores[idx])) for idx in sorted_indices]
            results[query_id] = top_passages

    return results


###Implementation

In [43]:
quantizer = train_VQ(
    model=model,
    passages=passages,
    queries=queries_train,
    qrels=qrels_train,
    tokenizer=tokenizer,
    device=device,
    batch_size=128,
    mode='train'  # (train or dev)
)

 Loading cached embeddings and IDs from embeddings/train...
 Successfully loaded embeddings: 516472 passages, 502939 queries.


Training codebook vectors: 100%|██████████| 64559/64559 [00:47<00:00, 1367.10it/s]


In [47]:
# Perform VQ on train passages
train_passage_codes = perform_VQ(
    quantizer=quantizer,
    embeddings=train_passage_embeddings,
    batch_size=128,
    k_chunks=16,
    embedding_type='passage'
)


Quantizing passage embeddings: 100%|██████████| 4035/4035 [00:01<00:00, 2855.33it/s]


In [46]:
# # Perform VQ on dev passages
dev_passage_codes = perform_VQ(
    quantizer=quantizer,
    embeddings=dev_passage_embeddings,
    batch_size=128,
    k_chunks=16,
    embedding_type='passage'
)

Quantizing passage embeddings: 100%|██████████| 59/59 [00:00<00:00, 1236.49it/s]


In [48]:
# Perform VQ on dev queries
dev_query_codes = perform_VQ(
    quantizer=quantizer,
    embeddings=dev_query_embeddings,
    batch_size=128,
    k_chunks=16,
    embedding_type='query'
)

Quantizing query embeddings: 100%|██████████| 55/55 [00:00<00:00, 2731.41it/s]


In [51]:
dev_passage_matrix = create_inverted_index(
    passage_ids=dev_passage_ids,
    code_indices=dev_passage_codes,
    batch_size=128,
    mode='dev'
)

Building dev passage matrix: 100%|██████████| 59/59 [00:00<00:00, 350.78it/s]


###Metrics

In [52]:
from tqdm import tqdm
from ranx import Qrels, Run, evaluate
import numpy as np
from collections import Counter
from scipy.sparse import csr_matrix

def get_metrics(
    query_codes,
    passage_matrix,
    query_ids,
    passage_ids,
    qrels,
    batch_size=128,
    topk=1000,
    k_chunks=16
):
    """
    Compute retrieval metrics using precomputed query_codes and passage_matrix.
    """

    # 1. Build sparse query matrix
    rows, cols, data = [], [], []
    for i, codes in enumerate(query_codes):
        counts = Counter(codes.tolist())
        for code, freq in counts.items():
            rows.append(i)
            cols.append(code)
            data.append(freq)

    vocab_size = passage_matrix.shape[1]
    num_queries = len(query_codes)

    query_matrix = csr_matrix(
        (data, (rows, cols)),
        shape=(num_queries, vocab_size),
        dtype=np.float32
    )

    # 2. Search using sparse matrix multiplication
    all_results = {}

    for start in tqdm(range(0, num_queries, batch_size), desc="Scoring queries"):
        end = min(start + batch_size, num_queries)
        query_chunk = query_matrix[start:end]

        scores_chunk = query_chunk @ passage_matrix.T  # shape [batch, num_passages]
        scores_chunk = scores_chunk.toarray()

        for i in range(scores_chunk.shape[0]):
            passage_scores = scores_chunk[i]
            top_i = np.argpartition(-passage_scores, topk)[:topk]
            sorted_indices = top_i[np.argsort(-passage_scores[top_i])]

            all_results[query_ids[start + i]] = [
                (passage_ids[idx], float(passage_scores[idx]))
                for idx in sorted_indices
            ]

    # 3. Build Ranx Run and Qrels
    run_dict = {
        qid: {str(pid): score for pid, score in results}
        for qid, results in all_results.items()
    }
    run = Run(run_dict)

    qrels_dict = {
        qid: {str(pid): 1 for pid in qrels[qid]}
        for qid in qrels
    }
    qrels = Qrels(qrels_dict)

    # 4. Evaluate
    metrics = ["ndcg@10", "ndcg@100", "ndcg@1000", "recall@10", "recall@100", "recall@1000", "mrr@10"]
    results = evaluate(qrels, run, metrics)

    return (
        results["mrr@10"],
        {
            '10': results["ndcg@10"],
            '100': results["ndcg@100"],
            '1000': results["ndcg@1000"]
        },
        {
            '10': results["recall@10"],
            '100': results["recall@100"],
            '1000': results["recall@1000"]
        }
    )


In [53]:
mrr10, ndcg_scores, recall_scores = get_metrics(
    query_codes=dev_query_codes,
    passage_matrix=dev_passage_matrix,
    query_ids=dev_query_ids,
    passage_ids=dev_passage_ids,
    qrels=qrels_dev,
    batch_size=128,
    topk=1000,
    k_chunks=16
)


Scoring queries: 100%|██████████| 55/55 [00:07<00:00,  6.96it/s]


In [54]:
mrr10

np.float64(0.19142237458498204)

In [55]:
ndcg_scores

{'10': np.float64(0.22177720129924405),
 '100': np.float64(0.28476348408783864),
 '1000': np.float64(0.3162026289798771)}

In [56]:
recall_scores

{'10': np.float64(0.33225883476599805),
 '100': np.float64(0.642227793696275),
 '1000': np.float64(0.8905324737344795)}