In [1]:
import numpy as np

import redis
from redis.commands.search.field import VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query

In [2]:
class RedisIndex():
    def __init__(self, conn, prefix, name, dim, create=False):
        self.conn = conn
        self.prefix = prefix
        self.name = name
        self.dim = dim
        
        if create:
            self.create()
        
    def create(self):
        schema = (
            VectorField(
                "$.vec",
                "FLAT",
                {
                    "TYPE": "FLOAT32",
                    "DISTANCE_METRIC": "COSINE",
                    "DIM": self.dim,
                },
                as_name="vector",
            ),
        )
        definition = IndexDefinition(prefix=[self.prefix], index_type=IndexType.JSON)
        res = self.conn.ft(self.name).create_index(
            fields=schema, definition=definition
        )
        
    def info(self):
        return self.conn.ft(self.name).info()
    
    def search(self, vec, n_top=3):
        query = (Query(f'(*)=>[KNN {n_top} @vector $query_vector AS vector_score]')
                 .sort_by('vector_score')
                 .dialect(2)
                )
        query_data = {'query_vector': np.array(vec, dtype=np.float32).tobytes()}
        return self.conn.ft(self.name).search(query, query_data).docs

In [3]:
def dump_to_redis(r, vecs, dim):
    zero_vec = np.zeros(dim)
    for j in range(vecs.shape[0]):
        if not np.array_equal(vecs[j], zero_vec):
            r.json().set(f"track:{j}", "$", {"vec": vecs[j].tolist()})


In [4]:
r = redis.Redis(host="localhost", port=6379)

In [5]:
vecs = np.load("data/als_vecs.npy")
dim = vecs.shape[1]
dump_to_redis(r, vecs, dim)
print(f"n_vecs = {r.dbsize()}")

n_vecs = 67200


In [6]:
index = RedisIndex(r, prefix="track:", 
                   name="idx:tracks", 
                   dim=dim, create=True)

In [7]:
index.search(np.random.rand(dim), 10)

[Document {'id': 'track:34635', 'payload': None, 'vector_score': '0.141877353191', 'json': '{"vec":[5.369509524510985e-12,4.670116290211857e-12,4.510269761837105e-12,4.57807489834261e-12,5.901823300902276e-12,4.8693041786174166e-12,4.7758056190283504e-12,4.15315915339165e-12,5.721811913161901e-12,4.584008953673058e-12,4.14702777326581e-12,5.048150699543674e-12,4.146726365061859e-12,3.957699619416832e-12,5.088988692253382e-12,4.889532789070783e-12,4.746273686573321e-12,4.184715074462275e-12,4.060800740807169e-12,4.979643433711267e-12,4.455613395598634e-12,4.8085693413191236e-12,4.0206601069348036e-12,5.2993690165259455e-12,4.804473225511474e-12,5.217343050645251e-12,4.458408468799302e-12,3.384049984678228e-12,5.242693865842307e-12,4.9427532379520126e-12,4.217365172365373e-12,3.2522500317821997e-12,5.829917711780429e-12,4.210568959467365e-12,6.628225312360625e-12,3.355340311150812e-12,3.543829726199155e-12,3.7956786151649435e-12,4.409130612698098e-12,3.2906806619881213e-12,5.134618424884