# LibVQ 
Here is a demo to use LibVQ to generate PQ for [Faiss](https://github.com/facebookresearch/faiss).

## Install
```bash
git clone https://github.com/staoxiao/LibVQ.git
cd LibVQ
pip install .
```

## Overview
There are tow modes to train the PQ:
- **LearnableIndex**: train the codebooks with fixed embeddings
- **LearnableIndexWithEncoder**: jointly train the codebooks and query/doc encoder


## Prepare
### Data for LearnableIndex

If you want to only train the index, you should provide the embeddings and relevance file:

- query_embeddings and doc_embeddings matrixs, which are numpy.array or saved in `.npy`.
- rel_data and neg_data, whose format is `{query_id: [doc1_id, doc2_id,...]}` (id means the`id`-th row in the embeddings matrix). If not provide neg_data,
we will randomly sample negatives form the corpus.

More information please refer to [Embeddings and Relevance label](https://github.com/staoxiao/LibVQ/blob/master/LibVQ/dataset/README.md).



### Data for LearnableIndexWithEncoder
If you want to train the index and encoder jointly, you should provide the raw text data besides the embeddings and relevance file.
We take the MSMARCO dataset as an example to show the data preprocess workflow:

**1. Download data**

For other datasets, you can prepare them following [here](https://github.com/staoxiao/LibVQ/blob/master/LibVQ/dataset/README.md)

In [1]:
! bash download_data.sh

--2022-07-02 14:55:02--  https://rocketqa.bj.bcebos.com/corpus/marco.tar.gz
Resolving rocketqa.bj.bcebos.com (rocketqa.bj.bcebos.com)... 103.235.46.61, 2409:8c04:1001:1002:0:ff:b001:368a
Connecting to rocketqa.bj.bcebos.com (rocketqa.bj.bcebos.com)|103.235.46.61|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1140510742 (1.1G) [application/x-gzip]
Saving to: ‘marco.tar.gz’


2022-07-02 14:57:32 (7.36 MB/s) - ‘marco.tar.gz’ saved [1140510742/1140510742]

--2022-07-02 14:59:09--  https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz
Resolving msmarco.blob.core.windows.net (msmarco.blob.core.windows.net)... 20.150.34.4
Connecting to msmarco.blob.core.windows.net (msmarco.blob.core.windows.net)|20.150.34.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1057717952 (1009M) [application/gzip]
Saving to: ‘collectionandqueries.tar.gz’


2022-07-02 14:59:44 (30.3 MB/s) - ‘collectionandqueries.tar.gz’ saved [105771795

### Preprocess data with your tokenizer
text_tokenizer should be a tokenizer class inherits from the PreTrainedTokenizer in huggingface transformers.
For example, 
- BertTokenizer.from_pretrained('bert-uncased-base')
- BertTokenizer(your_vocab_file)

Noted that the `add_cls_tokens=True` will add the special `[CLS]` token in the sequence.

In [None]:
preprocess_dir = './data/preprocessed_dataset'
embedding_dir = './Results/ARG'
max_query_length = 32
max_doc_length = 256

# preprocess
from LibVQ.dataset.preprocess import preprocess_data
from transformers import AutoTokenizer, AutoConfig
preprocess_data(data_dir='./data/dataset/',
                output_dir=preprocess_dir,
                text_tokenizer=AutoTokenizer.from_pretrained('Shitao/msmarco_query_encoder'),
                add_cls_tokens=True,
                max_doc_length=max_doc_length,
                max_query_length=max_query_length,
                workers_num=64)

### Generate Embeddings with your encoder
Mention your own encoder model in a class wihch has token sequence as input and output the sentence embedding.

```python
class YourCustomDEModel:
    def forward(input_ids, attention_mask):
        return embeddings
```
In LibVQ, we implement a simple encoder as TransformerModel.

In [2]:
import numpy as np
from LibVQ.inference import get_embeddings
from LibVQ.models import Encoder, TransformerModel

# Load encoder
query_encoder = TransformerModel.from_pretrained('Shitao/msmarco_query_encoder')
doc_encoder = TransformerModel.from_pretrained('Shitao/msmarco_doc_encoder')
text_encoder = Encoder(query_encoder, doc_encoder)
emb_size = query_encoder.encoder.config.hidden_size

# generate embeddings
doc_embeddings, dev_query, train_query = get_embeddings(data_dir=preprocess_dir,
               encoder=text_encoder,
               max_doc_length=max_doc_length,
               max_query_length=max_query_length,
               output_dir=embedding_dir,
               batch_size=10240)

start to generate embeddings for corpus


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 864/864 [51:39<00:00,  3.59s/it]


start to generate embeddings for dev queries


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.46it/s]


start to generate embeddings for train queries


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [01:15<00:00,  1.04it/s]


### Generate data for distillation
If there is no labeled data, you can generate the top-k docs for train queries as train data.

You can use a flat index or a opq index (more efficiency but lower accuracy)

In [None]:
from LibVQ.dataset.preprocess import generate_virtual_traindata

generate_virtual_traindata(
        doc_embeddings,
        train_query,
        output_dir = embedding_dir,
        use_gpu=False,
        topk= 400,
        index_method = 'opq',
        subvector_num=32,
        subvector_bits=8,
        dist_mode='ip')

## Training of PQ
Two steps:
- Create a index: LearnableIndex()
- Train the index: LearnableIndex.fit_with_multi_gpus()

In [9]:
# settings:
index_method = 'opq'
ivf_centers_num = -1
subvector_num = 32
subvector_bits = 8

### Distill-VQ

In [17]:
import faiss, os
from LibVQ.dataset.dataset import write_rel
from LibVQ.learnable_index import LearnableIndex

faiss.omp_set_num_threads(32)

doc_embeddings_file = os.path.join(embedding_dir, 'docs.memmap')
query_embeddings_file = os.path.join(embedding_dir, 'train-queries.memmap')
save_ckpt_dir = f'./saved_ckpts/distill-VQ/'


# Create Index
learnable_index = LearnableIndex(index_method=index_method,
                                 doc_embeddings=doc_embeddings,
                                 ivf_centers_num=ivf_centers_num,
                                 subvector_num=subvector_num,
                                 subvector_bits=subvector_bits)

# distill with generated data
learnable_index.fit_with_multi_gpus(rel_file=os.path.join(embedding_dir, 'train-virtual_rel.tsv'),
                                    neg_file=os.path.join(embedding_dir,
                                                          f"train-queries-virtual_hardneg.pickle"),
                                    query_embeddings_file=query_embeddings_file,
                                    doc_embeddings_file=doc_embeddings_file,
                                    emb_size=emb_size,
                                    per_query_neg_num=1,
                                    checkpoint_path=save_ckpt_dir,
                                    logging_steps=100,
                                    per_device_train_batch_size=512,
                                    loss_weight={'encoder_weight': 0.0, 'pq_weight': 1.0,
                                                 'ivf_weight': 0.0},
                                    lr_params={'encoder_lr': 0.0, 'pq_lr': 1e-4, 'ivf_lr': 0.0},
                                    loss_method='distill',
                                    epochs=30)

ground_truths = load_rel(os.path.join(preprocess_dir, 'dev-rels.tsv'))
learnable_index.test(dev_query, ground_truths, topk=1000, batch_size=64,
           MRR_cutoffs=[10, 100], Recall_cutoffs=[10, 30, 50, 100])

dev-rels.tsv: 7437it [00:00, 847153.50it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 110/110 [00:31<00:00,  3.49it/s]


number of query:6980,  searching time per query: 0.004520652458114406
6980 matching queries found
MRR@10:0.35032035975803927
MRR@100:0.3615070309827782
Recall@10:0.6261103151862463
Recall@30:0.7776981852913089
Recall@50:0.8307425978987586
Recall@100:0.8865926456542501


## Distill-VQ with Encoder

In [None]:
from LibVQ.learnable_index import LearnableIndexWithEncoder

save_ckpt_dir = f'./saved_ckpts/distill-VQ-Encoder/'

learnable_index_with_encoder = LearnableIndexWithEncoder(index_method=index_method,
                                 encoder=text_encoder,
                                 doc_embeddings=doc_embeddings,
                                 subvector_num=subvector_num,
                                 subvector_bits=subvector_bits)


learnable_index_with_encoder.fit(rel_data=os.path.join(embedding_dir, 'train-virtual_rel.tsv'),
                                neg_data=os.path.join(embedding_dir,
                                                      f"train-queries-virtual_hardneg.pickle"),
                                query_data_dir=preprocess_dir,
                                max_query_length=max_query_length,
                                query_embeddings=query_embeddings_file,
                                doc_embeddings=doc_embeddings_file,
                                emb_size=emb_size,
                                per_query_neg_num=1,
                                checkpoint_path=save_ckpt_dir,
                                logging_steps=100,
                                per_device_train_batch_size=512,
                                loss_weight={'encoder_weight': 1.0, 'pq_weight': 1.0,
                                             'ivf_weight': 0.0},
                                lr_params={'encoder_lr': 1e-5, 'pq_lr': 1e-4, 'ivf_lr': 0.0},
                                loss_method='distill',
                                epochs=10)

# get new query embeddigns because the encoder is updated in the training
new_query_embeddings = learnable_index_with_encoder.encode(data_dir=preprocess_dir,
                                              prefix='dev-queries',
                                              max_length=max_query_length,
                                              output_dir='LearnableIndexWithEncoder',
                                              batch_size=8196,
                                              is_query=True,
                                              return_vecs=True
                                              )
    

ground_truths = load_rel(os.path.join(preprocess_dir, 'dev-rels.tsv'))
learnable_index_with_encoder.test(dev_query, ground_truths, topk=1000, batch_size=64,
           MRR_cutoffs=[10, 100], Recall_cutoffs=[10, 30, 50, 100])