From shared reference point [proof of concept](http://localhost:8888/lab/tree/vector_search/Shared%20Reference%20Proof%20of%20Concept.ipynb) we saw we got OK'ish recall using a few thousand reference points and querying with 100. We would like to

1. Increase number of reference points
2. Decrease the number needed at query time

In a system that uses this, this setup would increase recall with the least impact to performance.

This notebook performs a grid search over the possible values.

## Load sentences

Reminder we use minilm encoded sentences from wikipedia, sampled down by 50% due to memory constraints.

In [39]:
import numpy as np

def load_sentences():
    # From
    # https://www.kaggle.com/datasets/softwaredoug/wikipedia-sentences-all-minilm-l6-v2
    with open('wikisent2_all.npz', 'rb') as f:
        wiki_vects = np.load(f)
        wiki_vects = vects['arr_0']
        # vects = np.stack(vects)
        all_normed = (np.linalg.norm(wiki_vects, axis=1) > 0.99) & (np.linalg.norm(wiki_vects, axis=1) < 1.01)
        assert all_normed.all(), "Something is wrong - vectors are not normalized!"

    with open('wikisent2.txt', 'rt') as f:
        wiki_sentences = f.readlines()

    return wiki_sentences, wiki_vects

sentences, vects = load_sentences()

# Shrink by 50% for the RAM savings
sentences = sentences[::2]
vects = vects[::2]
vects.shape, len(sentences)

((3935913, 384), 3935913)

## Build index

As per the proof of concept:

- Function to generate random vectors
- Build index of reference points with dot products back to main vectors

In [40]:
def random_vector(num_dims=768):
    """ Sample a unit vector from a sphere in N dimensions.
    It's actually important this is gaussian
    https://stackoverflow.com/questions/59954810/generate-random-points-on-10-dimensional-unit-sphere
    IE Don't do this
        projection = np.random.random_sample(size=num_dims)
        projection /= np.linalg.norm(projection)
    """
    projection = np.random.normal(size=num_dims)
    projection /= np.linalg.norm(projection)
    return projection

random_vector(num_dims=vects.shape[1])

array([ 0.00666564, -0.0722622 ,  0.0428574 , -0.02779602, -0.01125637,
        0.00735623,  0.00019494,  0.03934172, -0.04226361, -0.00964803,
        0.05044241,  0.02423316, -0.05395501,  0.02432852, -0.01583304,
        0.0223122 , -0.03727536, -0.06231928,  0.02635215, -0.08106241,
        0.00476535,  0.02151036,  0.04897824,  0.05952181,  0.06735643,
       -0.06717623, -0.04716457, -0.07152112, -0.0193187 ,  0.04654273,
       -0.07812368,  0.09264148, -0.0109842 ,  0.00154082,  0.00237992,
        0.00701564,  0.09279916,  0.00235389,  0.03603774, -0.01469911,
        0.00309564,  0.05114861, -0.01276528,  0.03691551, -0.03083302,
       -0.02454237, -0.00770771, -0.01561039,  0.06417661, -0.04109687,
       -0.00830815,  0.06693096, -0.06680911, -0.02094927, -0.01184418,
        0.07182279,  0.01817778,  0.00846186,  0.00025429,  0.03275067,
       -0.05605074, -0.05458292, -0.02311123, -0.01631234, -0.06567901,
        0.10800846, -0.00620358,  0.01366184,  0.008013  ,  0.03

In [44]:
def build_index(vects, num_refs=1000, refs_factory=random_vector):

    refs = np.zeros((num_refs, vects.shape[1]), dtype=np.float32)

    for ref_ord in range(0, num_refs):
        refs[ref_ord] = refs_factory(num_dims=vects.shape[1])
    # Memory gets sucked up here :)
    index = np.dot(vects, refs.T)

    return refs, index

## Search ground truth

Here's the ground truth for the search, using MiniLM (how the vectors are encoded)

In [42]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')

query = "mary had a little lamb"

def search_ground_truth(vects, query, at=10):
    query_vector = model.encode(query)
    nn = np.dot(vects, query_vector)
    top_n = np.argpartition(-nn, at)[:at]
    top_n = top_n[nn[top_n].argsort()[::-1]]
    return sorted(zip(top_n, nn[top_n]),
                  key=lambda scored: scored[1],
                  reverse=True)

gt_ords = set()
for vect_ord, score in search_ground_truth(vects, query):
    gt_ords.add(vect_ord)
    print(vect_ord, score, sentences[vect_ord])

1996387 0.700519 "Mary Had a Little Lamb", who wrote the novel under the name of Sara J. Hale.

1997224 0.6153624 Mary then went into labor.

1418816 0.61314523 It begins with the melody of the popular children's song "Mary Had a Little Lamb" and then cuts into the main riff, punctuated with a high trumpet trill.

1887627 0.5563892 Lamb and is wife, Sara, have one son.

775918 0.54064065 For the more domestic and intimate iconic representations of Mary with the infant Jesus on her lap, see Madonna and Child.

1393341 0.5362675 In this variant he shows the Christ Child in Mary's lap.

611431 0.5288173 Did Jesus Have a Dog?

3447108 0.52027595 The songs exclusive to this release are "Call Me Claus," "Mary Had a Little Lamb," and "'Zat You, Santa Claus?".

2991842 0.5185638 The final two lines detail the former's lamb feast, which resuscitates it.

3120137 0.5133202 The "Lady" is the Virgin Mary.



## (Inefficient) search function

We use the most accurate (though most inefficient) form of the reference point function that gets every vectors dot product to the reference points.

In [43]:
def best_refs(refs, query_vector, num_refs=200):
    dotted = np.dot(refs, query_vector)
    best_ref_ords = np.argsort(-dotted)[:num_refs]
    return best_ref_ords, dotted[best_ref_ords]

def search(index, refs, query, num_refs=200):

    query_vector = model.encode(query)
    
    best_ref_ords, dotted = best_refs(refs, query_vector, num_refs=num_refs)
    
    every_dotted = index[:, best_ref_ords] * dotted
    
    vects_scored = np.sum(every_dotted, axis=1)
    
    best_vect_ords = np.argsort(-vects_scored)[:10]
    dotted = vects_scored[best_vect_ords]

    return list(zip(best_vect_ords, dotted))

refs, index = build_index(vects, num_refs=100)
search(index, refs, query="mary had a litle lamb", num_refs=10)

(3935913, 384)


[(3642364, 0.069736),
 (3748866, 0.06701364),
 (554469, 0.0658752),
 (2430308, 0.065219864),
 (3253087, 0.06410813),
 (114717, 0.06364536),
 (554489, 0.06337762),
 (146384, 0.063208565),
 (2934351, 0.06192884),
 (1606922, 0.061899275)]

## Search over sample of queries

Using a handful of queries lets do a search varying:

* `num_query_refs` - the query time refs to score against the query's vector
* `num_index_refs` - the number of index time refs to use when constructing the index

### Generate ground truths

Get a ground truth for each test query to let us compute recall against

In [6]:
from collections import defaultdict

test_queries = ["what is a cat", "where is spain", "what is the capital of spain", 
"who framed roger rabbit", "free willy", "bed bath and beyond", "hats and stuff", "bed bath beyond",
"do you even paginate bro?", "mary had a little lamb"]

ground_truths = defaultdict(set)
for query in test_queries:
    for vect_ord, score in search_ground_truth(vects, query):
        ground_truths[query].add(vect_ord)


In [None]:
from collections import defaultdict
import pandas as pd
from statistics import mean


def grid_search(refs_factory=random_vector):

    num_search_rounds = 10

    results = []

    search_index_refs = [1500, 1250, 1000, 750, 500, 250]
    for num_index_refs in search_index_refs:
        refs, index = build_index(vects, num_index_refs,
                                  refs_factory=refs_factory)
        for num_query_refs in [10, 20, 30, 40, 100, 200]:
            test_results = defaultdict(set)

            recalls = []
            for query in test_queries:
                query_search_results = search(index, refs, query, num_refs=num_query_refs)
                test_results[query] = set([vect_ord for vect_ord, _ in query_search_results])
                intersection = test_results[query] & ground_truths[query]
                recalls.append(len(intersection) / 10)

            print(num_index_refs, num_query_refs, mean(recalls))
            results.append({'num_index_refs': num_index_refs,
                            'num_query_refs': num_query_refs,
                            'mean': mean(recalls), 
                            'max': max(recalls),
                            'min': min(recalls)})

    return results

pd.DataFrame(grid_search())

1500 10 0.19
1500 20 0.31
1500 30 0.41000000000000003
1500 40 0.45


## Try building refs from text outside the index

As we sampled every other sentence, what happens when we sample sentences not included in the index as our reference points?

In [23]:
_, vects_sample = load_sentences()
del _
np.random.shuffle(vects_sample)
vects_sample = vects_sample[1:20000:2]

def vectors_from_text(num_dims):
    ref_from_vects = np.random.randint(0, len(vects_sample))
    # print(np.array(vects_sample[ref_from_vects]).shape)
    return np.array(vects_sample[ref_from_vects])

grid_search(refs_factory=vectors_from_text)

1500 10 0.0
1500 20 0.0
1500 30 0.0
1500 40 0.0
1500 100 0.0
1500 200 0.0
1250 10 0.0
1250 20 0.0
1250 30 0.0
1250 40 0.0
1250 100 0.0
1250 200 0.0
1000 10 0.0
1000 20 0.0
1000 30 0.0
1000 40 0.0
1000 100 0.0
1000 200 0.0
750 10 0.0
750 20 0.0
750 30 0.0
750 40 0.0
750 100 0.0
750 200 0.0
500 10 0.0
500 20 0.0
500 30 0.0
500 40 0.0
500 100 0.0
500 200 0.0
250 10 0.0
250 20 0.0
250 30 0.0
250 40 0.0
250 100 0.0
250 200 0.0


[{'num_index_refs': 1500,
  'num_query_refs': 10,
  'mean': 0.0,
  'max': 0.0,
  'min': 0.0},
 {'num_index_refs': 1500,
  'num_query_refs': 20,
  'mean': 0.0,
  'max': 0.0,
  'min': 0.0},
 {'num_index_refs': 1500,
  'num_query_refs': 30,
  'mean': 0.0,
  'max': 0.0,
  'min': 0.0},
 {'num_index_refs': 1500,
  'num_query_refs': 40,
  'mean': 0.0,
  'max': 0.0,
  'min': 0.0},
 {'num_index_refs': 1500,
  'num_query_refs': 100,
  'mean': 0.0,
  'max': 0.0,
  'min': 0.0},
 {'num_index_refs': 1500,
  'num_query_refs': 200,
  'mean': 0.0,
  'max': 0.0,
  'min': 0.0},
 {'num_index_refs': 1250,
  'num_query_refs': 10,
  'mean': 0.0,
  'max': 0.0,
  'min': 0.0},
 {'num_index_refs': 1250,
  'num_query_refs': 20,
  'mean': 0.0,
  'max': 0.0,
  'min': 0.0},
 {'num_index_refs': 1250,
  'num_query_refs': 30,
  'mean': 0.0,
  'max': 0.0,
  'min': 0.0},
 {'num_index_refs': 1250,
  'num_query_refs': 40,
  'mean': 0.0,
  'max': 0.0,
  'min': 0.0},
 {'num_index_refs': 1250,
  'num_query_refs': 100,
  'mean

In [14]:
vects_sample.shape

(10000, 384)