# LibVQ 
Here is a demo to use LibVQ to generate PQ for [SPTAG](https://github.com/microsoft/SPTAG).

## Install
```bash
git clone https://github.com/staoxiao/LibVQ.git
cd LibVQ
pip install .
```

## Overview
There are tow modes to train the PQ:
- **LearnableIndex**: train the codebooks with fixed embeddings
- **LearnableIndexWithEncoder**: jointly train the codebooks and query/doc encoder


## Prepare
### Download data
We take the MSMARCO dataset as an example

In [1]:
! bash download_data.sh

--2022-06-12 12:41:33--  https://rocketqa.bj.bcebos.com/corpus/marco.tar.gz
Resolving rocketqa.bj.bcebos.com (rocketqa.bj.bcebos.com)... 103.235.46.61, 2409:8c04:1001:1002:0:ff:b001:368a
Connecting to rocketqa.bj.bcebos.com (rocketqa.bj.bcebos.com)|103.235.46.61|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1140510742 (1.1G) [application/x-gzip]
Saving to: ‘marco.tar.gz’


2022-06-12 12:43:20 (10.3 MB/s) - ‘marco.tar.gz’ saved [1140510742/1140510742]

--2022-06-12 12:45:02--  https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz
Resolving msmarco.blob.core.windows.net (msmarco.blob.core.windows.net)... 20.150.34.4
Connecting to msmarco.blob.core.windows.net (msmarco.blob.core.windows.net)|20.150.34.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1057717952 (1009M) [application/gzip]
Saving to: ‘collectionandqueries.tar.gz’


2022-06-12 12:45:57 (19.0 MB/s) - ‘collectionandqueries.tar.gz’ saved [105771795

### Generate Embeddings  

In [3]:
preprocess_dir = './data/preprocessed_dataset'
embedding_dir = './Results/ARG'
max_query_length = 32
max_doc_length = 256

# preprocess
from LibVQ.dataset.preprocess import preprocess_data
from transformers import AutoTokenizer
preprocess_data(data_dir='./data/dataset/',
                output_dir=preprocess_dir,
                text_tokenizer=AutoTokenizer.from_pretrained('Shitao/msmarco_query_encoder'),
                add_cls_tokens=True,
                max_doc_length=max_doc_length,
                max_query_length=max_query_length,
                workers_num=64)


import numpy as np
from LibVQ.inference import get_embeddings
from LibVQ.models import Encoder, TransformerModel

# Load encoder
query_encoder = TransformerModel.from_pretrained('Shitao/msmarco_query_encoder')
doc_encoder = TransformerModel.from_pretrained('Shitao/msmarco_doc_encoder')
text_encoder = Encoder(query_encoder, doc_encoder)
emb_size = query_encoder.encoder.config.hidden_size

# generate embeddings
doc_embeddings, dev_query, train_query = get_embeddings(data_dir=preprocess_dir,
               encoder=text_encoder,
               max_doc_length=max_doc_length,
               max_query_length=max_query_length,
               output_dir=embedding_dir,
               batch_size=10240)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 864/864 [29:10<00:00,  2.03s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.06it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 79/79 [01:03<00:00,  1.25it/s]


In [7]:
# convert to the format for SPTAG
import struct, os
from LibVQ.dataset.dataset import load_rel

os.makedirs('input_for_SPTAG', exist_ok=True)
with open(os.path.join('input_for_SPTAG', 'doc_vectors.bin'), 'wb') as f:
    f.write(struct.pack('ii', doc_embeddings.shape[0], doc_embeddings.shape[1] ))
    f.write(doc_embeddings.tobytes())

with open(os.path.join('input_for_SPTAG', 'query_vectors.bin'), 'wb') as f:
    f.write(struct.pack('ii', dev_query.shape[0], dev_query.shape[1] ))
    f.write(dev_query.tobytes())

data = load_rel('./data/preprocessed_dataset/dev-rels.tsv')
max_len = max([len(x) for x in data.values()])
with open('input_for_SPTAG/test_rels.txt', 'w') as f:
    for i in range(len(data)):
        ns = list(data[i])
        ns = ns + [-1]*(max_len-len(ns))
        f.write(' '.join([str(x) for x in ns]) + '\n')

dev-rels.tsv: 7437it [00:00, 784296.46it/s]


## Training of PQ

In [1]:
# settings:
index_method = 'opq'
ivf_centers_num = -1
subvector_num = 32
subvector_bits = 8

## Faiss PQ

In [5]:
import faiss
from LibVQ.utils import save_to_SPTAG_binary_file
from LibVQ.base_index import FaissIndex


# Creat Faiss index
faiss.omp_set_num_threads(32)
index = FaissIndex(index_method=index_method,
                   emb_size=len(doc_embeddings[0]),
                   ivf_centers_num=ivf_centers_num,
                   subvector_num=subvector_num,
                   subvector_bits=subvector_bits,
                   dist_mode='ip')

index.fit(doc_embeddings)
index.add(doc_embeddings)

os.makedirs('input_for_SPTAG/OPQ', exist_ok=True)
save_to_SPTAG_binary_file(index, save_dir='input_for_SPTAG/OPQ')

### Distill-VQ

In [4]:
import os
import pickle
import gc
from torch.optim import AdamW
from LibVQ.dataset.dataset import load_rel, write_rel
from LibVQ.learnable_index import LearnableIndex

faiss.omp_set_num_threads(32)

doc_embeddings_file = os.path.join(embedding_dir, 'docs.memmap')
query_embeddings_file = os.path.join(embedding_dir, 'train-queries.memmap')
init_index_file = os.path.join(embedding_dir, f'{index_method}_ivf{ivf_centers_num}_pq{subvector_num}x{subvector_bits}.index')
save_ckpt_dir = f'./saved_ckpts/distill-VQ/'

# Load embeddings of train queries
train_query_embeddings = np.memmap(query_embeddings_file, dtype=np.float32, mode="r")
train_query_embeddings = train_query_embeddings.reshape(-1, emb_size)


# Create Index
learnable_index = LearnableIndex(index_method=index_method,
                                 init_index_file=init_index_file,
                                 doc_embeddings=doc_embeddings,
                                 ivf_centers_num=ivf_centers_num,
                                 subvector_num=subvector_num,
                                 subvector_bits=subvector_bits)

    
'''
If there is not relevance data, you can set the rel_file/rel_data to None, and it will automatically generate the data for training.
You also can manually generate the data as following:
        '''
if not os.path.exists(os.path.join(embedding_dir, 'train-virtual_rel.tsv')):
    from LibVQ.dataset.preprocess import generate_virtual_traindata
    generate_virtual_traindata(
            doc_embeddings,
            train_query,
            output_dir = embedding_dir,
            use_gpu=False,
            topk= 400,
            index_method = 'opq',
            subvector_num=32,
            subvector_bits=8,
            dist_mode='ip')

# distill with no label data
learnable_index.fit_with_multi_gpus(rel_file=os.path.join(embedding_dir, 'train-virtual_rel.tsv'),
                                    neg_file=os.path.join(embedding_dir,
                                                          f"train-queries-virtual_hardneg.pickle"),
                                    query_embeddings_file=query_embeddings_file,
                                    doc_embeddings_file=doc_embeddings_file,
                                    emb_size=emb_size,
                                    per_query_neg_num=1,
                                    checkpoint_path=save_ckpt_dir,
                                    logging_steps=100,
                                    per_device_train_batch_size=512,
                                    loss_weight={'encoder_weight': 0.0, 'pq_weight': 1.0,
                                                 'ivf_weight': 0.0},
                                    lr_params={'encoder_lr': 0.0, 'pq_lr': 1e-4, 'ivf_lr': 0.0},
                                    loss_method='distill',
                                    epochs=30)


os.makedirs('input_for_SPTAG/LearnableIndex', exist_ok=True)
save_to_SPTAG_binary_file(learnable_index, save_dir='input_for_SPTAG/LearnableIndex')

## Distill-VQ with Encoder

In [18]:
from LibVQ.learnable_index import LearnableIndexWithEncoder

from transformers import BertModel


save_ckpt_dir = f'./saved_ckpts/distill-VQ-Encoder/'

learnable_index_with_encoder = LearnableIndexWithEncoder(index_method=index_method,
                                 encoder=text_encoder,
                                 init_index_file=init_index_file,
                                 doc_embeddings=doc_embeddings,
                                 ivf_centers_num=ivf_centers_num,
                                 subvector_num=subvector_num,
                                 subvector_bits=subvector_bits)


learnable_index_with_encoder.fit(rel_data=os.path.join(embedding_dir, 'train-virtual_rel.tsv'),
                                neg_data=os.path.join(embedding_dir,
                                                      f"train-queries-virtual_hardneg.pickle"),
                                query_data_dir=preprocess_dir,
                                max_query_length=max_query_length,
                                query_embeddings=query_embeddings_file,
                                doc_embeddings=doc_embeddings_file,
                                emb_size=emb_size,
                                per_query_neg_num=1,
                                checkpoint_path=save_ckpt_dir,
                                logging_steps=100,
                                per_device_train_batch_size=512,
                                loss_weight={'encoder_weight': 1.0, 'pq_weight': 1.0,
                                             'ivf_weight': 0.0},
                                lr_params={'encoder_lr': 1e-5, 'pq_lr': 1e-4, 'ivf_lr': 0.0},
                                loss_method='distill',
                                epochs=10)

os.makedirs('input_for_SPTAG/LearnableIndexWithEncoder', exist_ok=True)

new_query_embeddings = learnable_index_with_encoder.encode(data_dir=preprocess_dir,
                                              prefix='dev-queries',
                                              max_length=max_query_length,
                                              output_dir='input_for_SPTAG/LearnableIndexWithEncoder',
                                              batch_size=8196,
                                              is_query=True,
                                              return_vecs=True
                                              )
    

save_to_SPTAG_binary_file(learnable_index, save_dir='input_for_SPTAG/LearnableIndexWithEncoder')
with open(os.path.join('input_for_SPTAG/LearnableIndexWithEncoder', 'new_query.bin'), 'wb') as f:
    f.write(struct.pack('ii', new_query_embeddings.shape[0], new_query_embeddings.shape[1] ))
    f.write(new_query_embeddings.tobytes())

# Use PQ in SPTAG Index

## Create SPTAG index
- Install the SPTAG following [link](https://github.com/microsoft/SPTAG).
- Update the paths in create.ini
  - VectorPath: the path to quantized_doc.bin,e.g., ./input_for_SPTAG/LearnableIndex/quantized__vectors.bin
  - QuantizerFilePath: the path to parameters (rotate and codebooks), e.g., ./input_for_SPTAG/LearnableIndex/index_parameters.bin
  - IndexDirectory: the path to save the index
- Run: ./ssdservering create.ini

## Search via SPTAG
- Update the paths in search.ini
  - VectorPath: default to None. Given the path to uncompressed doc embeddings, the index will rerank candidates based on the uncompressed embeddings.
  - QuantizerFilePath: the path to parameters (rotate and codebooks), e.g., ./input_for_SPTAG/LearnableIndex/index_parameters.bin
  - IndexDirectory: the path to save the index
  - QueryPath/WarmupPath: the path to query embeddings. Noted that use the new query embeddings if you use the LearnIndexWithEncoder, e.g.,./input_for_SPTAG/LearnableIndexWithEncoder/new_query.bin 
- Run: ./ssdservering search.ini


## Results

Methods | Recall@10 |
------- | ------- |
OPQ | 0.541595 |
LearnableIndex | 0.559169 |
LearnableIndexWithEncoder | 0.580540 |