In [93]:
import os
import redis
from redis.commands.search.query import Query
from redis.commands.search.field import (
    NumericField,
    TagField,
    TextField,
    VectorField,
)
from redis.commands.search.indexDefinition import IndexDefinition, IndexType

from sentence_transformers import SentenceTransformer, util

import numpy as np

from my_util import get_chunks, get_topk_similarity
from typing import List

## Ingest Redis

In [2]:
content, metadata = get_chunks(company_name="novo_nordisk")

In [3]:
embedder = SentenceTransformer(
    "sentence-transformers/msmarco-distilbert-base-tas-b",
    cache_folder="cache",
)

In [89]:
embeddings = embedder.encode(content)

In [90]:
print(embeddings.shape)

(327, 768)


In [24]:
redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)

In [25]:
def insert_redis_index(client, index_name: str, docs, embeddings, metadata, dim=768):
        
    pipeline = client.pipeline()
    
    # insert docs by keys
    for idx, doc in enumerate(docs):
        redis_key = f"docs:{idx}"
        pipeline.json().set(redis_key, "$", {"content": doc})
    
    res = pipeline.execute()
    
    # insert embeddings by keys
    keys = sorted(client.keys("docs:*"))
    
    for key, embedding, meta in zip(keys, embeddings, metadata):
        pipeline.json().set(key, "$.content_vector", embedding)
        pipeline.json().set(key, "$.source", meta["source"])
        pipeline.json().set(key, "$.start_index", meta["start_index"])

    res = pipeline.execute()
    
    # create index
    schema = (
        TextField("$.content", as_name="content"),
        TextField("$.source", as_name="source"),
        NumericField("$.start_index", as_name="start_index"),
        VectorField(
            "$.content_vector",
            "HNSW",
            {
                "TYPE": "FLOAT32",
                "DIM": dim,
                "DISTANCE_METRIC": "COSINE",
            },
            as_name="content_vector",
        ),
    )

    definition = IndexDefinition(prefix=["docs:"], index_type=IndexType.JSON)
    
    res = client.ft(index_name).create_index(
        fields=schema, definition=definition
    )
    print(res)

In [27]:
insert_redis_index(
    client=redis_client,
    index_name = "rag-redis-demo", 
    docs=content,
    embeddings=embeddings.tolist(), 
    metadata=metadata, 
    dim=768
)

OK


## Search Redis

In [19]:
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain.chat_models import ChatOpenAI
from langchain.vectorstores import Redis

from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from langchain_core.documents import Document
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun

In [28]:
def query_redis_index(client, query_str: str, index_name, embedder):
    queries = [query_str]
    encoded_queries = embedder.encode(queries)

    sql = (
        Query("(*)=>[KNN 10 @content_vector $query_vector AS vector_score]")
        .sort_by("vector_score")
        .return_fields("id", "content", "content_vector", "vector_score")
        .dialect(2)
    )

    for query, encoded_query in zip(queries, encoded_queries):
        query_vector = np.array(encoded_query, dtype=np.float32).tobytes()
        result_docs = client.ft(index_name).search(sql, {"query_vector": query_vector}).docs

        print(f"Num of results: {len(result_docs)}")

        for result_doc in result_docs:
            vector_score = round(1 - float(result_doc.vector_score), 2)
            
            print("{} \t {} \t {:.4f}".format(result_doc.id, query, vector_score))

In [29]:
query_redis_index(
    client=redis_client,
    query_str = "what is scope 1 emissions", 
    index_name="rag-redis-demo", 
    embedder=embedder,
)

Num of results: 10
docs:39 	 what is scope 1 emissions 	 0.8400
docs:40 	 what is scope 1 emissions 	 0.8200
docs:4 	 what is scope 1 emissions 	 0.8100
docs:38 	 what is scope 1 emissions 	 0.7800
docs:115 	 what is scope 1 emissions 	 0.7800
docs:43 	 what is scope 1 emissions 	 0.7800
docs:119 	 what is scope 1 emissions 	 0.7700
docs:117 	 what is scope 1 emissions 	 0.7700
docs:118 	 what is scope 1 emissions 	 0.7700
docs:44 	 what is scope 1 emissions 	 0.7600


## Generate answer

### Custom retriever

In [81]:
class CustomRetriever(VectorStoreRetriever):
    vectorstore: VectorStore
    search_type: str = "similarity"

    def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
        
        # Option 1: Find cosine similarity, instead of KNN or MMR
        # encoded_query = embedder.encode(query)
        # result_pairs = get_topk_similarity(
        #     k=10, 
        #     encoded_query=encoded_query, 
        #     encoded_docs=embeddings, 
        #     is_cos_sim=True, 
        #     debug=False
        # )
        #
        # docs_indices = [doc_idx for (doc_idx, score) in result_pairs]
        
        # Option 2: Hard-coded with the correct answer
        docs_indices = [251]
        
        # Search
        results = []
        
        for doc_idx in docs_indices:
            doc = redis_client.json().get(f"docs:{doc_idx}")
            results.append(doc)
        
        # Prepare document results
        docs = []
        
        for doc_idx, result in zip(docs_indices, results):
            
            metadata = {"id": doc_idx}  
            docs.append(
                Document(page_content=result["content"], metadata=metadata)
            )
        
        return docs

In [82]:
vectorstore = Redis.from_existing_index(
    embedding=embedder, 
    index_name="rag-redis-demo", 
    schema="schema.yml", 
    redis_url="redis://localhost:6379",
)

# search_type="similarity" -> KNN
# search_type="mmr" -> MMR
# retriever = vectorstore.as_retriever(search_type="similarity")

retriever = CustomRetriever(vectorstore=vectorstore)

### Template

In [83]:
template = """
Use the following pieces of context from the sustainability report
to answer the question. Do not make up an answer if there is no
context provided to help answer it.

Context:
---------
{context}

---------
Question: {question}
---------

Answer:
"""


prompt = ChatPromptTemplate.from_template(template)

### Chain

In [94]:
llm = ChatOpenAI(
    model_name="gpt-3.5-turbo", 
    openai_api_key=os.getenv("OPENAI_API_KEY"), 
    temperature=0
)

In [85]:
class Question(BaseModel):
    __root__: str

In [86]:
chain = (
    RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
    | prompt
    | llm
    | StrOutputParser()
).with_types(input_type=Question)

### Chat

In [87]:
chain.invoke("What is the scope 1 emissions?")
# chain.invoke("what is scope 1 emissions")

'The scope 1 emissions for Novo Nordisk is 7.1 thousand tonnes of CO2.'