## Google Drive setup

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
cd /content/drive/MyDrive/685/code

/content/drive/MyDrive/685/code


## Install required libraries

In [3]:
# Install required libraries
!pip install torch ir_datasets wandb numpy scikit-learn sentence-transformers transformers tqdm scipy matplotlib rank-eval ranx
!pip install faiss-cpu
# !pip uninstall faiss-gpu-cu11

Collecting ir_datasets
  Downloading ir_datasets-0.5.10-py3-none-any.whl.metadata (12 kB)
Collecting rank-eval
  Downloading rank_eval-0.1.3-py3-none-any.whl.metadata (6.8 kB)
Collecting ranx
  Downloading ranx-0.3.20-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metad

## Import Libraries

In [4]:
import os
import json

import torch.nn as nn
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

from utils import split_embedding_into_chunks
from dataset import DataProcessor

## Embedding Processor

In [5]:
class EmbeddingProcessor:

    def __init__(self, data_processor, model, tokenizer, emb_root_dir, batch_size = 128, device = 'cpu') -> None:
        self.data_processor = data_processor
        self.model = model
        self.tokenizer = tokenizer
        self.emb_dim = None
        self.emb_root_dir = emb_root_dir
        self.device = device

        self.embeddings_path = {
            'train': {
                'passage_embs': os.path.join(self.emb_root_dir, 'train', 'passage_embeddings.pt'),
                'query_embs': os.path.join(self.emb_root_dir, 'dev', 'query_embeddings.pt'),
                'passage_ids': os.path.join(self.emb_root_dir, 'train', 'passage_ids.json'),
                'query_ids': os.path.join(self.emb_root_dir, 'dev', 'query_ids.json')
            },
            'dev': {
                'passage_embs': os.path.join(self.emb_root_dir, 'dev', 'passage_embeddings.pt'),
                'query_embs': os.path.join(self.emb_root_dir, 'dev', 'query_embeddings.pt'),
                'passage_ids': os.path.join(self.emb_root_dir, 'dev', 'passage_ids.json'),
                'query_ids': os.path.join(self.emb_root_dir, 'dev', 'query_ids.json')
            }
        }
        self.passages, self.queries_train, self.queries_dev, self.qrels_train, self.qrels_dev = None, None, None, None, None

        self.batch_size = batch_size
        return

    # Used to combine the embeddings of all the tokens
    # Contriever model
    def mean_pooling(self, token_embeddings, mask):
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim = 1) / mask.sum(dim = 1)[..., None]
        return sentence_embeddings

    def filter_data(self, passages, queries, qrels):
        relevant_passage_ids = set()
        for qid in qrels:
            relevant_passage_ids.update(qrels[qid])
        passages = {passage_id: passages[passage_id] for passage_id in relevant_passage_ids}
        passage_ids = list(passages.keys())

        qrels = {qid: qrels[qid] for qid in queries if qid in qrels}

        qrels_ = {}
        queries_ = {}
        for qid in queries:
            if qid in qrels:
                qrels_[qid] = qrels[qid]
                queries_[qid] = queries[qid]

        query_ids = list(qrels_.keys())

        return {'passages': passages, 'queries': queries_, 'qrels': qrels_, 'passage_ids': passage_ids, 'query_ids': query_ids}

    def get_filtered_data(self, mode = 'train'):
        if mode not in ['train', 'dev']:
            raise ValueError('Invalid mode!')

        if self.passages == None:
            # If they are not loaded yet, load them
            raw_data = self.data_processor.get_data()
            self.passages, self.queries_train, self.queries_dev, self.qrels_train, self.qrels_dev = raw_data['passages'], raw_data['queries_train'], raw_data['queries_dev'], raw_data['qrels_train'], raw_data['qrels_dev']

        filtered_data = {
            'passage': None,
            'query': None
        }
        if mode == 'train':
            data = self.filter_data(passages = self.passages, queries = self.queries_train, qrels = self.qrels_train)
            filtered_data['passage'] = {
                'passages': data['passages'],
                'passage_ids': data['passage_ids']
            }
            filtered_data['query'] = {
                'queries': data['queries'],
                'qrels': data['qrels'],
                'query_ids': data['query_ids']
            }

        elif mode == 'dev':
            data = self.filter_data(passages = self.passages, queries = self.queries_dev, qrels = self.qrels_dev)

            # In case of development set, passages would be the entire collection (instead of the filtered ids using qrels)
            filtered_data['passage'] = {
                'passages': self.passages,
                'passage_ids': list(self.passages.keys())
            }
            filtered_data['query'] = {
                'queries': data['queries'],
                'qrels': data['qrels'],
                'query_ids': data['query_ids']
            }

        return filtered_data

    def compute_embeddings(self, type = 'passage', mode = 'train', start = None, limit = None):
        if type == None:
            print('Embedding type not provided, Not computing embeddings!')
            return

        if type not in ['passage', 'query']:
            raise ValueError('Invalid embedding type!')

        data = self.get_filtered_data(mode)[type]
        if type == 'query':
            data = data['queries']
        elif type == 'passage':
            data = data['passages']

        if limit is not None:
            data = dict(list(data.items())[start: start + limit])
            print('Number of passages:', len(data))

        data_embeddings = []

        ids = list(data.keys())
        for i in tqdm(range(0, len(ids), self.batch_size), desc = f"Encoding {type}"):
            batch_data = [data[id] for id in ids[i:i + self.batch_size]]

            # Pad till the model's configured max_len (512)
            batch_inputs = self.tokenizer(batch_data, padding = True, truncation = True, return_tensors = 'pt')
            batch_inputs = {k: v.to(self.device) for k, v in batch_inputs.items()}

            with torch.no_grad():
                outputs = self.model(batch_inputs["input_ids"], batch_inputs["attention_mask"])
                batch_embeddings = self.mean_pooling(outputs[0], batch_inputs['attention_mask'])
                data_embeddings.append(batch_embeddings)

        data_embeddings = torch.cat(data_embeddings, dim = 0)

        return data_embeddings, list(ids)

    def load_or_save_passage_embeddings(self, mode = 'train', start= None, limit = None):
        pass_embs_path = self.embeddings_path[mode]['passage_embs']
        pass_ids_path = self.embeddings_path[mode]['passage_ids']

        if os.path.exists(pass_embs_path) and os.path.exists(pass_ids_path):
            print("Loading cached passage embeddings & ids...")
            passage_embeddings = torch.load(pass_embs_path).to(device = self.device)
            self.emb_dim = passage_embeddings.shape[-1]

            with open(pass_ids_path, "r") as f:
                passage_ids = json.load(f)

            return {
                'embeddings':
                    {
                        'passage_embeddings': passage_embeddings
                    },
                'mappings': {
                        'passage_ids': passage_ids
                    }
                }

        passage_embeddings, passage_ids = self.compute_embeddings(type = 'passage', mode = mode, start = start, limit = limit)
        self.emb_dim = passage_embeddings.shape[-1]

        # Save embeddings to the appropriate path
        torch.save(passage_embeddings, pass_embs_path)

        # Save ID mappings
        with open(pass_ids_path, "w") as f:
            json.dump(passage_ids, f)

        print("Embeddings & Mappings saved.")

        return {
            'embeddings':
                {
                    'passage_embeddings': passage_embeddings
                },
            'mappings': {
                    'passage_ids': passage_ids
                }
            }


    def load_or_save_query_embeddings(self, mode = 'train'):
        query_embs_path = self.embeddings_path[mode]['query_embs']
        query_ids_path = self.embeddings_path[mode]['query_ids']

        if os.path.exists(query_embs_path) and os.path.exists(query_ids_path):
            print("Loading cached query embeddings & ids...")
            query_embeddings = torch.load(query_embs_path).to(device = self.device)
            self.emb_dim = query_embeddings.shape[-1]

            with open(query_ids_path, "r") as f:
                query_ids = json.load(f)

            return {
                'embeddings':
                    {
                        'query_embeddings': query_embeddings
                    },
                'mappings': {
                        'query_ids': query_ids
                    }
                }

        query_embeddings, query_ids = self.compute_embeddings(type = 'query', mode = mode)
        self.emb_dim = query_embeddings.shape[-1]

        # Save embeddings to the appropriate path
        torch.save(query_embeddings, query_embs_path)

        with open(query_ids_path, "w") as f:
            json.dump(query_ids, f)

        print("Embeddings & Mappings saved.")

        return {
            'embeddings':
                {
                    'query_embeddings': query_embeddings
                },
            'mappings': {
                    'query_ids': query_ids
                }
            }

    def load_or_save_embeddings(self, mode = 'train'):
        pass_embs_path = self.embeddings_path[mode]['passage_embs']
        query_embs_path = self.embeddings_path[mode]['query_embs']
        pass_ids_path = self.embeddings_path[mode]['passage_ids']
        query_ids_path = self.embeddings_path[mode]['query_ids']
        print(pass_embs_path)

        if os.path.exists(pass_embs_path) and os.path.exists(query_embs_path) and os.path.exists(pass_ids_path) and os.path.exists(query_ids_path):
            print("Loading cached embeddings & ids...")
            passage_embeddings = torch.load(pass_embs_path, map_location = 'cpu')
            query_embeddings = torch.load(query_embs_path, map_location = 'cpu')
            # passage_embeddings = torch.load(pass_embs_path).to(device = self.device)
            # query_embeddings = torch.load(query_embs_path).to(device = self.device)
            self.emb_dim = passage_embeddings.shape[-1]

            with open(pass_ids_path, "r") as f:
                passage_ids = json.load(f)

            with open(query_ids_path, "r") as f:
                query_ids = json.load(f)

            return {
                'embeddings':
                    {
                        'passage_embeddings': passage_embeddings,
                        'query_embeddings': query_embeddings
                    },
                'mappings': {
                        'passage_ids': passage_ids,
                        'query_ids': query_ids
                    }
                }

        passage_embeddings, passage_ids = self.compute_embeddings(type = 'passage', mode = mode)
        query_embeddings, query_ids = self.compute_embeddings(type = 'query', mode = mode)
        self.emb_dim = passage_embeddings.shape[-1]

        # Save embeddings to the appropriate path
        torch.save(passage_embeddings, pass_embs_path)
        torch.save(query_embeddings, query_embs_path)

        # Save ID mappings
        with open(pass_ids_path, "w") as f:
            json.dump(passage_ids, f)

        with open(query_ids_path, "w") as f:
            json.dump(query_ids, f)

        print("Embeddings & Mappings saved.")

        return {
            'embeddings':
                {
                    'passage_embeddings': passage_embeddings,
                    'query_embeddings': query_embeddings
                },
            'mappings': {
                    'passage_ids': passage_ids,
                    'query_ids': query_ids
                }
            }

    def get_emb_dim(self):
        if self.emb_dim is None:
            raise ValueError('Embedding dimension not found!')

        return self.emb_dim

## Vector Quantizer

In [6]:
class Quantize(nn.Module):
    def __init__(self, dim, num_clusters, decay = 0.99, eps = 1e-5):
        super().__init__()

        self.dim = dim
        self.num_clusters = num_clusters
        self.decay = decay
        self.eps = eps

        embed = torch.randn(dim, num_clusters)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(num_clusters))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
            flatten.pow(2).sum(1, keepdim = True)
            - 2 * flatten @ self.embed
            + self.embed.pow(2).sum(0, keepdim = True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.num_clusters).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if self.training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
            n = self.cluster_size.sum()
            cluster_size = (
                (self.cluster_size + self.eps) / (n + self.num_clusters * self.eps) * n
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

        quantize = input + (quantize - input).detach()

        return quantize, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))

## Vector Quantizer Hanlder

In [7]:
class VQHandler:
    def __init__(self, embedding_processor, quantizer = None, emb_dim = None, num_clusters = None, num_chunks = 32, batch_size = 512, device = 'cpu'):

        if quantizer is not None:
            self.quantizer = quantizer
        else:
            self.quantizer = Quantize(dim = emb_dim, num_clusters = num_clusters)

        self.embedding_processor = embedding_processor
        train_embeddings = self.embedding_processor.load_or_save_embeddings(mode = 'train')['embeddings']
        dev_embeddings = self.embedding_processor.load_or_save_embeddings(mode = 'dev')['embeddings']

        self.train_query_embeddings = train_embeddings['query_embeddings']
        self.train_passage_embeddings = dev_embeddings['passage_embeddings']
        self.dev_query_embeddings = dev_embeddings['query_embeddings']
        self.dev_passage_embeddings = dev_embeddings['passage_embeddings']

        # Number of chunks each emb to be divided into
        self.num_chunks = num_chunks
        self.device = device
        self.batch_size = batch_size

    def train(self):
        self.quantizer.training = True
        # train_query_chunked_embs = split_embedding_into_chunks(self.train_query_embeddings, self.num_chunks)
        train_pass_chunked_embs = split_embedding_into_chunks(self.train_passage_embeddings, self.num_chunks)

        # embeddings = torch.cat((train_query_chunked_embs, train_pass_chunked_embs), dim = 0)

        for i in tqdm(range(0, train_pass_chunked_embs.shape[0], self.batch_size), desc = "Training codebook vectors"):
            batch_embs = train_pass_chunked_embs[i:i + self.batch_size].to(device = self.device)
            _, code = self.quantizer(batch_embs)

        self.quantizer.training = False
        return self.quantizer

    def inference(self, type = 'passage', mode = 'dev'):

        embeddings = None
        if mode == 'dev':
            if type == 'passage':
                embeddings = self.dev_passage_embeddings
            elif type == 'query':
                embeddings = self.dev_query_embeddings

        if mode == 'train':
            if type == 'passage':
                embeddings = self.train_passage_embeddings
            elif type == 'query':
                embeddings = self.train_query_embeddings

        # Get the code book vectors for each passage in the devlopment set
        self.quantizer.training = False
        code_indices = []
        for i in tqdm(range(0, embeddings.shape[0], self.batch_size), desc = "Vector quantizing..."):
            batch_embs = embeddings[i:i + self.batch_size].to(device = self.device)
            batch_chunked_embs = split_embedding_into_chunks(batch_embs, self.num_chunks)
            _, code = self.quantizer(batch_chunked_embs)
            code = code.view(-1, self.num_chunks)
            code_indices.append(code)

        code_indices = torch.cat(code_indices, dim = 0)

        return code_indices

## Main.py

In [8]:
'''
    - Load the raw dataset
'''
data_processor = DataProcessor(data_root_dir = '../data')
data = data_processor.get_data()
passages, queries_train, queries_dev, qrels_train, qrels_dev = data['passages'], data['queries_train'], data['queries_dev'], data['qrels_train'], data['qrels_dev']
data_processor.print_samples()

Passages:
0: The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.
1: The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science.
2: Essay on The Manhattan Project - The Manhattan Project The Manhattan Project was to see if making an atomic bomb possible. The success of this project would forever change the world forever making it known that something this powerful can be manmade.
3: The Manhattan Project was the name for a project conducted during World War II, to develop the first atomic bomb. It refers specifically to the period of the project from 194 â¦ 2-1946 under the control of the U.S. Army Corp

In [9]:
'''
    - Initialize the model(Contriever)
'''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device', device)

# tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
# model = AutoModel.from_pretrained('facebook/contriever-msmarco').to(device)
model = None
tokenizer = None

Using device cuda


In [10]:
'''
    - Load/Save the embeddings
'''
embedding_processor = EmbeddingProcessor(data_processor = data_processor, model = model, tokenizer = tokenizer, emb_root_dir = '../embeddings', batch_size = 128, device = device)
# emb_dim = embedding_processor.get_emb_dim()

In [11]:
'''
    - Train the Vector Quantizer
    - Vector quantize the embeddings & Get the code indices
'''
quantizer = Quantize(dim = 6, num_clusters = 10).to(device = device)
vq_hanlder = VQHandler(embedding_processor = embedding_processor, quantizer = quantizer, num_chunks = 128, device = device)
quantizer = vq_hanlder.train()

# Get the code indices for each passage in development set
# These code indices are used to build the inverted index
# code_indices = vq_hanlder.inference(type = 'passage', mode = 'dev')

asdflkasd;la
../embeddings/train/passage_embeddings.pt
Loading cached embeddings & ids...
asdflkasd;la
../embeddings/dev/passage_embeddings.pt
Loading cached embeddings & ids...


Training codebook vectors: 100%|██████████| 1859/1859 [00:02<00:00, 877.83it/s] 


In [12]:
code_indices = vq_hanlder.inference(type = 'passage', mode = 'dev').cpu().numpy()

Vector quantizing...: 100%|██████████| 15/15 [00:00<00:00, 685.66it/s]


In [13]:
print(code_indices)

[[9 7 3 ... 8 3 4]
 [7 3 2 ... 8 6 9]
 [0 7 4 ... 8 3 4]
 ...
 [3 7 6 ... 8 4 1]
 [4 7 1 ... 8 1 2]
 [1 3 1 ... 8 1 9]]


In [14]:
print(code_indices.shape)

(7433, 128)


In [15]:
from collections import defaultdict, Counter
import numpy as np
from scipy.sparse import csr_matrix

class InvertedIndexHandler:
    def __init__(self, embedding_processor):
        self.embedding_processor = embedding_processor
        self.train_mapppings = self.embedding_processor.load_or_save_embeddings(mode = 'train')['mappings']
        self.dev_mappings = self.embedding_processor.load_or_save_embeddings(mode = 'dev')['mappings']

        # self.train_passage_ids = train_mapppings['passage_ids']
        # self.dev_passage_ids = dev_mappings['passage_ids']
        self.passage_ids = None
        self.optimized_index = None
        return

    def create_inverted_index(self, code_indices, mode = 'dev'):
        self.passage_ids = self.dev_mappings['passage_ids']
        if mode == 'train':
            self.passage_ids = self.train_mapppings['passage_ids']
        elif mode != 'dev':
            raise NotImplementedError(f"Inverted index for {mode} not implemented!")

        inverted_index = defaultdict(list)
        num_passages = len(self.passage_ids)

        rows, cols, data = [], [], []
        vocab_size = 10

        for i, codes in enumerate(tqdm(code_indices, desc="Building passage matrix")):
            weights = Counter(codes)
            for code, freq in weights.items():
                rows.append(i)
                # rows.append(passage_ids[i])
                cols.append(code)
                data.append(freq)

        self.optimized_index = csr_matrix((data, (rows, cols)), shape=(len(code_indices), vocab_size), dtype=np.float32)
        return self.optimized_index

        # for i in tqdm(range(0, num_passages), desc = "Building inverted index"):
        #     code_index_list = code_indices[i].tolist()
        #     weights = Counter(code_index_list)
        #     for code_index in list(set(code_index_list)):
        #         inverted_index[int(code_index)].append((self.passage_ids[i], float(weights[code_index])))  # Ensure integer keys

        # # Sort postings lists by weight for each term
        # for idx in inverted_index:
        #     inverted_index[idx] = sorted(inverted_index[idx], key = lambda x: abs(x[1]), reverse=True)

        # # Convert to more efficient data structure
        # self.optimized_index = {
        #     idx: (
        #         np.array([passage_id for passage_id, _ in postings], dtype = np.int32),
        #         np.array([weight for _, weight in postings], dtype = np.float32)
        #     )
        #     for idx, postings in inverted_index.items()
        # }

        # return self.optimized_index

    # Optimized search function for sparse retrieval using the inverted index
    def search_inverted_index(self, query_matrix, query_ids):
    # def search_inverted_index(self, query_code_index_list):
        # scores = defaultdict(float)
        # seen_passages = set()

        # # Process each query term
        # weights = Counter(query_code_index_list)
        # query_code_index_list = list(set(query_code_index_list))
        # for code_index in query_code_index_list:
        #     if code_index not in self.optimized_index:
        #         # print('Unexpected!!!')
        #         continue

        #     passage_ids, passage_weights = self.optimized_index[code_index]
        #     query_weight = weights[code_index]

        #     # Only process top documents per term
        #     for passage_id, passage_weight in zip(passage_ids, passage_weights):
        #         scores[passage_id] += query_weight * passage_weight
        #         seen_passages.add(passage_id)

        # # Use numpy for final scoring
        # if seen_passages:
        #     passage_ids = np.array(list(seen_passages))
        #     passage_scores = np.array([scores[passage_id] for passage_id in passage_ids])

        #     # Get top 1000 results efficiently
        #     top_k = min(1000, len(passage_scores))
        #     top_indices = np.argpartition(passage_scores, -top_k)[-top_k:]
        #     top_indices = top_indices[np.argsort(-passage_scores[top_indices])]

        #     return [(passage_ids[i], passage_scores[i]) for i in top_indices]

        # return []
        num_queries = query_matrix.shape[0]

        all_results = {}
        for start in tqdm(range(0, num_queries, 128), desc="Scoring queries in chunks"):
            end = min(start + 128, num_queries)
            query_chunk = query_matrix[start:end]
            scores_chunk = query_chunk @ self.optimized_index.T  # shape: [batch_size, num_passages]

            topk_idx = np.argpartition(-scores_chunk, 1000, axis=1)[:, :1000]
            for i in range(scores_chunk.shape[0]):
                passage_scores = scores_chunk[i]
                top_i = topk_idx[i]
                top_indices = top_i[np.argsort(-passage_scores[top_i])]

                all_results[query_ids[start + i]] = [(self.passage_ids[t], passage_scores[t]) for t in top_indices]

        return all_results

In [16]:
'''
    - Create the inverted index with the vector quantized indices
'''
inverted_index_handler = InvertedIndexHandler(embedding_processor = embedding_processor)

# Code indices of the passage set that you are working on
# Changes based on the mode (dev/train)
obj = inverted_index_handler.create_inverted_index(code_indices = code_indices, mode = 'dev')

asdflkasd;la
../embeddings/train/passage_embeddings.pt
Loading cached embeddings & ids...
asdflkasd;la
../embeddings/dev/passage_embeddings.pt
Loading cached embeddings & ids...


Building passage matrix: 100%|██████████| 7433/7433 [00:00<00:00, 24249.77it/s]


In [17]:
# import numpy as np

# all_passage_ids = np.concatenate([postings[0] for postings in inverted_index_handler.optimized_index.values()])
# num_unique_passages = len(np.unique(all_passage_ids))
# print("Unique passage IDs:", num_unique_passages)
inverted_index_handler.optimized_index.shape
# len(inverted_index_handler.optimized_index)

(7433, 10)

In [18]:
from tqdm import tqdm
from ranx import Qrels, Run, evaluate

class MetricsGenerator:
    def __init__(self, inverted_index_handler, embedding_processor, qrels):
        self.inverted_index_handler = inverted_index_handler

        self.embedding_processor = embedding_processor
        train_mapppings = self.embedding_processor.load_or_save_embeddings(mode = 'train')['mappings']
        dev_mappings = self.embedding_processor.load_or_save_embeddings(mode = 'dev')['mappings']

        self.train_query_ids = train_mapppings['query_ids']
        self.dev_query_ids = dev_mappings['query_ids']

        self.train_qrels = qrels['train']
        self.dev_qrels = qrels['dev']
        self.query_matrix = None
        return

    def build_query_matrix(self, query_code_indices):
      vocab_size = 10
      self.query_matrix = np.zeros((len(query_code_indices), vocab_size), dtype=np.float32)

      for i, codes in enumerate(query_code_indices):
          weights = Counter(codes)
          for code, freq in weights.items():
              self.query_matrix[i, code] = freq

      return self.query_matrix

    def batch_score_queries(self, top_k = 1000):
      scores_matrix = self.query_matrix @ self.inverted_index.T  # shape: [num_queries, num_passages]
      top_k_indices = np.argpartition(-scores_matrix, top_k, axis=1)[:, :top_k]

      results = {}
      for i in range(scores_matrix.shape[0]):
          row = scores_matrix[i]
          top_idx = top_k_indices[i]
          sorted_idx = top_idx[np.argsort(-row[top_idx])]
          results.append(list(zip(sorted_idx, row[sorted_idx])))

      return results

    def get_metrics(self, code_indices, mode = 'dev', batch_size = 128):
        query_ids = self.dev_query_ids
        q_rels = self.dev_qrels
        if mode == 'train':
            query_ids = self.train_query_ids
            qrels = self.train_qrels
        elif mode != 'dev':
            raise NotImplementedError(f"Metrics calculator not implemented for {mode}!")

        # all_results = {}
        # num_queries = len(query_ids)
        # for i in tqdm(range(0, num_queries), desc = "Evaluating queries"):
        #     code_index_list = code_indices[i].tolist()
        #     search_results = self.inverted_index_handler.search_inverted_index(code_index_list)
        #     all_results[query_ids[i]] = search_results

        all_results = self.inverted_index_handler.search_inverted_index(self.query_matrix, query_ids)

        # Create rank_eval Run and Qrels objects
        run_dict = {}
        for qid, results in all_results.items():
            run_dict[qid] = {
                str(passage_id): float(score)
                for passage_id, score in results
            }
        run = Run(run_dict)

        qrels_dict = {
            qid: {str(passage_id): 1 for passage_id in q_rels[qid]}
            for qid in q_rels
        }
        qrels = Qrels(qrels_dict)

        # Evaluate using rankx
        metrics = ["ndcg@10", "ndcg@100", "ndcg@1000", "recall@10", "recall@100", "recall@1000", "mrr@10"]
        results = evaluate(qrels, run, metrics)

        return (
            results["mrr@10"],
            {
                '10': results["ndcg@10"],
                '100': results["ndcg@100"],
                '1000': results["ndcg@1000"]
            },
            {
                '10': results["recall@10"],
                '100': results["recall@100"],
                '1000': results["recall@1000"]
            }
        )

In [19]:
'''
    - Generate Metrics using the inverted index built
'''

# Pass qrels to calculate the metrics
qrels = {
    'train': qrels_train,
    'dev': qrels_dev
}

# Get the code indices for each query in the development set
# These code indices are used to calculate the scores and metrics
query_code_indices = vq_hanlder.inference(type = 'query', mode = 'dev')

metrics_generator = MetricsGenerator(inverted_index_handler = inverted_index_handler, embedding_processor = embedding_processor, qrels = qrels)
metrics_generator.build_query_matrix(query_code_indices)
results = metrics_generator.get_metrics(code_indices = query_code_indices)

Vector quantizing...: 100%|██████████| 14/14 [00:00<00:00, 1017.70it/s]


asdflkasd;la
../embeddings/train/passage_embeddings.pt
Loading cached embeddings & ids...
asdflkasd;la
../embeddings/dev/passage_embeddings.pt
Loading cached embeddings & ids...


Scoring queries in chunks: 100%|██████████| 55/55 [00:02<00:00, 19.16it/s]
  scores[i] = _ndcg(qrels[i], run[i], k, rel_lvl, jarvelin)


In [20]:
query_code_indices

tensor([[6, 1, 4,  ..., 8, 3, 2],
        [7, 4, 4,  ..., 8, 4, 7],
        [3, 0, 2,  ..., 8, 1, 2],
        ...,
        [7, 1, 4,  ..., 8, 0, 7],
        [7, 4, 0,  ..., 8, 0, 0],
        [0, 1, 7,  ..., 8, 0, 4]], device='cuda:0')

In [21]:
query_code_indices.shape, code_indices.shape

(torch.Size([6980, 128]), (7433, 128))

In [22]:
results

(np.float64(0.0004196229590212398),
 {'10': np.float64(0.000595835526254673),
  '100': np.float64(0.0029155282810198815),
  '1000': np.float64(0.01688858877149606)},
 {'10': np.float64(0.001325214899713467),
  '100': np.float64(0.013861031518624643),
  '1000': np.float64(0.13527936962750717)})

In [64]:
results

(np.float64(0.021673773138672848),
 {'10': np.float64(0.02566012687020493),
  '100': np.float64(0.037404539590592756),
  '1000': np.float64(0.05318459512457225)},
 {'10': np.float64(0.040568290353390636),
  '100': np.float64(0.1004297994269341),
  '1000': np.float64(0.23006208213944604)})

In [22]:
results

(0.0004196229590212398,
 {'10': 0.0006232282141076039,
  '100': 0.0028432873958526925,
  '1000': 0.01679376817421632},
 {'10': 0.0013610315186246419,
  '100': 0.013347659980897802,
  '1000': 0.1346585482330468})

In [85]:
results

(0.0004196229590212398,
 {'10': 0.0006232282141076039,
  '100': 0.0028432873958526925,
  '1000': 0.01679376817421632},
 {'10': 0.0013610315186246419,
  '100': 0.013347659980897802,
  '1000': 0.1346585482330468})