In [None]:
! pip install redis wget pandas openai python-dotenv plotly matplotlib scipy scikit-learn

In [44]:
import os
import openai
from dotenv import load_dotenv
import pandas as pd
import numpy as np

load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY") 
# print(openai.api_key)

In [2]:
# embedding model parameters
embedding_model = "text-embedding-ada-002" # second gen best model at the moment
embedding_encoding = "cl100k_base" # latest tokenizer for second gen models
max_tokens = 8000 # max tokens for second gen models and tokenizer above is 8191

In [4]:
# Load the embedding data
df = pd.read_csv("data/wmd_1452_embeddings.csv")

In [16]:
# only if reading from csv to ensure correct type
df["embedding"] = df.embedding.apply(eval).apply(np.array)

In [27]:
# start redis using the docker-compose file in the same folder
! docker compose up -d

Container redis-vector-db-1  Creating
Container redis-vector-db-1  Created
Container redis-vector-db-1  Starting
Container redis-vector-db-1  Started


In [21]:
# connect to redis
import redis
from redis.commands.search.indexDefinition import (
    IndexDefinition,
    IndexType
)
from redis.commands.search.query import Query
from redis.commands.search.field import (
    TextField,
    VectorField
)

REDIS_HOST =  "redis-12982.c238.us-central1-2.gce.cloud.redislabs.com"
REDIS_PORT = 12982
REDIS_PASSWORD = "YSFpgjRvvWMkBV2HS3UvK3GJw3q4nVsS" # default for passwordless Redis

# Connect to Redis
redis_client = redis.Redis(
    host=REDIS_HOST,
    port=REDIS_PORT,
    password=REDIS_PASSWORD
)
redis_client.ping()

True

In [46]:
# Create a search index in Redis

# Constants
VECTOR_DIM = len(df['embedding'].values[0]) # length of the vectors
VECTOR_NUMBER = len(df[:1000])                 # initial number of vectors
INDEX_NAME = "embeddings-wmd-index"   # name of the search index
PREFIX = "doc"                            # prefix for the document keys
DISTANCE_METRIC = "COSINE"                # distance metric for the vectors (ex. COSINE, IP, L2)

In [47]:
# Define RedisSearch fields for each of the columns in the dataset
topic = TextField(name="topic", weight=1.0)
overview = TextField(name="overview", weight=1.0)
symptoms = TextField(name="symptoms", weight=1.0)
url = TextField(name="url", weight=1.0)
embedding = VectorField("embedding",
                        "FLAT", {
                            "TYPE": "FLOAT32",
                            "DIM": VECTOR_DIM,
                            "DISTANCE_METRIC": DISTANCE_METRIC,
                            "INITIAL_CAP": VECTOR_NUMBER,
                        }
            )
fields = [topic, overview, symptoms, url, embedding]

In [48]:
# Check if index exists
try:
    redis_client.ft(INDEX_NAME).info()
    print("Index already exists")
except:
    # Create RediSearch Index
    redis_client.ft(INDEX_NAME).create_index(
        fields = fields,
        definition = IndexDefinition(prefix=[PREFIX], index_type=IndexType.HASH)
)

Index already exists


In [49]:
def index_documents(client: redis.Redis, prefix: str, documents: pd.DataFrame):
    records = documents.to_dict("records")
    for i, doc in enumerate(records):
        key = f"{prefix}:{i}"
        # key = f"{prefix}:{str(doc['id'])}"

        # create byte vectors for title and content
        embedding = np.array(doc["embedding"], dtype=np.float32).tobytes()
        # title_embedding = np.array(doc["title_vector"], dtype=np.float32).tobytes()
        # content_embedding = np.array(doc["content_vector"], dtype=np.float32).tobytes()

        # replace list of floats with byte vectors
        doc["embedding"] = embedding
        # doc["title_vector"] = title_embedding
        # doc["content_vector"] = content_embedding

        client.hset(key, mapping = doc)

In [51]:
index_documents(redis_client, PREFIX, df[:1000])
print(f"Loaded {redis_client.info()['db0']['keys']} documents in Redis search index with name: {INDEX_NAME}")

Loaded 930 documents in Redis search index with name: embeddings-wmd-index


In [52]:
# Run a search query and return the results
from typing import List
def search_redis(
        redis_client: redis.Redis,
        user_query: str,
        index_name: str = INDEX_NAME,
        vector_field: str = "embedding",
        return_fields: list = ["topic", "overview", "symptoms", "url", "vector_score"],
        hybrid_fields = "*",
        k: int = 20,
        print_results: bool = False,
) -> List[dict]:
    """
    Search Redis for a given query and return the results.
    :param redis_client: Redis client
    :param user_query: Query string
    :param index_name: Name of the index to search in
    :param vector_field: Name of the vector field
    :param return_fields: List of fields to return
    :param hybrid_fields: List of fields to use for hybrid search
    :param k: Number of results to return
    :param print_results: Whether to print the results
    :return: List of results
    """
    # Creates embedding vector from user query
    embedded_query = openai.Embedding.create(input=user_query,
                                             model="text-embedding-ada-002",
                                             )["data"][0]['embedding']
    
    # Prepare the query
    base_query = f'{hybrid_fields}=>[KNN {k} @{vector_field} $vector AS vector_score]'

    query = (
        Query(base_query)
        .return_fields(*return_fields)
        .sort_by("vector_score")
        .paging(0, k)
        .dialect(2)
    )

    params_dict = {
        "vector": np.array(embedded_query).astype(dtype=np.float32).tobytes()
    }

    # perforrm vector search
    results = redis_client.ft(index_name).search(query, params_dict)

    # Print the results
    if print_results:
        for i, result in enumerate(results.docs):
            print(f"Rank: {i}")
            print(f"Topic: {result.topic}")
            print(f"Overview: {result.overview}")
            print(f"Symptoms: {result.symptoms}")
            print(f"URL: {result.url}")
            score = 1 - float(result.vector_score)
            print(f"Score: {round(score, 3)})")
            print()

    return results.docs

In [53]:
# Search redis
results = search_redis(redis_client, 'back pain', k=10)

Rank: 0
Topic: Lumbar Pain
Overview: We often bring on our back problems through bad habits, such as: The spine is actually a stack of 24 bones called vertebrae. A healthy spine is S-shaped when viewed from the side. It curves back at your shoulders and inward at your neck and small of your back. It houses and protects your spinal cord, the network of nerves that transmit feeling and control movement throughout your entire body.  One of the more common types of back pain comes from straining the bands of muscles surrounding the spine. It happens most often in the curve of the low back and the base of the neck. These areas support more weight than your upper and mid back, which are less prone to trouble. Injuries from contact sports, accidents, and falls can cause problems ranging from minor muscle strains, to herniated disks, to fractures that damage the spinal column or cord. Stabbing low back pain could be from muscle spasms, when your muscles seize up and don't relax, like a cramp. 