In [199]:
import os
import time
import openai
import tiktoken
import cohere
import chromadb
import tempfile
import google.generativeai as genai
from llama_index.llms import OpenAI, Gemini
from llama_index.memory import ChatMemoryBuffer
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext, StorageContext, PromptHelper, LLMPredictor, load_index_from_storage
from llama_index.embeddings import OpenAIEmbedding, GeminiEmbedding
from llama_index.vector_stores import ChromaVectorStore
from llama_index.indices.postprocessor import SentenceEmbeddingOptimizer, LLMRerank, CohereRerank, LongContextReorder
from llama_index import download_loader
from llama_index.text_splitter import TokenTextSplitter
from llama_index.node_parser import SimpleNodeParser
from flask import Flask, jsonify, flash, request, redirect, render_template, url_for, Response, stream_with_context



## Configs

In [304]:
GOOGLE_API_KEY = 'AIzaSyB4Aew8oVjBgPMZlskdhdmQs27DuyNBDAY'
os.environ["GOOGLE_API_KEY"]  = GOOGLE_API_KEY

In [3]:
genai.configure(
    api_key=GOOGLE_API_KEY,
    client_options={"api_endpoint": "generativelanguage.googleapis.com"},
)

In [4]:
for m in genai.list_models():
    if "generateContent" in m.supported_generation_methods:
        print(m.name)

models/gemini-pro
models/gemini-pro-vision


In [5]:
CHROMADB_HOST = "localhost"
COHERE_RERANK_KEY = 'p8K3ASZaficAE1YlOh9dAY3x5Tkxa8sOmCRtJOtP'
ALLOWED_EXTENSIONS = {'txt', 'htm', 'html', 'pdf', 'doc', 'docx', 'ppt', 'pptx', 'csv'}


## Helpers

To start the chroma server, run the following in terminal:

`chroma run --path ./src/vector_db`

In [168]:
def generate_vector_embedding(index_name, temp_dir):
        
    try:
        # initialize client, setting path to save data
        # db = chromadb.PersistentClient(path="./chroma_db")
        print("Connecting to Chroma database...")
        db = chromadb.HttpClient(host=CHROMADB_HOST, port=8000)
    except:
        return {'statusCode': 400, 'status': 'Could not connect to chroma database'}

    try:
        # create collection
        print("Creating vector embeddings......")
        print("Index name: ", index_name)
        start_time = time.time()
        chroma_collection = db.get_or_create_collection(
            name=index_name,
            metadata={"hnsw:space": "cosine"} # default: L2; used before: ip
            )
    except Exception as e:
        print("Error : : :", e)
        return {'statusCode': 400, 'status': 'A knowledge base with the same name already exists'}

    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)

    # setup our storage (vector db)
    storage_context = StorageContext.from_defaults(
        vector_store=vector_store
    )

    llm = Gemini(api_key=GOOGLE_API_KEY, model='models/gemini-pro', temperature=0.6)
    embed_model = GeminiEmbedding(api_key=GOOGLE_API_KEY)
    # node_parser = SimpleNodeParser.from_defaults(
    #     text_splitter=TokenTextSplitter(chunk_size=1024, chunk_overlap=20)
    #     )

    service_context = ServiceContext.from_defaults(
        llm=llm,
        embed_model=embed_model,
        chunk_size=1024,
        chunk_overlap=20
        )

    documents = SimpleDirectoryReader(input_dir=temp_dir).load_data()
    
    index = VectorStoreIndex.from_documents(
        documents,
        storage_context=storage_context,
        service_context=service_context
    )

    # temp_dir.cleanup() # delete document temp dir

    print(f"Vector embeddings created in {time.time() - start_time} seconds.")

    response = {
        'statusCode': 200,
        'status': 'Chroma embedding complete',
    }
    return response

In [336]:
def get_index_from_vector_db(index_name):
    
    try:
        # initialize client
        db = chromadb.HttpClient(host=CHROMADB_HOST, port=8000)
    except Exception as e:
        print('<<< get_index_from_vector_db() >>> Could not connect to database!\n', e)
        return None, None
    
    # get collection and embedding size
    try:
        chroma_collection = db.get_collection(index_name)
        doc_size = chroma_collection.count()
        print('Computing knowledge base size...', doc_size)
    except Exception as e:
        print(e)
        return None, None

    start_time = time.time()
    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
    storage_context = StorageContext.from_defaults(vector_store=vector_store)

    # Experimented settings for large docs versus small. Don't change except you have tested extensively!!!
    # context_window=65000 if doc_size>200 else 16384
    context_window=32000 if doc_size>300 else 2048
    embed_model = GeminiEmbedding(api_key=GOOGLE_API_KEY)
    llm = Gemini(api_key=GOOGLE_API_KEY, model='models/gemini-pro', max_tokens=4096, temperature=0.3)
    print_msg = "Using Gemini Pro..."
    print(print_msg)

    # node_parser = SimpleNodeParser.from_defaults(
    #     text_splitter=TokenTextSplitter(chunk_size=1024, chunk_overlap=20)
    #     )

    service_context = ServiceContext.from_defaults(
        llm=llm,
        context_window=context_window, 
        embed_model=embed_model,
        chunk_size=1024,
        chunk_overlap=20
    )

    print('Retrieving knowledge base index from ChromaDB...')            
    index = VectorStoreIndex.from_vector_store(
        vector_store=vector_store, 
        storage_context=storage_context,
        service_context=service_context
    )

    print(f'Index retrieved from ChromaDB in {time.time() - start_time} seconds.')
    return index, doc_size

In [263]:
def postprocessor_args(doc_size):
    if doc_size<500:
        return None
    
    print('Optimising context information...')
    
    # fastest postprocessor
    cohere_rerank = CohereRerank(api_key=COHERE_RERANK_KEY, top_n=20)

    # slowest postprocessor using GPT-3.5, fairly fast using MockLLM
    embed_model = GeminiEmbedding(api_key=GOOGLE_API_KEY)
    service_context = ServiceContext.from_defaults(llm=None, embed_model=embed_model, chunk_size=256, chunk_overlap=20) # use llama_index default MockLLM (faster)
    # service_context = ServiceContext.from_defaults(llm=OpenAI(temperature=0, model="gpt-3.5-turbo-1106",), chunk_size=512, chunk_overlap=0)
    rank_postprocessor = LLMRerank(
        choice_batch_size=10, top_n=100,
        service_context=service_context,
        parse_choice_select_answer_fn=parse_choice_select_answer_fn
    )
    
    # node postprocessors run in the specified order
    node_postprocessors = [
        rank_postprocessor,
        cohere_rerank,
    ]

    return node_postprocessors

In [9]:
def parse_choice_select_answer_fn(
    answer: str, num_choices: int, raise_error: bool = False
):
    """Default parse choice select answer function."""
    answer_lines = answer.split("\n")
    # print(answer_lines)
    answer_nums = []
    answer_relevances = []
    for answer_line in answer_lines:
        line_tokens = answer_line.split(",")
        if len(line_tokens) != 2:
            if not raise_error:
                continue
            else:
                raise ValueError(
                    f"Invalid answer line: {answer_line}. "
                    "Answer line must be of the form: "
                    "answer_num: <int>, answer_relevance: <float>"
                )
        if len(line_tokens[0].split(":"))>1 and line_tokens[0].split(":")[1].strip().isdigit():
            answer_num = int(line_tokens[0].split(":")[1].strip())
            if answer_num > num_choices:
                continue
            answer_nums.append(answer_num)
            answer_relevances.append(float(line_tokens[1].split(":")[1].strip()))
    # print(answer_nums)
    return answer_nums, answer_relevances

In [229]:
def answer_query_stream(query, index_name, chat_history):

    index, doc_size = get_index_from_vector_db(index_name)
    prompt_header = prompt_style()

    if index is None:
        response = "I'm sorry I couldn't find any document in your knowledge base. Please add documents to your knowledge base and try again."
        return response
    else:
        node_postprocessors = postprocessor_args(doc_size)
        # similarity_top_k = 50 if doc_size>500 else 10
        similarity_top_k = 300 if doc_size>500 else 15 if doc_size>200 else 20
        chat_engine = index.as_chat_engine(chat_mode="context", 
                                            memory=chat_history,
                                            system_prompt=prompt_header, 
                                            similarity_top_k=similarity_top_k,
                                            verbose=True, 
                                            # streaming=True,
                                            function_call="query_engine_tool",
                                            node_postprocessors=node_postprocessors
                                            )

        message_body = f"""\nUse the tool to answer:\n{query}\n"""
        response = chat_engine.chat(message_body)
        # print(get_formatted_sources(response) if response.source_nodes else None)
        
        if response is None:
            print("Index retrieved but cannot stream response...")
            chat_response = "I'm sorry I couldn't find an answer to the requested information in your knowledge base. Please rephrase your question and try again."
            # for token in chat_response.split():
            #     print(token, end=" ")
            #     yield f"""{token} """
            return chat_response
        else:
            print('Starting response stream...\n...........................\n...........................')
            # return response.response
            # for token in response.response_gen:
            #     print(token, end="")
            #     yield f"""{token}"""
            print(response.response)

In [225]:

def get_formatted_sources(response, length=100, trim_text=True) -> str:
    """Get formatted sources text."""
    from llama_index.utils import truncate_text
    texts = []
    for source_node in response.source_nodes:
        fmt_text_chunk = source_node.node.get_content()
        if trim_text:
            fmt_text_chunk = truncate_text(fmt_text_chunk, length)
        # node_id = source_node.node.node_id or "None"
        node_id = source_node.node.metadata['page_label'] or "None"
        source_text = f"> Source (Page no: {node_id}): {fmt_text_chunk}"
        texts.append(source_text)
    return "\n\n".join(texts)

In [368]:
def prompt_style(): 

    prompt_header = f"""Your name is Alpha. 
    You are a helpful and friendly Q&A bot, a highly intelligent system that 
    answers my questions based on the information I provide above each question. 
    Only use the information in the knowledge base I have provided, and if the information can not be found, decline to answer it. 
    """

    return prompt_header   

## Embedding

In [192]:
project_name = "Ebible"
index_name = project_name.lower() + '_embeddings'

In [193]:
generate_vector_embedding(index_name, './temp')

Connecting to Chroma database...
Creating vector embeddings......
Index name:  ebible_embeddings
Vector embeddings created in 756.6599473953247 seconds.


{'statusCode': 200, 'status': 'Chroma embedding complete'}

## Retrieval and Chat

In [389]:
def message_thread(memory, reset=None):
    if reset:
        new_conversation_state = init_chat_history()
        return new_conversation_state
    return memory

def init_chat_history():
    new_conversation_state = ChatMemoryBuffer.from_defaults(token_limit=50000)
    return new_conversation_state


chat_history = init_chat_history()

In [391]:
query = "State the verses where Demas is mentioned"

In [392]:
answer_query_stream(query, index_name, chat_history)
    # for token in response.response_gen:
    # print(token, end="")

Computing knowledge base size... 1607
Using Gemini Pro...
Retrieving knowledge base index from ChromaDB...
Index retrieved from ChromaDB in 2.4284427165985107 seconds.
Optimising context information...
LLM is explicitly disabled. Using MockLLM.
Starting response stream...
...........................
...........................
This context does not mention anything about Demas, so I cannot extract the requested data from the provided context.
