In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/dessert/

Mounted at /content/drive
/content/drive/MyDrive/dessert


In [2]:
from dessert_pytorch import *

In [3]:
import numpy as np
import random
import torch

# Helper method that returns a tuple of two functions. The first function
# takes no arguments and returns a document retrieval index with all generated
# documents added. The second function takes a document retrieval index
# and queries it with generated queries, asserting that the top result is
# as expected, and also returning all results.
# The general idea for this test is that each word is a normal distribution
# somehwhere in the vector space. A doc is made up of a vector from each
# of words_per_doc normal distributions. A ground truth query is made up of
# some words from a single doc's word distributions and some random words.
def get_build_and_run_functions_random(num_docs=100, num_queries=100):

    hashes_per_table = 7
    num_tables = 32
    data_dim = 100
    vocab_size = 10000
    words_per_doc = 200
    words_per_query_random = 5
    words_per_query_from_doc = 10
    words_per_query = words_per_query_random + words_per_query_from_doc
    between_word_std = 1
    within_word_std = 0.1

    np.random.seed(42)
    random.seed(42)

    # Generate word centers
    word_centers = np.random.normal(size=(vocab_size, data_dim), scale=between_word_std)

    # Generates docs
    doc_word_ids = [
        random.sample(range(vocab_size), words_per_doc) for _ in range(num_docs)
    ]
    doc_offsets = np.random.normal(
        size=(num_docs, words_per_doc, data_dim), scale=within_word_std
    )
    docs = []
    for i in range(num_docs):
        doc = []
        for j in range(words_per_doc):
            doc.append(doc_offsets[i][j] + word_centers[doc_word_ids[i][j]])
        docs.append(doc)

    # Generate queries. GT for query i is doc i
    query_random_word_ids = [
        random.sample(range(vocab_size), words_per_query_random)
        for _ in range(num_queries)
    ]
    query_same_word_ids = [
        ids[:words_per_query_from_doc] for ids in doc_word_ids[:num_queries]
    ]
    query_word_ids = [a + b for a, b in zip(query_same_word_ids, query_random_word_ids)]
    query_offsets = np.random.normal(
        size=(num_queries, words_per_query, data_dim), scale=within_word_std
    )
    queries = []
    for i in range(num_queries):
        query = []
        for j in range(words_per_query):
            query.append(query_offsets[i][j] + word_centers[query_word_ids[i][j]])
        queries.append(query)

    index_func = lambda: _build_index_random(
        docs, hashes_per_table, num_tables, data_dim, word_centers, words_per_doc
    )
    query_func = lambda index: _do_queries_random(index, queries, num_docs)

    return index_func, query_func


def _build_index_random(docs, hashes_per_table, num_tables, data_dim, centroids, words_per_doc):
    index = DocRetrieval(
        hashes_per_table=hashes_per_table,
        num_tables=num_tables,
        dense_input_dimension=data_dim,
        centroids=torch.tensor(centroids),
        max_doc_size=words_per_doc,
        device="cpu" #three options for device, which are "cpu", "cuda:0" (running on just one gpu), and "cuda" (multiple gpus), although, "cuda:0" appears to be faster than "cuda", so far "cuda:0" fastest
    )
    for i, doc in enumerate(docs):
        index.add_doc(doc_id=str(i), doc_embeddings=torch.tensor(doc))
    return index


def _do_queries_random(index, queries, num_docs):
    result = []
    for gt, query in enumerate(queries):
        query_result = index.query(torch.tensor(query), top_k=10, num_to_rerank=10)
        result += query_result
        print(query_result)
        assert int(query_result[0]) == gt
    return result

In [4]:
def test_random_docs():
    index_func, query_func = get_build_and_run_functions_random()
    index = index_func()
    results = query_func(index)

In [5]:
test_random_docs()

  index.add_doc(doc_id=str(i), doc_embeddings=torch.tensor(doc))


['0', '61', '88', '84', '73', '43', '30', '50', '79', '63']
['1', '15', '23', '34', '79', '52', '62', '12', '81', '43']
['2', '51', '84', '60', '47', '19', '54', '16', '56', '98']
['3', '89', '51', '87', '38', '95', '78', '65', '16', '0']
['4', '69', '10', '90', '89', '17', '85', '79', '26', '95']
['5', '98', '71', '40', '61', '14', '33', '26', '38', '82']
['6', '89', '24', '57', '54', '34', '41', '77', '88', '39']
['7', '90', '91', '23', '74', '60', '12', '25', '71', '63']
['8', '23', '31', '41', '42', '15', '26', '71', '60', '20']
['9', '64', '21', '44', '32', '84', '50', '15', '81', '76']
['10', '11', '65', '1', '19', '22', '57', '56', '74', '71']
['11', '78', '58', '8', '70', '54', '72', '20', '76', '92']
['12', '52', '53', '15', '50', '75', '16', '67', '20', '36']
['13', '95', '87', '31', '92', '32', '52', '33', '12', '55']
['14', '77', '28', '90', '24', '54', '57', '7', '70', '23']
['15', '66', '52', '43', '96', '68', '56', '84', '20', '97']
['16', '1', '75', '10', '65', '82', '6