In [None]:
# https://www.pragnakalp.com/leverage-phi-3-exploring-rag-based-qna-with-microsofts-phi-3/

In [None]:
!pip install torch
!pip install transformers
!pip install langchain chromadb pypdf openai sentence-transformers accelerate
!pip install rapidocr-onnxruntime

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader,PyMuPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma,FAISS
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain import HuggingFacePipeline
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
from langchain.vectorstores.base import VectorStoreRetriever
from langchain.vectorstores.utils import DistanceStrategy
from langchain.chains import RetrievalQA
from transformers import TextIteratorStreamer
from threading import Thread
model_kwargs = {'device': 'cuda'}
embeddings = HuggingFaceEmbeddings(model_kwargs=model_kwargs)

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-128k-instruct", device_map='auto', torch_dtype="auto", trust_remote_code=True,)

streamer = TextIteratorStreamer(
        tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=300.0
)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=600,streamer=streamer)
llm = HuggingFacePipeline(pipeline=pipe)

In [None]:
# 20240515 Can add multiple pdf into a list
# Load the PDF file
#pdf_link = "Weekly Report KYD210.pdf"
pdf_link = "2106.09685v2.pdf"
#loader = PyPDFLoader(pdf_link, extract_images=False)
loader = PyPDFLoader(pdf_link, extract_images=True)
# load_and_split => Return List[Document]

#This function will always show "Windows platform detected, try to use DirectML as primary provider" and the running is very slow(with image)
#However; sometime it will not happen this problem(without image)
pages = loader.load_and_split()

#pages = loader.load()

# Split data into chunks
text_splitter = RecursiveCharacterTextSplitter(
   chunk_size = 4000,
   chunk_overlap  = 0,
   length_function = len,
   add_start_index = True,
)
chunks = text_splitter.split_documents(pages)

In [None]:
vectorstore = FAISS.from_documents(
        chunks, embeddings, distance_strategy=DistanceStrategy.DOT_PRODUCT
)
retriever = VectorStoreRetriever(vectorstore=vectorstore, search_kwargs={"k": 2})

In [None]:
# Store data into database
#db=Chroma.from_documents(chunks,embedding=embeddings,persist_directory="test_index")
#db.persist()

In [None]:
# Load the database
#vectordb = Chroma(persist_directory="test_index", embedding_function = embeddings)

# Load the retriver
#retriever = vectordb.as_retriever(search_kwargs = {"k" : 3})

In [None]:
# Define the custom prompt template suitable for the Phi-3 model
qna_prompt_template="""<|system|>
You have been provided with the context and a question, try to find out the answer to the question only using the context information. If the answer to the question is not found within the context, return "I dont know" as the response.<|end|>
<|user|>
Context:
{context}

Question: {question}<|end|>
<|assistant|>"""
PROMPT = PromptTemplate(
   template=qna_prompt_template, input_variables=["context", "question"]
)

# Define the QNA chain
#chain = load_qa_chain(llm, chain_type="stuff", prompt=PROMPT)
chain = RetrievalQA.from_chain_type(
        llm=llm,
        retriever=retriever,
        chain_type_kwargs={"prompt": PROMPT},
)


In [None]:
def ask(question):
    thread = Thread(target=chain.invoke, kwargs={"input": {"query": question}})
    thread.start()
    response = ""
    for token in streamer:
        #pattern = r'^[!@#].*?(Response|response)'
        #match = re.search(r':\s*(.*)', token)
        # 使用正則表達式進行匹配
        #if re.match(pattern, token):
        #    continue
        #print(token)# 定義正則表達式模式，以匹配開頭的空白、特殊字元以及 "Response" 或 "response"
        
        #if "-" in token:
        #    continue
        #pattern = r'^[\s!@#-]*?(?:Response|response)[:]'

        # 使用 sub() 方法替換匹配的部分為空字符串
        #cleaned_string = re.sub(pattern, '', token)

        # 如果處理後的字符串不為空，則進行處理
        #if cleaned_string.strip():
        #    print("Processed string:", cleaned_string)
        
        #response += cleaned_string
        response += token
        #yield response.strip()
    #match = re.search(r':\s*(.*)', response)
    #if match:
        # 如果找到匹配，則取出 ":" 之後的部分
    #    response = match.group(1)
    return response.strip()

In [None]:
# A utility function for answer generation
#def ask(question):
#   context = retriever.get_relevant_documents(question)
#   answer = (chain({"input_documents": context, "question": question}, return_only_outputs=True))['output_text']
#   return answer

In [None]:
# Take the user input and call the function to generate output
user_question = input("User: ")
answer = ask(user_question)
answer = (answer.split("<|assistant|>")[-1]).strip()
print("Answer:", answer)