Skip to content

structlearning/corgii

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CoRGII: Contextual Tokenization for Graph Inverted Indices

This readme files provides details for installation, training, inference and retrieval. It also provides a basic walkthrough of the code.

Requirements

For each of the baselines (FAISS, DiskANN), please consult their respective repositories on installing them and their requirements. Follow the instructions given in requirements.txt in the main directory to create a bare-bones environment. You may modify a requirement according to your system configuration.

Note that GPU access is required for timely training and evaluation of $\text{CoRGII}$. FAISS and DiskANN do not mandate GPU.

Code Walkthrough

Below is an overview of the sparse_retrieval module of the codebase.

sparse_retrieval/
├── common.py
├── embed_dump.py
├── __init__.py
├── retr_utils.py
├── run.py
├── sparsifiers
│   ├── base.py
│   ├── __init__.py
│   ├── neural.py
└── training
    ├── base.py
    ├── __init__.py
    ├── models.py
    ├── score_postings_list.py
    ├── train_postings_list.py
    └── train.py

common.py: This file contains utility functions which are used in multiple places throughout the module. Some of the utilities include code to set the internal task name and the logfile name, model initialization, and generating the transport plans for the pre-trained backbone.

embed_dump.py: This file contains code relevant to the pre-trained backbone, and generating the continuous embeddings.

retr_utils.py: This file contains utilities for retrieval, specifically for $\text{S}_{\text{unif}}$ and its variations. It contains utility function for thresholded retrieval, as seen in for e.g. Figure 7 of the paper.

run.py: The entry file for performing thresholded retrieval for $\text{S}_{\text{unif}}$. Use this to generate data for uniform aggregation.

sparsifiers/base.py: This contains the base class for inverted indexing and probing the index. The index generation process is relevant to both $\text{S}{\text{unif}}$ and $\text{S}{\text{impact}}$ from the paper. The probing/searching code in this file is relevant to $\text{S}_{\text{unif}}$ only.

sparsifiers/neural.py: Contains the subclass for neural indexing, i.e., functions for generating corpus and query bitcodes using our neural models.

training/base.py: Contains the base class with common code for training of $\text{GTNet}$ and $\text{Impact}_{\psi}$ networks. Includes utilities for scoring, validation, early stopping, etc.

training/train_postings_list.py: Contains the subclass specific to the training of $\text{Impact}{\psi}$ network. Provided certain dump files are available (the bitcodes and the posting lists), use this to train for $\text{S}{\text{impact}}$ on various margins.

training/score_postings_list.py: A separate submodule pertaining to the score functions that are used to train on $\text{S}_{\text{impact}}$ schemes. Includes for e.g. co-occurrence multiprobe training.

training/train.py: Contains subclass specific to training of $\text{GTNet}$. Includes facilities for scalable training in end-to-end style as dictated by $\text{CoRGII}$ training protocol. Use this to train $\text{GTNet}$, which can then be used in further stages of training.

Further, below is an overview of the utilities used to generate plots for the paper.

main_notebooks/
├── __init__.py
├── notebook_eval_funcs.py
├── notebook_functions.py
└── notebook_plot_funcs.py

notebook_eval_funcs.py: Contains functions relevant to evaluating and generating data for $\text{S}_{\text{impact}}$ retrieval schemes. This includes both Hamming multiprobe and co-occurrence multiprobe.

notebook_functions.py: Contains functions that are common to plotting and generation. This includes fetching dumped data from the pre-trained backbone, configuration management and data loading for $\text{S}_{\text{unif}}$.

notebook_plot_funcs.py: Contains utilities for plotting the figures in the main and supplementary material, as well as their legends.

These modules are generally imported in the given notebooks in autoreload fashion.

All the modules listed above import from certain internal utility modules in the ./utils directory, including utils.utils.py, utils.model_utils.py, utils.data_utils.py, uitls.training_utils.py and utils.dataset_loader.py.

Command Examples

Training $\text{GTNet}$

The following commands can be used to train $\text{GTNet}$.

python -m sparse_retrieval.training.train dataset.rel_mode=sub_iso dataset.name=ptc_fr model.name=NANL log_level=INFO hashing.hcode_dim=10 seed=42 post_process.mlp=nanl_container post_process.hidden_channels=[64] dataset.data_type=gmn_1d training.batch_size=50000 post_process.to_mask=1 post_process.regularizer=none post_process.scoring_style=continuous post_process.reg_multiplier=none post_process.reg_lambda=0.0 post_process.z_training=min_ranking_loss post_process.rl_margin=10.0 task.wandb_project=GHASH_SPARSE_12thMarch training.wandb_watch=True training.patience=30 training.scalable=True &

python -m sparse_retrieval.training.train dataset.rel_mode=sub_iso dataset.name=ptc_fr model.name=NANL log_level=INFO hashing.hcode_dim=10 seed=42 post_process.mlp=nanl_container post_process.hidden_channels=[64] dataset.data_type=gmn_1d training.batch_size=50000 post_process.to_mask=1 post_process.regularizer=none post_process.scoring_style=continuous post_process.reg_multiplier=none post_process.reg_lambda=0.0 post_process.z_training=min_ranking_loss post_process.rl_margin=30.0 task.wandb_project=GHASH_SPARSE_12thMarch training.wandb_watch=True training.patience=30 training.scalable=True &

Change the dataset name to one of ptc_fm, ptc_mr or cox2 to train on the other datasets.

Training $\text{Impact}{\psi}$ for $\text{S}{\text{impact}}$

python -m sparse_retrieval.training.train_postings_list dataset.rel_mode=sub_iso dataset.name=ptc_fr model.name=NANL log_level=INFO hashing.hcode_dim=10 seed=42 post_process.mlp=nanl_container post_process.hidden_channels=[64] dataset.data_type=idx training.batch_size=3000 post_process.to_mask=1 post_process.regularizer=none post_process.scoring_style=continuous post_process.reg_multiplier=none post_process.reg_lambda=0.0 post_process.z_training=min_ranking_loss post_process.rl_margin=10.0 post_process.pl_training=one_one post_process.pl_margin=0.01 post_process.pl_hidden_dim=64 task.wandb_project=GHASH_SPARSE_PLTRAIN training.wandb_watch=True &

python -m sparse_retrieval.training.train_postings_list dataset.rel_mode=sub_iso dataset.name=ptc_fr model.name=NANL log_level=INFO hashing.hcode_dim=10 seed=42 post_process.mlp=nanl_container post_process.hidden_channels=[64] dataset.data_type=idx training.batch_size=3000 post_process.to_mask=1 post_process.regularizer=none post_process.scoring_style=continuous post_process.reg_multiplier=none post_process.reg_lambda=0.0 post_process.z_training=min_ranking_loss post_process.rl_margin=10.0 post_process.pl_training=one_one post_process.pl_margin=0.1 post_process.pl_hidden_dim=64 task.wandb_project=GHASH_SPARSE_PLTRAIN training.wandb_watch=True &

python -m sparse_retrieval.training.train_postings_list dataset.rel_mode=sub_iso dataset.name=ptc_fr model.name=NANL log_level=INFO hashing.hcode_dim=10 seed=42 post_process.mlp=nanl_container post_process.hidden_channels=[64] dataset.data_type=idx training.batch_size=3000 post_process.to_mask=1 post_process.regularizer=none post_process.scoring_style=continuous post_process.reg_multiplier=none post_process.reg_lambda=0.0 post_process.z_training=min_ranking_loss post_process.rl_margin=10.0 post_process.pl_training=one_one post_process.pl_margin=1.0 post_process.pl_hidden_dim=64 task.wandb_project=GHASH_SPARSE_PLTRAIN training.wandb_watch=True &

The above commands train on three different impact margins for ptc_fr. Modify different hyperparameters such as dataset name accordingly.

Thresholded retrieval for $\text{S}_{\text{unif}}$

python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True &

python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=1 &

python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=2 &

python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=3 &

python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=5 &

python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=7 &

python -m sparse_retrieval.run dataset.rel_mode="sub_iso" dataset.name="cox2" model.name="NANL" log_level="INFO" plot=0 top_n=5 gpu=1 standardize=0 hashing.hcode_dim=10 seed=42 retrieval="focused" post_process.mlp="mlp_container" post_process.hidden_channels="[64]" dataset.data_type="idx" training.batch_size=10000 post_process.to_mask=1 post_process.regularizer="none" post_process.scoring_style="continuous" post_process.reg_multiplier="v_soft" post_process.reg_lambda=0.0 post_process.z_training="min_ranking_loss" post_process.rl_margin=30.0 post_process.pt_aggr="threshold" post_process.pt_aggr_threshold_stats=True post_process.pt_bit_dist=10 &

The above commands generate data for thresholded retrieval in the case of uniform aggregation. Each of the Hamming expansions from the paper is also included.

Note: Examples of thresholded retrieval for $\text{S}_{\text{impact}}$ are included in the notebooks. Further, each of the baseline evaluations is detailed in its own notebook.

Miscellaneous

Contact: pritish at cse dot iitb dot ac dot in

Link to assets (checkpoints, datasets): https://rebrand.ly/corgii

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published