In [14]:
import sys

sys.path.append("../")
import streamlit as st
from dotenv import load_dotenv
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.vectorstores import FAISS
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain, RetrievalQA
from models.sambanova_endpoint import SambaNovaEndpoint
from langchain.prompts import PromptTemplate
import glob

## TO BE FILLED TO ACCESS THE LLM ENDPOINT

In [15]:
# PROVIDE: Directory containing pdf
# input_data_loc = 

# # PROVIDE: API Info
# base_url=f'https://sjc1-demo1.sambanova.net'
# project_id=
# endpoint_id=
# api_key=

## Functions

In [16]:
## Extract text and metadata from pdf
def get_pdf_text_and_metadata(pdf_doc):
    text = []
    metadata = []
    pdf_reader = PdfReader(pdf_doc)
    for page in pdf_reader.pages:
        text.append(page.extract_text())
        metadata.append({"filename": pdf_doc, "page": pdf_reader.get_page_number(page)})
    return text, metadata


# Read the pdf files and extract text + metadata
def get_data_for_splitting(pdf_docs):
    files_data = []
    files_metadatas = []
    for file in pdf_docs:
        text, meta = get_pdf_text_and_metadata(file)
        files_data.extend(text)
        files_metadatas.extend(meta)
    return files_data, files_metadatas


# Chunk the extracted data
def get_text_chunks(text, metadata):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000, chunk_overlap=200, length_function=len
    )
    chunks = text_splitter.create_documents(text, metadata)
    return chunks


def get_vectorstore(text_chunks):
    encode_kwargs = {"normalize_embeddings": True}
    embeddings = HuggingFaceInstructEmbeddings(
        model_name="BAAI/bge-large-en",
        embed_instruction="",  # no instruction is needed for candidate passages
        query_instruction="Represent this sentence for searching relevant passages: ",
        encode_kwargs=encode_kwargs,
    )
    vectorstore = FAISS.from_documents(documents=text_chunks, embedding=embeddings)
    return vectorstore


def get_qa_retrieval_chain(vectorstore):
    llm = SambaNovaEndpoint(model_kwargs={"do_sample": False, "temperature": 0.0})

    conversation_chain = RetrievalQA.from_llm(
        llm=llm,
        retriever=vectorstore.as_retriever(),
        return_source_documents=True,
        input_key="question",
    )
    return conversation_chain


def get_conversation_chain(vectorstore):
    llm = SambaNovaEndpoint(model_kwargs={"do_sample": False, "temperature": 0.0})

    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
    conversation_chain = ConversationalRetrievalChain.from_llm(
        llm=llm, retriever=vectorstore.as_retriever(), memory=memory
    )
    return conversation_chain

## Read the pdf files and metadata

In [17]:
pdf_docs = [f for f in glob.glob(f"{input_data_loc}/*.pdf")]

# get pdf text
raw_text, meta_data = get_data_for_splitting(pdf_docs)

## Chunk the text

In [18]:
# get the text chunks
text_chunks = get_text_chunks(raw_text, meta_data)

## Create a vector store (for example: FAISS)

In [19]:
# create vector store
vectorstore = get_vectorstore(text_chunks)

load INSTRUCTOR_Transformer
max_seq_length  512


## Inititalize the Large language model

- **Note**: api info will have to be updated to point to customers endpoint

In [20]:
llm = SambaNovaEndpoint(
    base_url=base_url,
    project_id=project_id,
    endpoint_id=endpoint_id,
    api_key=api_key,
    model_kwargs={"do_sample": False, "temperature": 0.0},
)

## Initialize a ConversationalRetrieval chain

In [21]:
# memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer')

# conversation_chain = ConversationalRetrievalChain.from_llm(
#     llm=llm,
#     retriever=vectorstore.as_retriever(),
#     memory=memory,
#     return_source_documents=True,
#     )

retriever = vectorstore.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={"score_threshold": 0.5, "k": 4},
)
retrieval_chain = RetrievalQA.from_llm(
    llm=llm,
    retriever=retriever,
    return_source_documents=True,
    input_key="question",
    output_key="answer",
)

## Customer prompt

In [22]:
custom_prompt_template = """Use the following pieces of context to answer the question at the end. 
If the answer is not in context for answering, say that you don't know, don't try to make up an answer or provide an answer not extracted from provided context. 
Cross check if the answer is contained in provided context. If not than say "I do not have information regarding this."

{context}

Question: {question}
Helpful Answer:"""
CUSTOMPROMPT = PromptTemplate(
    template=custom_prompt_template, input_variables=["context", "question"]
)
## Inject custom prompt
retrieval_chain.combine_documents_chain.llm_chain.prompt = CUSTOMPROMPT

## Ask a question

In [23]:
user_question = "What is the input voltage range for LT8625S?"
response = retrieval_chain({"question": user_question})

In [24]:
print(f'Response ={response["answer"]}')

Response = The LT8625S has an input voltage range of 2.7V to 18V.
