## Google Drive setup

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
cd /content/drive/MyDrive/UMass/"685 PROJECT"/"LSR-VQ"

/content/drive/MyDrive/UMass/685 PROJECT/LSR-VQ


## Install required libraries

In [3]:
# Install required libraries
!pip install torch ir_datasets wandb numpy scikit-learn sentence-transformers transformers tqdm scipy matplotlib rank-eval ranx
!pip install faiss-cpu
# !pip uninstall faiss-gpu-cu11

Collecting ir_datasets
  Downloading ir_datasets-0.5.10-py3-none-any.whl.metadata (12 kB)
Collecting rank-eval
  Downloading rank_eval-0.1.3-py3-none-any.whl.metadata (6.8 kB)
Collecting ranx
  Downloading ranx-0.3.20-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metad

## Import Libraries

In [149]:
# Import libraries
import pandas as pd
import csv

import torch
import ir_datasets
import faiss
import wandb
import heapq
import time
import sys
import random
import string
import os
import pickle
import math

import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import numpy as np

from sentence_transformers import SentenceTransformer
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm
from sklearn.metrics import ndcg_score, recall_score
from collections import defaultdict
from scipy.sparse import csr_matrix
from collections import defaultdict
from tqdm import tqdm
# from rank_eval import Qrels, Run, evaluate
from ranx import Qrels, Run, evaluate

from collections import Counter
import json

os.makedirs("embeddings/train/", exist_ok = True)
os.makedirs("embeddings/dev/", exist_ok = True)

## Download Dataset (Don't run!!! Data is already downloaded)

In [145]:
# Download collection (pId -> passage text)
# !wget -P data/raw/ https://msmarco.z22.web.core.windows.net/msmarcoranking/collection.tar.gz

# Download queries (qId -> query text)
# !wget -P data/raw/ https://msmarco.z22.web.core.windows.net/msmarcoranking/queries.tar.gz

# Download qRels Dev and Train
# !wget -P data/raw/ https://msmarco.z22.web.core.windows.net/msmarcoranking/qrels.dev.tsv
# !wget -P data/raw/ https://msmarco.z22.web.core.windows.net/msmarcoranking/qrels.train.tsv

# Download training data - qId positive_pId and negative_pId
# !wget -P data/raw/ https://msmarco.z22.web.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz

# Download total dataset (This contains all the required files)
# !wget -P data/raw https://msmarco.z22.web.core.windows.net/msmarcoranking/collectionandqueries.tar.gz

In [146]:
# # unzip files
# !tar -xf "data/raw/queries.tar.gz"
# !tar -xf "data/raw/collection.tar.gz"
# !tar -xf "data/raw/qidpidtriples.train.full.2.tsv.gz"
# !tar -xf "data/raw/collectionandqueries.tar.gz"

## Load Dataset

In [7]:
# Read the tsv file as a dictionary
def open_file(file_path, keys = [0, 1]):
    with open(file_path, mode = "r", encoding = "utf-8") as file:
        reader = csv.reader(file, delimiter= "\t")

        data = {}
        for row in reader:
            if row[keys[0]] in data:
                data[row[keys[0]]].append(row[keys[1]])
            else:
                data[row[keys[0]]] = [row[keys[1]]]

    return data

# Load and preprocess the dataset
def load_and_preprocess_dataset():
    # Load all passages
    passages = open_file("data/collection.tsv")

    # Load all train, dev & eval queries
    queries_train = open_file("data/queries.train.tsv")
    queries_dev = open_file("data/queries.dev.small.tsv")
    # queries_eval = open_file("data/queries.eval.tsv")

    # Load qRels train & dev
    # These only have relevant files i.e binary relevance judgment score
    # will be 1 for all the entries
    qrels_train = open_file("data/qrels.train.tsv", keys = [0, 2])
    qrels_dev = open_file("data/qrels.dev.small.tsv", keys = [0, 2])

    return passages, queries_train, queries_dev, qrels_train, qrels_dev

In [8]:
passages, queries_train, queries_dev, qrels_train, qrels_dev = load_and_preprocess_dataset()

In [9]:
def print_samples(file_name, dict_, n = 2):
    print('-' * 15)
    print(f'Statistics for {file_name}:')
    print('Total number of samples:', len(dict_))
    samples = list(dict_.items())[:2]
    for entry in samples:
        print(entry)

print_samples("passages", passages)
print_samples("queries_train", queries_train)
print_samples("queries_dev", queries_dev)
print_samples("qrels_train", qrels_train)
print_samples("qrels_dev", qrels_dev)

---------------
Statistics for passages:
Total number of samples: 8841823
('0', ['The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.'])
('1', ['The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science.'])
---------------
Statistics for queries_train:
Total number of samples: 808731
('121352', ['define extreme'])
('634306', ['what does chattel mean on credit history'])
---------------
Statistics for queries_dev:
Total number of samples: 6980
('1048585', ["what is paula deen's brother"])
('2', [' Androgen receptor define'])
---------------
Statistics for qrels_train:
Total number of samples: 502939
('1

## Initialize Model

In [10]:
import torch
from transformers import AutoTokenizer, AutoModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device', device)

tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
model = AutoModel.from_pretrained('facebook/contriever-msmarco').to(device)

# Mean pooling
def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim = 1) / mask.sum(dim = 1)[..., None]
    return sentence_embeddings

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

## Evaluate Model

In [138]:
# Evaluate model
def evaluate_model(model, passages, dev_queries, dev_qrels, tokenizer, device, batch_size_inference = 128, min_weight = 1e-5, query_topk = 128, doc_topk = 1024, quick_mode = True):
    base_model = model.module if isinstance(model, nn.DataParallel) else model

    model.eval()
    all_results = {}  # Store results for each query

    # Filter passages if in quick mode
    if quick_mode:
        relevant_passage_ids = set()
        for qid in dev_qrels:
            relevant_passage_ids.update(dev_qrels[qid])
        passages = {passage_id: passages[passage_id] for passage_id in relevant_passage_ids}
        print(f"Quick mode: using {len(passages)} passages for evaluation")

    # Filter queries if in quick mode
    if quick_mode:
        # dev_queries = dict(list(dev_queries.items())[:200])
        dev_qrels = {qid: dev_qrels[qid][0] for qid in dev_queries if qid in dev_qrels}
        print(f"Quick mode: using {len(dev_queries)} queries for evaluation")

    # Dense retrieval using FAISS
    print("Building FAISS index...")
    passage_ids = list(passages.keys())
    passage_embeddings = []

    # Encode all passages
    for i in tqdm(range(0, len(passage_ids), batch_size_inference), desc = "Encoding passages"):
        batch_passages = [passages[pid][0] for pid in passage_ids[i:i + batch_size_inference]]

        # Pad till the model's configured max_len (512)
        passage_inputs = tokenizer(batch_passages, padding = True, truncation = True, return_tensors = 'pt')
        passage_inputs = {k: v.to(device) for k, v in passage_inputs.items()}

        with torch.no_grad():
            outputs = model(passage_inputs["input_ids"], passage_inputs["attention_mask"])
            batch_embeddings = mean_pooling(outputs[0], passage_inputs['attention_mask'])
            passage_embeddings.append(batch_embeddings.cpu().numpy())

    passage_embeddings = np.vstack(passage_embeddings)

    # Build FAISS index
    dimension = passage_embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)
    index.add(passage_embeddings)

    # Evaluate queries
    with torch.no_grad():
        for qid, query in tqdm(dev_queries.items(), desc = "Evaluating"):
            if qid not in dev_qrels:
                continue

            query = query[0]
            # Pad till the model's configured max_len (512)
            query_input = tokenizer(query, padding = True, truncation = True, return_tensors = 'pt')
            query_input = {k: v.to(device) for k, v in query_input.items()}
            output = model(query_input["input_ids"], query_input["attention_mask"])
            query_embedding = mean_pooling(output[0], query_input['attention_mask'])
            query_embedding = query_embedding.cpu().numpy()

            # Search using FAISS
            scores, indices = index.search(query_embedding, 1000)
            search_results = [(passage_ids[idx], float(score)) for idx, score in zip(indices[0], scores[0])]
            all_results[qid] = search_results

    # Create rank_eval Run and Qrels objects
    run_dict = {}
    for qid, results in all_results.items():
        run_dict[qid] = {
            str(passage_id): float(score)
            for passage_id, score in results
        }
    run = Run(run_dict)

    qrels_dict = {
        qid: {str(doc_id): 1 for doc_id in dev_qrels[qid]}
        for qid in dev_qrels
    }
    qrels = Qrels(qrels_dict)



    # Create rank_eval Run and Qrels objects
    # run = Run()
    # qrels = Qrels()

    # Add results to Run object
    # for qid in all_results:
    #     passage_ids = [str(passage_id) for passage_id, score in all_results[qid]]
    #     scores = [float(score) for _, score in all_results[qid]]
    #     run.add(qid, passage_ids, scores)

    # # Add relevance judgments to Qrels object
    # for qid in dev_qrels:
    #     qrels.add(qid, [str(doc_id) for doc_id in dev_qrels[qid]], [1] * len(dev_qrels[qid]))

    # Evaluate using rankx
    metrics = ["ndcg@10", "ndcg@100", "ndcg@1000", "recall@10", "recall@100", "recall@1000", "mrr@10"]
    results = evaluate(qrels, run, metrics)

    return (
        results["mrr@10"],
        {
            '10': results["ndcg@10"],
            '100': results["ndcg@100"],
            '1000': results["ndcg@1000"]
        },
        {
            '10': results["recall@10"],
            '100': results["recall@100"],
            '1000': results["recall@1000"]
        }
    )


In [14]:
mrr_10, ndcg, recall = evaluate_model(model = model, passages = passages, dev_queries = queries_dev, dev_qrels = qrels_dev, tokenizer = tokenizer, device = device)

Quick mode: using 107 passages for evaluation
Quick mode: using 200 queries for evaluation
Building FAISS index...


Encoding passages: 100%|██████████| 1/1 [00:01<00:00,  1.53s/it]
Evaluating: 100%|██████████| 200/200 [00:01<00:00, 104.65it/s]
  scores[i] = _ndcg(qrels[i], run[i], k, rel_lvl, jarvelin)


In [15]:
# Metrics
print("Final Evaluation")
print(f"MRR@10: {mrr_10:.4f}")
print(f"nDCG@10: {ndcg['10']:.4f}, nDCG@100: {ndcg['100']:.4f}, nDCG@1000: {ndcg['1000']:.4f}")
print(f"Recall@10: {recall['10']:.4f}, Recall@100: {recall['100']:.4f}, Recall@1000: {recall['1000']:.4f}")

Final Evaluation
MRR@10: 0.0000
nDCG@10: 0.0000, nDCG@100: 0.0000, nDCG@1000: 0.0000
Recall@10: 0.0000, Recall@100: 0.0000, Recall@1000: 0.0000


## Get Query/Passage embeddings

In [203]:
def get_embeddings(model, passages, queries, qrels, tokenizer, device, batch_size_inference = 128, min_weight = 1e-5, query_topk = 128, doc_topk = 1024, quick_mode = True, output_batching = False, mode = 'train'):
    base_model = model.module if isinstance(model, nn.DataParallel) else model

    # Filter passages if in quick mode
    if quick_mode:
        relevant_passage_ids = set()
        # for qid in list(qrels.keys())[:1]:
        for qid in qrels:
            relevant_passage_ids.update(qrels[qid])
        passages = {passage_id: passages[passage_id][0] for passage_id in relevant_passage_ids}
        print(f"Quick mode: using {len(passages)} {mode} passages")

    # Filter queries if in quick mode
    if quick_mode:
        # queries = dict(list(queries.items())[:5000])
        qrels = {qid: qrels[qid] for qid in queries if qid in qrels}
        print(f"Quick mode: using {len(queries)} {mode} queries")

    passage_ids = list(passages.keys())
    passage_embeddings = [] # (batch_size * k_chunks, emb_dim // k_chunks)

    # Encode all passages
    for i in tqdm(range(0, len(passage_ids), batch_size_inference), desc = "Encoding passages"):
        batch_passages = [passages[pid] for pid in passage_ids[i:i + batch_size_inference]]

        # Pad till the model's configured max_len (512)
        passage_inputs = tokenizer(batch_passages, padding = True, truncation = True, return_tensors = 'pt')
        passage_inputs = {k: v.to(device) for k, v in passage_inputs.items()}

        with torch.no_grad():
            outputs = model(passage_inputs["input_ids"], passage_inputs["attention_mask"])
            batch_embeddings = mean_pooling(outputs[0], passage_inputs['attention_mask'])
            # batch_chunks = split_embedding_into_chunks(batch_embeddings, k_chunks = 4)
            # passage_embeddings.append(batch_chunks)
            passage_embeddings.append(batch_embeddings)


    query_ids = list(qrels.keys())
    query_embeddings = [] # (batch_size * k_chunks, emb_dim // k_chunks)

    # Encode all queries
    for i in tqdm(range(0, len(query_ids), batch_size_inference), desc = "Encoding queries"):
        batch_queries = [queries[qid][0] for qid in query_ids[i:i + batch_size_inference]]

        # Pad till the model's configured max_len (512)
        query_inputs = tokenizer(batch_queries, padding = True, truncation = True, return_tensors = 'pt')
        query_inputs = {k: v.to(device) for k, v in query_inputs.items()}

        with torch.no_grad():
            outputs = model(query_inputs["input_ids"], query_inputs["attention_mask"])
            batch_embeddings = mean_pooling(outputs[0], query_inputs['attention_mask'])
            # batch_chunks = split_embedding_into_chunks(batch_embeddings, k_chunks = 4)
            # query_embeddings.append(batch_chunks)
            query_embeddings.append(batch_embeddings)

    if output_batching == True:
        passage_embeddings = torch.cat(passage_embeddings, dim = 0)
        query_embeddings = torch.cat(query_embeddings, dim = 0)

    return {
            'embeddings':
                {
                    'passage_embeddings': passage_embeddings,
                    'query_embeddings': query_embeddings
                },
            'mappings': {
                    'passage_ids': passage_ids,
                    'query_ids': query_ids
                }
            }

In [86]:
obj = get_embeddings(model, passages, queries_dev, qrels_dev, tokenizer = tokenizer, device = device, output_batching = True)
passage_embeddings, query_embeddings = obj['embeddings']['passage_embeddings'], obj['embeddings']['query_embeddings']

Quick mode: using 1 train passages
Quick mode: using 5000 train queries


Encoding passages: 100%|██████████| 1/1 [00:00<00:00, 63.01it/s]
Encoding queries: 100%|██████████| 40/40 [00:05<00:00,  7.13it/s]


In [194]:
!rm -rf embeddings/train
!rm -rf embeddings/dev

In [202]:
def load_or_save_embeddings(model, passages, queries, qrels, tokenizer, device, batch_size = 128, quick_mode = True, output_batching = True, mode = 'train'):
    passage_embeddings_path = 'embeddings/' + mode + '/passage_embeddings.pt'
    query_embeddings_path = 'embeddings/' + mode + '/query_embeddings.pt'

    passage_ids_path = 'embeddings/' + mode + '/passage_ids.json'
    query_ids_path = 'embeddings/' + mode + '/query_ids.json'

    # Try loading if saved
    if os.path.exists(passage_embeddings_path) and os.path.exists(query_embeddings_path) and os.path.exists(passage_ids_path) and os.path.exists(query_ids_path):
        print("Loading cached embeddings & ids...")
        passage_embeddings = torch.load(passage_embeddings_path).to(device = device)
        query_embeddings = torch.load(query_embeddings_path).to(device = device)

        with open(passage_ids_path, "r") as f:
            passage_ids = json.load(f)

        with open(query_ids_path, "r") as f:
            query_ids = json.load(f)

        return {
            'embeddings':
                {
                    'passage_embeddings': passage_embeddings,
                    'query_embeddings': query_embeddings
                },
            'mappings': {
                    'passage_ids': passage_ids,
                    'query_ids': query_ids
                }
            }

    # Compute if not saved
    print("Computing embeddings...")
    obj = get_embeddings(
        model = model,
        passages = passages,
        queries = queries,
        qrels = qrels,
        tokenizer = tokenizer,
        device = device,
        batch_size_inference = batch_size,
        quick_mode = quick_mode,
        output_batching = output_batching,
    )

    passage_embeddings, query_embeddings = obj['embeddings']['passage_embeddings'], obj['embeddings']['query_embeddings']
    passage_ids, query_ids = obj['mappings']['passage_ids'], obj['mappings']['query_ids']

    # Save embeddings to the appropriate path
    torch.save(passage_embeddings, passage_embeddings_path)
    torch.save(query_embeddings, query_embeddings_path)

    # Save ID mappings
    with open(passage_ids_path, "w") as f:
        json.dump(passage_ids, f)

    with open(query_ids_path, "w") as f:
        json.dump(query_ids, f)

    print("Embeddings & Mappings saved.")

    return {
            'embeddings':
                {
                    'passage_embeddings': passage_embeddings,
                    'query_embeddings': query_embeddings
                },
            'mappings': {
                    'passage_ids': passage_ids,
                    'query_ids': query_ids
                }
            }

In [196]:
# Train set embeddings
load_or_save_embeddings(model, passages, queries_train, qrels_train, tokenizer = tokenizer, device = device, batch_size_inference = 128, quick_mode = True, output_batching = True, mode = 'train')

Computing embeddings...
Quick mode: using 516472 train passages
Quick mode: using 808731 train queries


Encoding passages:   1%|▏         | 51/4035 [01:04<1:23:57,  1.26s/it]


KeyboardInterrupt: 

In [None]:
# Dev set embeddings
load_or_save_embeddings(model, passages, queries_dev, qrels_dev, tokenizer = tokenizer, device = device, batch_size_inference = 128, quick_mode = True, output_batching = True, mode = 'dev')

## Vector Quantizer

In [197]:
class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay = 0.99, eps = 1e-5):
        super().__init__()

        self.dim = dim
        self.n_embed = n_embed
        self.decay = decay
        self.eps = eps

        embed = torch.randn(dim, n_embed)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(n_embed))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim = True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim = True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if self.training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

        quantize = input + (quantize - input).detach()

        return quantize, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

## Train VQ

In [201]:
# Split the embedding to k chunks
def split_embedding_into_chunks(embeddings, k_chunks):
    bsz, dim = embeddings.shape
    assert dim % k_chunks == 0
    return embeddings.view(bsz * k_chunks, dim // k_chunks)

def train_VQ(model, passages, train_queries, train_qrels, dev_queries, dev_qrels, tokenizer, device, batch_size = 128, eval_only = False):
    """
    Function to run the training

    Args:
        eval_only: If True, only run evaluation using a pre-trained model
    Note: dev_qrels & dev_queries -> These are not yet used in this function
    """

    # Model configuration
    k_chunks = 16
    dim = 768 // k_chunks

    obj = load_or_save_embeddings(model, passages, train_queries, train_qrels, tokenizer = tokenizer, device = device, batch_size = batch_size, quick_mode = True, output_batching = True, mode = 'train')
    passage_embeddings, query_embeddings = obj['embeddings']['passage_embeddings'], obj['embeddings']['query_embeddings']
    query_chunked_embeddings = split_embedding_into_chunks(query_embeddings, k_chunks)
    passage_chunked_embeddings = split_embedding_into_chunks(passage_embeddings, k_chunks)

    embeddings = torch.cat((query_chunked_embeddings, passage_chunked_embeddings), dim = 0)

    # Train the codebook vectors
    # n_embed -> number of clusters
    # dim -> each chunked embedding dimension
    quantizer = Quantize(dim = dim, n_embed = 256).to(device = device)
    quantizer.training = True
    for i in tqdm(range(0, embeddings.shape[0], batch_size), desc = "Training codebook vectors"):
        batch_embs = embeddings[i:i + batch_size]
        _, code = quantizer(batch_embs)

    return quantizer

In [156]:
quantizer = train_VQ(model, passages, queries_train, qrels_train, queries_dev, qrels_dev, tokenizer, device)

Loading cached embeddings & ids...


Training codebook vectors: 100%|██████████| 1802/1802 [00:01<00:00, 1731.64it/s]


## VQ Embeddings & Create Inverted Index

In [200]:
def perform_VQ(quantizer, passage_embeddings, batch_size = 32):
    """
    Function to run the inference on VQ
    """
    k_chunks = 16

    # Get the code book vectors for each passage and build the inverted index
    quantizer.training = False
    code_indices = []
    for i in tqdm(range(0, passage_embeddings.shape[0], batch_size), desc = "Vector quantizing..."):
        batch_embs = passage_embeddings[i:i + batch_size]
        batch_chunked_embs = split_embedding_into_chunks(batch_embs, k_chunks)
        _, code = quantizer(batch_chunked_embs)
        code = code.view(-1, k_chunks)
        code_indices.append(code)

    code_indices = torch.cat(code_indices, dim = 0)

    return code_indices


# Create an inverted index
def create_inverted_index(model, quantizer, passages, queries, q_rels, tokenizer, device, batch_size = 128, mode = 'dev'):
    # Get the embeddings for the dev set
    # embeddings are list of ems for each query/passage
    obj = load_or_save_embeddings(model, passages, queries, q_rels, tokenizer = tokenizer, device = device, batch_size_inference = batch_size, quick_mode = True, output_batching = False, mode = mode)
    passage_embeddings, query_embeddings = obj['embeddings']['passage_embeddings'], obj['embeddings']['query_embeddings']
    passage_ids = obj['mappings']['passage_ids']

    code_indices = perform_VQ(quantizer, passage_embeddings, batch_size)

    inverted_index = defaultdict(list)
    num_passages = len(passage_ids)

    for i in tqdm(range(0, num_passages), desc = "Building inverted index"):
        code_index_list = code_indices[i].tolist()
        weights = Counter(code_index_list)
        for code_index in list(set(code_index_list)):
            inverted_index[int(code_index)].append((passage_ids[i], float(weights[code_index])))  # Ensure integer keys

    # Sort postings lists by weight for each term
    for idx in inverted_index:
        inverted_index[idx] = sorted(inverted_index[idx], key = lambda x: abs(x[1]), reverse = True)

    # Convert to more efficient data structure
    optimized_index = {
        idx: (
            np.array([passage_id for passage_id, _ in postings], dtype = np.int32),
            np.array([weight for _, weight in postings], dtype = np.float32)
        )
        for idx, postings in inverted_index.items()
    }

    return optimized_index

In [166]:
# Create the inverted index on the development dataset
optimized_index = create_inverted_index(model, quantizer, passages, queries_dev, qrels_dev, tokenizer, device, batch_size = 128, mode = 'dev')

Loading cached embeddings & ids...


Training codebook vectors: 100%|██████████| 233/233 [00:00<00:00, 2340.11it/s]
Building inverted index: 100%|██████████| 7433/7433 [00:00<00:00, 30014.06it/s]


## Search Inverted Index

In [199]:
# Optimized search function for sparse retrieval using the inverted index
def search_inverted_index(query_code_index_list, inverted_index, query_topk = 128):
    scores = defaultdict(float)
    seen_passages = set()

    # # Get top-k query dimensions by weight
    # weights = [(idx, weight) for idx, weight in enumerate(query_embedding) if abs(weight) > min_weight]
    # top_weights = heapq.nlargest(query_topk, weights, key = lambda x: abs(x[1]))

    # Process each query term
    weights = Counter(query_code_index_list)
    query_code_index_list = list(set(query_code_index_list))
    for code_index in query_code_index_list:
        if code_index not in inverted_index:
            print('Unexpected!!!')
            continue

        passage_ids, passage_weights = inverted_index[code_index]
        query_weight = weights[code_index]

        # Only process top documents per term
        for passage_id, passage_weight in zip(passage_ids, passage_weights):
            scores[passage_id] += query_weight * passage_weight
            seen_passages.add(passage_id)

    # Use numpy for final scoring
    if seen_passages:
        passage_ids = np.array(list(seen_passages))
        passage_scores = np.array([scores[passage_id] for passage_id in passage_ids])

        # Get top 1000 results efficiently
        top_k = min(1000, len(passage_scores))
        top_indices = np.argpartition(passage_scores, -top_k)[-top_k:]
        top_indices = top_indices[np.argsort(-passage_scores[top_indices])]

        return [(passage_ids[i], passage_scores[i]) for i in top_indices]

    return []


def get_metrics(model, quantizer, inverted_index, passages, queries, q_rels, tokenizer, device, batch_size = 128, mode = 'dev'):
    all_results = {}

    # Get the embeddings for the dev set
    # embeddings are list of ems for each query/passage
    obj = load_or_save_embeddings(model, passages, queries, q_rels, tokenizer = tokenizer, device = device, batch_size_inference = batch_size, quick_mode = True, output_batching = False, mode = mode)
    query_embeddings = obj['embeddings']['query_embeddings']
    query_ids = obj['mappings']['query_ids']
    num_queries = len(query_ids)

    code_indices = perform_VQ(quantizer, query_embeddings, batch_size)

    for i in tqdm(range(0, num_queries), desc = "Evaluating queries"):
        code_index_list = code_indices[i].tolist()
        search_results = search_inverted_index(code_index_list, inverted_index)
        all_results[query_ids[i]] = search_results

    # Create rank_eval Run and Qrels objects
    run_dict = {}
    for qid, results in all_results.items():
        run_dict[qid] = {
            str(passage_id): float(score)
            for passage_id, score in results
        }
    run = Run(run_dict)

    qrels_dict = {
        qid: {str(passage_id): 1 for passage_id in q_rels[qid]}
        for qid in q_rels
    }
    qrels = Qrels(qrels_dict)

    # Evaluate using rankx
    metrics = ["ndcg@10", "ndcg@100", "ndcg@1000", "recall@10", "recall@100", "recall@1000", "mrr@10"]
    results = evaluate(qrels, run, metrics)

    return (
        results["mrr@10"],
        {
            '10': results["ndcg@10"],
            '100': results["ndcg@100"],
            '1000': results["ndcg@1000"]
        },
        {
            '10': results["recall@10"],
            '100': results["recall@100"],
            '1000': results["recall@1000"]
        }
    )

In [193]:
# Get the metrics on the development set (queries)
get_metrics(model, quantizer, optimized_index, passages, queries_dev, qrels_dev, tokenizer, device, batch_size = 128, mode = 'dev')

Loading cached embeddings & ids...


Training codebook vectors: 100%|██████████| 219/219 [00:00<00:00, 2791.77it/s]
Evaluating queries: 100%|██████████| 6980/6980 [02:21<00:00, 49.18it/s]


(np.float64(0.1107232114431255),
 {'10': np.float64(0.1304878218490855),
  '100': np.float64(0.17783046561884971),
  '1000': np.float64(0.21894447635230774)},
 {'10': np.float64(0.20450095510983765),
  '100': np.float64(0.439207258834766),
  '1000': np.float64(0.7739016236867239)})