This readme files provides details for installation, training, inference and retrieval. It also provides a basic walkthrough of the code.
For each of the baselines (FAISS, DiskANN), please consult their respective repositories on installing them and their requirements. Follow the instructions given in requirements.txt in the main directory to create a bare-bones environment. You may modify a requirement according to your system configuration.
Note that GPU access is required for timely training and evaluation of
Below is an overview of the sparse_retrieval module of the codebase.
sparse_retrieval/
├── common.py
├── embed_dump.py
├── __init__.py
├── retr_utils.py
├── run.py
├── sparsifiers
│ ├── base.py
│ ├── __init__.py
│ ├── neural.py
└── training
├── base.py
├── __init__.py
├── models.py
├── score_postings_list.py
├── train_postings_list.py
└── train.py
common.py: This file contains utility functions which are used in multiple places throughout the module. Some of the utilities include code to set the internal task name and the logfile name, model initialization, and generating the transport plans for the pre-trained backbone.
embed_dump.py: This file contains code relevant to the pre-trained backbone, and generating the continuous embeddings.
retr_utils.py: This file contains utilities for retrieval, specifically for
run.py: The entry file for performing thresholded retrieval for
sparsifiers/base.py: This contains the base class for inverted indexing and probing the index. The index generation process is relevant to both $\text{S}{\text{unif}}$ and $\text{S}{\text{impact}}$ from the paper. The probing/searching code in this file is relevant to
sparsifiers/neural.py: Contains the subclass for neural indexing, i.e., functions for generating corpus and query bitcodes using our neural models.
training/base.py: Contains the base class with common code for training of
training/train_postings_list.py: Contains the subclass specific to the training of $\text{Impact}{\psi}$ network. Provided certain dump files are available (the bitcodes and the posting lists), use this to train for $\text{S}{\text{impact}}$ on various margins.
training/score_postings_list.py: A separate submodule pertaining to the score functions that are used to train on
training/train.py: Contains subclass specific to training of
Further, below is an overview of the utilities used to generate plots for the paper.
main_notebooks/
├── __init__.py
├── notebook_eval_funcs.py
├── notebook_functions.py
└── notebook_plot_funcs.py
notebook_eval_funcs.py: Contains functions relevant to evaluating and generating data for
notebook_functions.py: Contains functions that are common to plotting and generation. This includes fetching dumped data from the pre-trained backbone, configuration management and data loading for
notebook_plot_funcs.py: Contains utilities for plotting the figures in the main and supplementary material, as well as their legends.
These modules are generally imported in the given notebooks in autoreload fashion.
All the modules listed above import from certain internal utility modules in the ./utils directory, including utils.utils.py, utils.model_utils.py, utils.data_utils.py, uitls.training_utils.py and utils.dataset_loader.py.
The following commands can be used to train
python -m sparse_retrieval.training.train dataset.rel_mode=sub_iso dataset.name=ptc_fr model.name=NANL log_level=INFO hashing.hcode_dim=10 seed=42 post_process.mlp=nanl_container post_process.hidden_channels=[64] dataset.data_type=gmn_1d training.batch_size=50000 post_process.to_mask=1 post_process.regularizer=none post_process.scoring_style=continuous post_process.reg_multiplier=none post_process.reg_lambda=0.0 post_process.z_training=min_ranking_loss post_process.rl_margin=10.0 task.wandb_project=GHASH_SPARSE_12thMarch training.wandb_watch=True training.patience=30 training.scalable=True &
python -m sparse_retrieval.training.train dataset.rel_mode=sub_iso dataset.name=ptc_fr model.name=NANL log_level=INFO hashing.hcode_dim=10 seed=42 post_process.mlp=nanl_container post_process.hidden_channels=[64] dataset.data_type=gmn_1d training.batch_size=50000 post_process.to_mask=1 post_process.regularizer=none post_process.scoring_style=continuous post_process.reg_multiplier=none post_process.reg_lambda=0.0 post_process.z_training=min_ranking_loss post_process.rl_margin=30.0 task.wandb_project=GHASH_SPARSE_12thMarch training.wandb_watch=True training.patience=30 training.scalable=True &
Change the dataset name to one of ptc_fm, ptc_mr or cox2 to train on the other datasets.
python -m sparse_retrieval.training.train_postings_list dataset.rel_mode=sub_iso dataset.name=ptc_fr model.name=NANL log_level=INFO hashing.hcode_dim=10 seed=42 post_process.mlp=nanl_container post_process.hidden_channels=[64] dataset.data_type=idx training.batch_size=3000 post_process.to_mask=1 post_process.regularizer=none post_process.scoring_style=continuous post_process.reg_multiplier=none post_process.reg_lambda=0.0 post_process.z_training=min_ranking_loss post_process.rl_margin=10.0 post_process.pl_training=one_one post_process.pl_margin=0.01 post_process.pl_hidden_dim=64 task.wandb_project=GHASH_SPARSE_PLTRAIN training.wandb_watch=True &
python -m sparse_retrieval.training.train_postings_list dataset.rel_mode=sub_iso dataset.name=ptc_fr model.name=NANL log_level=INFO hashing.hcode_dim=10 seed=42 post_process.mlp=nanl_container post_process.hidden_channels=[64] dataset.data_type=idx training.batch_size=3000 post_process.to_mask=1 post_process.regularizer=none post_process.scoring_style=continuous post_process.reg_multiplier=none post_process.reg_lambda=0.0 post_process.z_training=min_ranking_loss post_process.rl_margin=10.0 post_process.pl_training=one_one post_process.pl_margin=0.1 post_process.pl_hidden_dim=64 task.wandb_project=GHASH_SPARSE_PLTRAIN training.wandb_watch=True &
python -m sparse_retrieval.training.train_postings_list dataset.rel_mode=sub_iso dataset.name=ptc_fr model.name=NANL log_level=INFO hashing.hcode_dim=10 seed=42 post_process.mlp=nanl_container post_process.hidden_channels=[64] dataset.data_type=idx training.batch_size=3000 post_process.to_mask=1 post_process.regularizer=none post_process.scoring_style=continuous post_process.reg_multiplier=none post_process.reg_lambda=0.0 post_process.z_training=min_ranking_loss post_process.rl_margin=10.0 post_process.pl_training=one_one post_process.pl_margin=1.0 post_process.pl_hidden_dim=64 task.wandb_project=GHASH_SPARSE_PLTRAIN training.wandb_watch=True &
The above commands train on three different impact margins for ptc_fr. Modify different hyperparameters such as dataset name accordingly.
python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True &
python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=1 &
python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=2 &
python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=3 &
python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=5 &
python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=7 &
python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=10 &
The above commands generate data for thresholded retrieval in the case of uniform aggregation. Each of the Hamming expansions from the paper is also included.
Note: Examples of thresholded retrieval for
Contact: pritish at cse dot iitb dot ac dot in
Link to assets (checkpoints, datasets): https://rebrand.ly/corgii