# Google Drive Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
cd /content/drive/MyDrive/685/code

/content/drive/MyDrive/685/code


# Install Required Libraries

In [3]:
# Install required libraries
!pip install torch ir_datasets wandb numpy scikit-learn sentence-transformers transformers tqdm scipy matplotlib rank-eval ranx
!pip install faiss-cpu
# !pip uninstall faiss-gpu-cu11

Collecting ir_datasets
  Downloading ir_datasets-0.5.10-py3-none-any.whl.metadata (12 kB)
Collecting rank-eval
  Downloading rank_eval-0.1.3-py3-none-any.whl.metadata (6.8 kB)
Collecting ranx
  Downloading ranx-0.3.20-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metad

# Import Libraries

In [4]:
import os
import json

import torch.nn as nn
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

from utils import split_embedding_into_chunks
from dataset import DataProcessor

# Define Classes

## Embedding Processor

In [94]:
class EmbeddingProcessor:

    def __init__(self, data_processor, model, tokenizer, emb_root_dir, batch_size = 128, device = 'cpu') -> None:
        self.data_processor = data_processor
        self.model = model
        self.tokenizer = tokenizer
        self.emb_dim = None
        self.emb_root_dir = emb_root_dir
        self.device = device

        self.embeddings_path = {
            'train': {
                'passage_embs': os.path.join(self.emb_root_dir, 'train', 'passage_embeddings.pt'),
                'query_embs': os.path.join(self.emb_root_dir, 'train', 'query_embeddings.pt'),
                'passage_ids': os.path.join(self.emb_root_dir, 'train', 'passage_ids.json'),
                'query_ids': os.path.join(self.emb_root_dir, 'train', 'query_ids.json')
            },
            'dev': {
                'passage_embs': os.path.join(self.emb_root_dir, 'dev', 'passage_embeddings.pt'),
                'query_embs': os.path.join(self.emb_root_dir, 'dev', 'query_embeddings.pt'),
                'passage_ids': os.path.join(self.emb_root_dir, 'dev', 'passage_ids.json'),
                'query_ids': os.path.join(self.emb_root_dir, 'dev', 'query_ids.json')
            }
        }
        self.train_passage_embs, self.train_query_embs, self.train_passage_ids, self.train_query_ids = None, None, None, None
        self.dev_passage_embs, self.dev_query_embs, self.dev_passage_ids, self.dev_query_ids = None, None, None, None

        self.passages, self.queries_train, self.queries_dev, self.qrels_train, self.qrels_dev = None, None, None, None, None

        self.batch_size = batch_size
        return

    # Used to combine the embeddings of all the tokens
    # Contriever model
    def mean_pooling(self, token_embeddings, mask):
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim = 1) / mask.sum(dim = 1)[..., None]
        return sentence_embeddings

    def filter_data(self, passages, queries, qrels):
        relevant_passage_ids = set()
        for qid in qrels:
            relevant_passage_ids.update(qrels[qid])
        passages = {passage_id: passages[passage_id] for passage_id in relevant_passage_ids}
        passage_ids = list(passages.keys())

        qrels = {qid: qrels[qid] for qid in queries if qid in qrels}

        qrels_ = {}
        queries_ = {}
        for qid in queries:
            if qid in qrels:
                qrels_[qid] = qrels[qid]
                queries_[qid] = queries[qid]

        query_ids = list(qrels_.keys())

        return {'passages': passages, 'queries': queries_, 'qrels': qrels_, 'passage_ids': passage_ids, 'query_ids': query_ids}

    def get_filtered_data(self, mode = 'train'):
        if mode not in ['train', 'dev']:
            raise ValueError('Invalid mode!')

        if self.passages == None:
            # If they are not loaded yet, load them
            raw_data = self.data_processor.get_data()
            self.passages, self.queries_train, self.queries_dev, self.qrels_train, self.qrels_dev = raw_data['passages'], raw_data['queries_train'], raw_data['queries_dev'], raw_data['qrels_train'], raw_data['qrels_dev']

        filtered_data = {
            'passage': None,
            'query': None
        }
        if mode == 'train':
            data = self.filter_data(passages = self.passages, queries = self.queries_train, qrels = self.qrels_train)
            filtered_data['passage'] = {
                'passages': data['passages'],
                'passage_ids': data['passage_ids']
            }
            filtered_data['query'] = {
                'queries': data['queries'],
                'qrels': data['qrels'],
                'query_ids': data['query_ids']
            }

        elif mode == 'dev':
            data = self.filter_data(passages = self.passages, queries = self.queries_dev, qrels = self.qrels_dev)

            # In case of development set, passages would be the entire collection (instead of the filtered ids using qrels)
            filtered_data['passage'] = {
                'passages': self.passages,
                'passage_ids': list(self.passages.keys())
            }
            filtered_data['query'] = {
                'queries': data['queries'],
                'qrels': data['qrels'],
                'query_ids': data['query_ids']
            }

        return filtered_data

    def compute_embeddings(self, type = 'passage', mode = 'train', start = None, limit = None):
        if type == None:
            print('Embedding type not provided, Not computing embeddings!')
            return

        if type not in ['passage', 'query']:
            raise ValueError('Invalid embedding type!')

        data = self.get_filtered_data(mode)[type]
        if type == 'query':
            data = data['queries']
        elif type == 'passage':
            data = data['passages']

        if limit is not None:
            data = dict(list(data.items())[start: start + limit])
            print('Number of passages:', len(data))

        data_embeddings = []

        ids = list(data.keys())
        for i in tqdm(range(0, len(ids), self.batch_size), desc = f"Encoding {type}"):
            batch_data = [data[id] for id in ids[i:i + self.batch_size]]

            # Pad till the model's configured max_len (512)
            batch_inputs = self.tokenizer(batch_data, padding = True, truncation = True, return_tensors = 'pt')
            batch_inputs = {k: v.to(self.device) for k, v in batch_inputs.items()}

            with torch.no_grad():
                outputs = self.model(batch_inputs["input_ids"], batch_inputs["attention_mask"])
                batch_embeddings = self.mean_pooling(outputs[0], batch_inputs['attention_mask'])
                data_embeddings.append(batch_embeddings)

        data_embeddings = torch.cat(data_embeddings, dim = 0)

        return data_embeddings, list(ids)

    # TODO: Change this function
    def load_or_save_passage_embeddings(self, mode = 'train', start= None, limit = None):
        pass_embs_path = self.embeddings_path[mode]['passage_embs']
        pass_ids_path = self.embeddings_path[mode]['passage_ids']

        if os.path.exists(pass_embs_path) and os.path.exists(pass_ids_path):
            print("Loading cached passage embeddings & ids...")
            passage_embeddings = torch.load(pass_embs_path).to(device = self.device)
            self.emb_dim = passage_embeddings.shape[-1]

            with open(pass_ids_path, "r") as f:
                passage_ids = json.load(f)

            return {
                'embeddings':
                    {
                        'passage_embeddings': passage_embeddings
                    },
                'mappings': {
                        'passage_ids': passage_ids
                    }
                }

        passage_embeddings, passage_ids = self.compute_embeddings(type = 'passage', mode = mode, start = start, limit = limit)
        self.emb_dim = passage_embeddings.shape[-1]

        # Save embeddings to the appropriate path
        torch.save(passage_embeddings, pass_embs_path)

        # Save ID mappings
        with open(pass_ids_path, "w") as f:
            json.dump(passage_ids, f)

        print("Embeddings & Mappings saved.")

        return {
            'embeddings':
                {
                    'passage_embeddings': passage_embeddings
                },
            'mappings': {
                    'passage_ids': passage_ids
                }
            }

    # TODO: Change this function
    def load_or_save_query_embeddings(self, mode = 'train'):
        query_embs_path = self.embeddings_path[mode]['query_embs']
        query_ids_path = self.embeddings_path[mode]['query_ids']

        if os.path.exists(query_embs_path) and os.path.exists(query_ids_path):
            print("Loading cached query embeddings & ids...")
            query_embeddings = torch.load(query_embs_path).to(device = self.device)
            self.emb_dim = query_embeddings.shape[-1]

            with open(query_ids_path, "r") as f:
                query_ids = json.load(f)

            return {
                'embeddings':
                    {
                        'query_embeddings': query_embeddings
                    },
                'mappings': {
                        'query_ids': query_ids
                    }
                }

        query_embeddings, query_ids = self.compute_embeddings(type = 'query', mode = mode)
        self.emb_dim = query_embeddings.shape[-1]

        # Save embeddings to the appropriate path
        torch.save(query_embeddings, query_embs_path)

        with open(query_ids_path, "w") as f:
            json.dump(query_ids, f)

        print("Embeddings & Mappings saved.")

        return {
            'embeddings':
                {
                    'query_embeddings': query_embeddings
                },
            'mappings': {
                    'query_ids': query_ids
                }
            }

    def load_or_save_embeddings(self, mode = 'train'):
        # If the embeddings are already cached, return them
        if (not (mode == 'train' and self.train_passage_embs is not None)) and (not (mode == 'dev' and self.dev_passage_embs is not None)):
          pass_embs_path = self.embeddings_path[mode]['passage_embs']
          query_embs_path = self.embeddings_path[mode]['query_embs']
          pass_ids_path = self.embeddings_path[mode]['passage_ids']
          query_ids_path = self.embeddings_path[mode]['query_ids']

          if os.path.exists(pass_embs_path) and os.path.exists(query_embs_path) and os.path.exists(pass_ids_path) and os.path.exists(query_ids_path):
              print("Loading saved passage-query embeddings & ids from the provided paths...")
              passage_embeddings = torch.load(pass_embs_path, map_location = 'cpu')
              query_embeddings = torch.load(query_embs_path, map_location = 'cpu')
              self.emb_dim = passage_embeddings.shape[-1]

              with open(pass_ids_path, "r") as f:
                  passage_ids = json.load(f)

              with open(query_ids_path, "r") as f:
                  query_ids = json.load(f)
          else:
            passage_embeddings, passage_ids = self.compute_embeddings(type = 'passage', mode = mode)
            query_embeddings, query_ids = self.compute_embeddings(type = 'query', mode = mode)
            self.emb_dim = passage_embeddings.shape[-1]

            # Save embeddings to the appropriate path
            torch.save(passage_embeddings, pass_embs_path)
            torch.save(query_embeddings, query_embs_path)

            # Save ID mappings
            with open(pass_ids_path, "w") as f:
              json.dump(passage_ids, f)

            with open(query_ids_path, "w") as f:
              json.dump(query_ids, f)

            print("Query-passage embeddings & mappings saved!")

          assert passage_embeddings.shape[0] == len(passage_ids), f"Mismatch: {passage_embeddings.shape[0]} embeddings vs {len(passage_ids)} IDs"
          # TODO: uncomment this
          # assert query_embeddings.shape[0] == len(query_ids), f"Mismatch: {query_embeddings.shape[0]} embeddings vs {len(query_ids)} IDs"

          # Cache the embeddings
          if mode == 'train':
            self.train_passage_embs, self.train_query_embs, self.train_passage_ids, self.train_query_ids = passage_embeddings, query_embeddings, passage_ids, query_ids
          elif mode == 'dev':
            self.dev_passage_embs, self.dev_query_embs, self.dev_passage_ids, self.dev_query_ids = passage_embeddings, query_embeddings, passage_ids, query_ids

        if mode == 'train':
          return {
              'embeddings':
                  {
                      'passage_embeddings': self.train_passage_embs,
                      'query_embeddings': self.train_query_embs
                  },
              'mappings': {
                      'passage_ids': self.train_passage_ids,
                      'query_ids': self.train_query_ids
                  }
              }
        elif mode == 'dev':
          return {
              'embeddings':
                  {
                      'passage_embeddings': self.dev_passage_embs,
                      'query_embeddings': self.dev_query_embs
                  },
              'mappings': {
                      'passage_ids': self.dev_passage_ids,
                      'query_ids': self.dev_query_ids
                  }
              }

    def get_emb_dim(self):
        if self.emb_dim is None:
            raise ValueError('Embedding dimension not found!')

        return self.emb_dim

## Vector Quantizer

In [89]:
class Quantize(nn.Module):
    def __init__(self, codebook_vector_dim, num_clusters, decay = 0.99, eps = 1e-5):
        super().__init__()

        self.dim = codebook_vector_dim
        self.num_clusters = num_clusters
        self.decay = decay
        self.eps = eps

        embed = torch.randn(codebook_vector_dim, num_clusters)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(num_clusters))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim = True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim = True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.num_clusters).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if self.training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.eps) / (n + self.num_clusters * self.eps) * n
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

        quantize = input + (quantize - input).detach()

        return quantize, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

## Vector Quantizer Hanlder

In [91]:
class VQHandler:
    def __init__(self, embedding_processor, quantizer = None, emb_dim = None, num_clusters = None, num_chunks = 32, batch_size = 512, device = 'cpu'):

        if quantizer is not None:
            self.quantizer = quantizer
        else:
            self.quantizer = Quantize(dim = emb_dim, num_clusters = num_clusters)

        train_embeddings = embedding_processor.load_or_save_embeddings(mode = 'train')['embeddings']
        dev_embeddings = embedding_processor.load_or_save_embeddings(mode = 'dev')['embeddings']

        self.train_query_embeddings = train_embeddings['query_embeddings']
        self.train_passage_embeddings = train_embeddings['passage_embeddings']
        self.dev_query_embeddings = dev_embeddings['query_embeddings']
        self.dev_passage_embeddings = dev_embeddings['passage_embeddings']

        # Number of chunks each emb to be divided into
        self.num_chunks = num_chunks
        self.device = device
        self.batch_size = batch_size

    def train(self, mode = 'train'):
        self.quantizer.training = True

        # Note: Not using query embeddings for training
        if mode == 'train':
          pass_chunked_embs = split_embedding_into_chunks(self.train_passage_embeddings, self.num_chunks)
        elif mode == 'dev':
          pass_chunked_embs = split_embedding_into_chunks(self.dev_passage_embeddings, self.num_chunks)

        for i in tqdm(range(0, pass_chunked_embs.shape[0], self.batch_size), desc = "Training codebook vectors..."):
            batch_embs = pass_chunked_embs[i:i + self.batch_size].to(device = self.device)
            _, code = self.quantizer(batch_embs)

        self.quantizer.training = False
        return self.quantizer

    def inference(self, type = 'passage', mode = 'dev'):
        embeddings = None
        if mode == 'dev':
            if type == 'passage':
                embeddings = self.dev_passage_embeddings
            elif type == 'query':
                embeddings = self.dev_query_embeddings

        elif mode == 'train':
            if type == 'passage':
                embeddings = self.train_passage_embeddings
            elif type == 'query':
                # embeddings = self.train_query_embeddings
                # Why would you use this, huh?
                raise Exception("Bye bye!")

        # Get the code book vectors for each passage
        self.quantizer.training = False
        code_indices = []
        for i in tqdm(range(0, embeddings.shape[0], self.batch_size), desc = f"Vector quantizing {type} embeddings..."):
            batch_embs = embeddings[i:i + self.batch_size].to(device = self.device)
            batch_chunked_embs = split_embedding_into_chunks(batch_embs, self.num_chunks)
            _, code = self.quantizer(batch_chunked_embs)
            code = code.view(-1, self.num_chunks)
            code_indices.append(code)

        code_indices = torch.cat(code_indices, dim = 0)

        return code_indices

## Inverted Index Handler

In [183]:
from collections import defaultdict, Counter
import numpy as np
from scipy.sparse import csr_matrix

class InvertedIndexHandler:
    def __init__(self, embedding_processor, vocab_size = None):
        self.embedding_processor = embedding_processor
        self.train_mapppings = self.embedding_processor.load_or_save_embeddings(mode = 'train')['mappings']
        self.dev_mappings = self.embedding_processor.load_or_save_embeddings(mode = 'dev')['mappings']
        self.passage_ids = None
        self.vocab_size = vocab_size
        self.optimized_index = None
        return

    # Build the index
    def create_inverted_index(self, code_indices, mode = 'dev', index_type = 'csr'):
        self.passage_ids = self.dev_mappings['passage_ids']
        if mode == 'train':
            self.passage_ids = self.train_mapppings['passage_ids']
        elif mode != 'dev':
            raise NotImplementedError(f"Inverted index for {mode} not implemented!")

        assert code_indices.shape[0] == len(self.passage_ids), f"Mismatch: {code_indices.shape[0]} passage code indices vs {len(self.passage_ids)} passage IDs"

        if index_type == 'csr':
          if self.vocab_size == None:
            raise RuntimeError(f"vocab_size is {self.vocab_size}, can't build csr index!")

          rows, cols, data = [], [], []

          for i, codes in enumerate(tqdm(code_indices, desc = "Building passage CSR matrix...")):
              weights = Counter(codes.tolist())
              for code, freq in weights.items():
                  rows.append(i)
                  cols.append(code)
                  data.append(freq)

          # vocab_size = len(np.unique(code)) - I think we shouldn't do this
          self.optimized_index = csr_matrix((data, (rows, cols)), shape = (len(code_indices), self.vocab_size), dtype = np.float32)

        elif index_type == 'inverted_index':
          inverted_index = defaultdict(list)
          num_passages = len(self.passage_ids)
          for i in tqdm(range(0, num_passages), desc = "Building inverted index..."):
              code_index_list = code_indices[i].tolist()
              weights = Counter(code_index_list)
              for code_index in list(set(code_index_list)):
                  inverted_index[int(code_index)].append((self.passage_ids[i], float(weights[code_index])))  # Ensure integer keys

          # Sort postings lists by weight for each term
          for idx in inverted_index:
              inverted_index[idx] = sorted(inverted_index[idx], key = lambda x: abs(x[1]), reverse=True)

          # Convert to more efficient data structure
          self.optimized_index = {
              idx: (
                  np.array([passage_id for passage_id, _ in postings], dtype = np.int32),
                  np.array([weight for _, weight in postings], dtype = np.float32)
              )
              for idx, postings in inverted_index.items()
          }

        return self.optimized_index

    # Computes the results given a query code index list
    # Uses the inverted index
    def compute_results_inv_index(self, code_index_list):
      scores = defaultdict(float)
      seen_passages = set()

      # Process each query term
      weights = Counter(code_index_list)
      for code_index in list(set(code_index_list)):
          if code_index not in self.optimized_index:
              continue

          passage_ids, passage_weights = self.optimized_index[code_index]
          query_weight = weights[code_index]

          # Compute scores against all the documents this code_index is associated with
          for passage_id, passage_weight in zip(passage_ids, passage_weights):
              scores[passage_id] += query_weight * passage_weight
              seen_passages.add(passage_id)

      # Use numpy for final scoring
      if seen_passages:
          passage_ids = np.array(list(seen_passages))
          passage_scores = np.array([scores[passage_id] for passage_id in passage_ids])

          # Get top 1000 results efficiently
          top_k = min(1000, len(passage_scores))
          top_indices = np.argpartition(passage_scores, -top_k)[-top_k:]
          top_indices = top_indices[np.argsort(-passage_scores[top_indices])]

          return [(passage_ids[i], passage_scores[i]) for i in top_indices]

      print("Found no matching code indices b/w query and passage!!!")
      return []

    # Optimized search function for sparse retrieval using the inverted index
    def search_inverted_index(self, query_code_repr, query_ids, index_type = 'csr', batch_size = 512):
        num_queries = query_code_repr.shape[0]
        assert num_queries == len(query_ids), f"Mismatch: {query_code_repr.shape[0]} query code indices vs {len(query_ids)} query IDs"

        all_results = {}
        if index_type == 'csr':
          for start in tqdm(range(0, num_queries, batch_size), desc = "Evaluating queries (csr)..."):
              end = min(start + batch_size, num_queries)
              query_chunk = query_code_repr[start:end]
              scores_chunk = query_chunk @ self.optimized_index.T  # shape: [batch_size, num_passages]

              topk_idx = np.argpartition(-scores_chunk, 1000, axis = 1)[:, :1000]
              for i in range(scores_chunk.shape[0]):
                  passage_scores = scores_chunk[i]
                  top_i = topk_idx[i]
                  top_indices = top_i[np.argsort(-passage_scores[top_i])]

                  # Ensure self.passage_ids is referring to the correct mode (train/dev)
                  all_results[query_ids[start + i]] = [(self.passage_ids[t], passage_scores[t]) for t in top_indices]

        elif index_type == 'inverted_index':
          for i in tqdm(range(0, num_queries), desc = "Evaluating queries (inverted index)..."):
              code_index_list = query_code_repr[i].tolist()
              search_results = self.compute_results_inv_index(code_index_list)
              all_results[query_ids[i]] = search_results

        return all_results

## Metrics

In [116]:
from tqdm import tqdm
from ranx import Qrels, Run, evaluate
import faiss

class MetricsGenerator:
    def __init__(self, vq_handler, inverted_index_handler, embedding_processor, dev_qrels, vocab_size = None):
        self.vq_handler = vq_handler
        self.inverted_index_handler = inverted_index_handler
        self.embedding_processor = embedding_processor

        self.train_data = self.embedding_processor.load_or_save_embeddings(mode = 'train')
        self.dev_data = self.embedding_processor.load_or_save_embeddings(mode = 'dev')

        self.train_query_ids = self.train_data['mappings']['query_ids']
        self.dev_query_ids = self.dev_data['mappings']['query_ids']

        self.dev_qrels = dev_qrels

        self.query_matrix = None
        self.vocab_size = vocab_size
        return

    def build_query_matrix(self, query_code_indices):
      if self.vocab_size == None:
        raise RuntimeError(f"vocab_size is {self.vocab_size}, can't build query csr matrix!")

      self.query_matrix = np.zeros((len(query_code_indices), self.vocab_size), dtype = np.float32)

      for i, codes in enumerate(query_code_indices):
          weights = Counter(codes)
          for code, freq in weights.items():
              self.query_matrix[i, code] = freq

      return self.query_matrix

    def batch_score_queries(self, top_k = 1000):
      scores_matrix = self.query_matrix @ self.inverted_index.T  # shape: [num_queries, num_passages]
      top_k_indices = np.argpartition(-scores_matrix, top_k, axis=1)[:, :top_k]

      results = {}
      for i in range(scores_matrix.shape[0]):
          row = scores_matrix[i]
          top_idx = top_k_indices[i]
          sorted_idx = top_idx[np.argsort(-row[top_idx])]
          results.append(list(zip(sorted_idx, row[sorted_idx])))

      return results

    # Performs faiss (always on the development set queries)
    def perform_faiss(self, mode = 'dev'):
        query_embeddings = self.dev_data['embeddings']['query_embeddings']

        if mode == 'dev':
          passage_ids = self.dev_data['mappings']['passage_ids']
          passage_embeddings = self.dev_data['embeddings']['passage_embeddings']
        elif mode == 'train':
          passage_ids = self.train_data['mappings']['passage_ids']
          passage_embeddings = self.train_data['embeddings']['passage_embeddings']

        # Dense retrieval using FAISS
        print("Building FAISS index...")

        # Build FAISS index
        dimension = passage_embeddings.shape[1]
        index = faiss.IndexFlatIP(dimension)
        index.add(passage_embeddings.cpu().numpy())

        all_results = {}
        # Search using FAISS for each query
        for i, qid in enumerate(tqdm(self.dev_query_ids, desc = "Evaluating queries (FAISS)...", total = len(self.dev_query_ids))):
            if qid not in self.dev_qrels:
                continue

            # Get top 1000 results for this query
            scores, indices = index.search(query_embeddings[i:i + 1], 1000)
            search_results = [(passage_ids[idx], float(score)) for idx, score in zip(indices[0], scores[0])]
            all_results[qid] = search_results

        return all_results

    def get_metrics(self, method = 'vq', mode = 'dev', index_type = 'csr'):
        # mode param - Determines which passages to be used for faiss (dev or train)

        assert len(self.dev_qrels) == len(self.dev_query_ids), f"Mismatch: {len(self.dev_qrels)} development qrels vs {len(self.dev_query_ids)} dev query ids"

        if method == 'faiss':
            all_results = self.perform_faiss(mode = mode)
        elif method == 'vq':
            query_code_indices = self.vq_handler.inference(type = 'query', mode = 'dev')

            if index_type == 'csr':
              self.build_query_matrix(query_code_indices)
              all_results = self.inverted_index_handler.search_inverted_index(self.query_matrix, self.dev_query_ids, index_type = 'csr', batch_size = 128)
            else:
              all_results = self.inverted_index_handler.search_inverted_index(query_code_indices, self.dev_query_ids, index_type = 'inverted_index')

        # Create rank_eval Run and Qrels objects
        run_dict = {}
        for qid, results in all_results.items():
            run_dict[qid] = {
                str(passage_id): float(score)
                for passage_id, score in results
            }
        run = Run(run_dict)

        qrels_dict = {
            qid: {str(passage_id): 1 for passage_id in self.dev_qrels[qid]}
            for qid in self.dev_qrels
        }
        qrels = Qrels(qrels_dict)

        # Evaluate using rankx
        metrics = ["ndcg@10", "ndcg@100", "ndcg@1000", "recall@10", "recall@100", "recall@1000", "mrr@10"]
        results = evaluate(qrels, run, metrics)

        return (
            results["mrr@10"],
            {
                '10': results["ndcg@10"],
                '100': results["ndcg@100"],
                '1000': results["ndcg@1000"]
            },
            {
                '10': results["recall@10"],
                '100': results["recall@100"],
                '1000': results["recall@1000"]
            }
        )

# Main

### Load the dataset

In [8]:
'''
    - Load the raw dataset
'''
data_processor = DataProcessor(data_root_dir = '../data')
data = data_processor.get_data()
passages, queries_train, queries_dev, qrels_train, qrels_dev = data['passages'], data['queries_train'], data['queries_dev'], data['qrels_train'], data['qrels_dev']
data_processor.print_samples()

Passages:
0: The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.
1: The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science.
2: Essay on The Manhattan Project - The Manhattan Project The Manhattan Project was to see if making an atomic bomb possible. The success of this project would forever change the world forever making it known that something this powerful can be manmade.
3: The Manhattan Project was the name for a project conducted during World War II, to develop the first atomic bomb. It refers specifically to the period of the project from 194 â¦ 2-1946 under the control of the U.S. Army Corp

### Instantiate the model

In [9]:
'''
    - Initialize the model(Contriever)
'''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device', device)

tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
model = AutoModel.from_pretrained('facebook/contriever-msmarco').to(device)

Using device cpu


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

### Load the embeddings

In [95]:
'''
    - Load/Save the embeddings
'''
embedding_processor = EmbeddingProcessor(data_processor = data_processor, model = model, tokenizer = tokenizer, emb_root_dir = '../embeddings', batch_size = 128, device = device)

In [96]:
obj_train = embedding_processor.load_or_save_embeddings(mode = 'train')
obj_dev = embedding_processor.load_or_save_embeddings(mode = 'dev')

Loading saved passage-query embeddings & ids from the provided paths...
Loading saved passage-query embeddings & ids from the provided paths...


In [97]:
'''
  - Embeddings & mappings stats
'''
print('Train passage stats: ')
print('-> Number of passage embeddings: ', obj_train['embeddings']['passage_embeddings'].shape)
print('-> Number of passage ids: ', len(obj_train['mappings']['passage_ids']))
print('-> Number of query embeddings: ', obj_train['embeddings']['query_embeddings'].shape)
print('-> Number of query ids: ', len(obj_train['mappings']['query_ids']))

print('Dev passage stats: ')
print('-> Number of passage embeddings: ', obj_dev['embeddings']['passage_embeddings'].shape)
print('-> Number of passage ids: ', len(obj_dev['mappings']['passage_ids']))
print('-> Number of query embeddings: ', obj_dev['embeddings']['query_embeddings'].shape)
print('-> Number of query ids: ', len(obj_dev['mappings']['query_ids']))

Train passage stats: 
-> Number of passage embeddings:  torch.Size([516472, 768])
-> Number of passage ids:  516472
-> Number of query embeddings:  torch.Size([6980, 768])
-> Number of query ids:  502939
Dev passage stats: 
-> Number of passage embeddings:  torch.Size([7433, 768])
-> Number of passage ids:  7433
-> Number of query embeddings:  torch.Size([6980, 768])
-> Number of query ids:  6980


### Train VQ

In [240]:
'''
    - Train the Vector Quantizer
    - Vector quantize the embeddings & Get the code indices
'''
vocab_size = 6000
quantizer = Quantize(codebook_vector_dim = 48, num_clusters = vocab_size).to(device = device)
vq_hanlder = VQHandler(embedding_processor = embedding_processor, quantizer = quantizer, num_chunks = 16, batch_size = 2048, device = device)
quantizer = vq_hanlder.train(mode = 'train')

Training codebook vectors...: 100%|██████████| 4035/4035 [03:06<00:00, 21.60it/s]


In [241]:
# Get the code indices for each passage
# These code indices are used to build the inverted index
# If mode is train, we would use the train passages for building the index (In that case, while evaluating using dev queries, train passages should also include the dev passages in it)
passage_code_indices = vq_hanlder.inference(type = 'passage', mode = 'dev').cpu().numpy()
print('\nPassage code indices shape (will be used for building the inverted index)', passage_code_indices.shape)

Vector quantizing passage embeddings...: 100%|██████████| 4/4 [00:04<00:00,  1.16s/it]


Passage code indices shape (will be used for building the inverted index) (7433, 16)





### Manage the inverted index

In [242]:
'''
    - Create the inverted index with the vector quantized indices
'''
inverted_index_handler = InvertedIndexHandler(embedding_processor = embedding_processor, vocab_size = vocab_size)

In [243]:
# Code indices of the passage set that you are working on
# Changes based on the mode (dev/train)
obj = inverted_index_handler.create_inverted_index(code_indices = passage_code_indices, mode = 'dev', index_type = 'csr')

Building passage CSR matrix...: 100%|██████████| 7433/7433 [00:00<00:00, 264178.74it/s]


### Generate metrics

In [244]:
'''
    - Generate Metrics using the inverted index built
'''

# Pass qrels to calculate the metrics
metrics_generator = MetricsGenerator(vq_handler = vq_hanlder, inverted_index_handler = inverted_index_handler, embedding_processor = embedding_processor, dev_qrels = qrels_dev, vocab_size = vocab_size)

In [245]:
# Get results for the VQ approach
results_vq = metrics_generator.get_metrics(method = 'vq', index_type = 'csr')
results_vq

Vector quantizing query embeddings...: 100%|██████████| 4/4 [00:03<00:00,  1.04it/s]
Evaluating queries (csr)...: 100%|██████████| 55/55 [00:02<00:00, 19.36it/s]


(np.float64(0.36619411470414337),
 {'10': np.float64(0.410098592873994),
  '100': np.float64(0.4485607516396251),
  '1000': np.float64(0.46062582109266687)},
 {'10': np.float64(0.5637535816618912),
  '100': np.float64(0.7493314231136581),
  '1000': np.float64(0.8412368672397327)})

In [246]:
# Get results for the FAISS approach
# results_faiss = metrics_generator.get_metrics(method = 'faiss', mode = 'dev')
# results_faiss

## Results (using TF-TF)

In [247]:
'''
  - VQ
  - 5L passage embeddings for training VQ
  - 7.4K passages for building the inverted index
  - 6.9k dev queries to evaluate
'''

# (codebook_vector_dim = 24, num_clusters = 256, num_chunks = 32)
(np.float64(0.17119007822804383),
{'10': np.float64(0.19149957537796625),
'100': np.float64(0.2353862277616052),
'1000': np.float64(0.27115998415045417)},
{'10': np.float64(0.2693409742120344),
'100': np.float64(0.48398997134670485),
'1000': np.float64(0.7758237822349571)})

# (codebook_vector_dim = 48, num_clusters = 128, num_chunks = 16)
(np.float64(0.12466878155273571),
{'10': np.float64(0.14767703404097327),
'100': np.float64(0.20209698406944923),
'1000': np.float64(0.24292927147100815)},
{'10': np.float64(0.23182903533906402),
'100': np.float64(0.5002865329512894),
'1000': np.float64(0.8279130850047755)})

# (codebook_vector_dim = 48, num_clusters = 256, num_chunks = 16)
(np.float64(0.1691290307909219),
{'10': np.float64(0.19559153049706235),
'100': np.float64(0.2556197314476698),
'1000': np.float64(0.2884573843470742)},
{'10': np.float64(0.2923829990448901),
'100': np.float64(0.587714899713467),
'1000': np.float64(0.8500358166189111)})

# (codebook_vector_dim = 48, num_clusters = 512, num_chunks = 16)
(np.float64(0.22092099877200164),
{'10': np.float64(0.251955542931182),
'100': np.float64(0.31416249725051193),
'1000': np.float64(0.33961830864204745)},
{'10': np.float64(0.364541547277937),
'100': np.float64(0.6663085004775549),
'1000': np.float64(0.8676337153772683)})

# (codebook_vector_dim = 64, num_clusters = 2000, num_chunks = 12) **
(np.float64(0.2716147041433574),
{'10': np.float64(0.31473508373769676),
'100': np.float64(0.37432307035851536),
'1000': np.float64(0.39284564479797796)},
{'10': np.float64(0.4700214899713467),
'100': np.float64(0.7453438395415473),
'1000': np.float64(0.8907234957020057)})

# (codebook_vector_dim = 48, num_clusters = 5000, num_chunks = 16) **
(np.float64(0.3283248055669259),
{'10': np.float64(0.3688863406910524),
'100': np.float64(0.4119222205754149),
'1000': np.float64(0.43068227869497544)},
{'10': np.float64(0.5117597898758357),
'100': np.float64(0.7116404011461318),
'1000': np.float64(0.8668338108882522)})

# (codebook_vector_dim = 64, num_clusters = 5000, num_chunks = 12) **
(np.float64(0.32318983945058444),
{'10': np.float64(0.3699945665521756),
'100': np.float64(0.42161033636510026),
'1000': np.float64(0.4310999911021888)},
{'10': np.float64(0.5348853868194843),
'100': np.float64(0.7776981852913084),
'1000': np.float64(0.8545964660936008)})

# (codebook_vector_dim = 24, num_clusters = 5000, num_chunks = 32)
(np.float64(0.3284150293355165),
{'10': np.float64(0.3521708719785304),
'100': np.float64(0.38744525258572193),
'1000': np.float64(0.4112226004970137)},
{'10': np.float64(0.44363658070678125),
'100': np.float64(0.621095988538682),
'1000': np.float64(0.8201170009551099)})

# (codebook_vector_dim = 64, num_clusters = 6000, num_chunks = 12) **
(np.float64(0.33638928685132125),
{'10': np.float64(0.3837062708306875),
'100': np.float64(0.4307398449636524),
'1000': np.float64(0.4435182294671534)},
{'10': np.float64(0.5484121298949379),
'100': np.float64(0.7686127029608404),
'1000': np.float64(0.8630730659025788)})

# (codebook_vector_dim = 48, num_clusters = 6000, num_chunks = 16) **
(np.float64(0.36619411470414337),
{'10': np.float64(0.410098592873994),
'100': np.float64(0.4485607516396251),
'1000': np.float64(0.46062582109266687)},
{'10': np.float64(0.5637535816618912),
'100': np.float64(0.7493314231136581),
'1000': np.float64(0.8412368672397327)})

# (codebook_vector_dim = 48, num_clusters = 10000, num_chunks = 16)
(np.float64(0.3285774548596898),
{'10': np.float64(0.3642851923026182),
'100': np.float64(0.40186488099180434),
'1000': np.float64(0.42212981926836546)},
{'10': np.float64(0.49356494746895896),
'100': np.float64(0.6669054441260746),
'1000': np.float64(0.835398758357211)})

# (codebook_vector_dim = 24, num_clusters = 10000, num_chunks = 32)
(np.float64(0.3010531791513167),
{'10': np.float64(0.31786488929181994),
'100': np.float64(0.354913812481216),
'1000': np.float64(0.3830926676609372)},
{'10': np.float64(0.3889565425023878),
'100': np.float64(0.5662129894937917),
'1000': np.float64(0.795642311365807)})

'''
  - FAISS
  - 7.4K passages for building the inverted index
  - 6.9k dev queries to evaluate
'''

(np.float64(0.9675002274071043),
{'10': np.float64(0.9743810319851177),
'100': np.float64(0.9749237822949888),
'1000': np.float64(0.9749527047512263)},
{'10': np.float64(0.9973734479465137),
'100': np.float64(0.9996418338108882),
'1000': np.float64(0.9998567335243553)})


(np.float64(0.9675002274071043),
 {'10': np.float64(0.9743810319851177),
  '100': np.float64(0.9749237822949888),
  '1000': np.float64(0.9749527047512263)},
 {'10': np.float64(0.9973734479465137),
  '100': np.float64(0.9996418338108882),
  '1000': np.float64(0.9998567335243553)})