In [1]:
# https://www.pragnakalp.com/leverage-phi-3-exploring-rag-based-qna-with-microsofts-phi-3/

In [2]:
!pip install torch
!pip install transformers
!pip install langchain chromadb pypdf openai sentence-transformers accelerate
!pip install rapidocr-onnxruntime



In [3]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain import HuggingFacePipeline
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate

model_kwargs = {'device': 'cuda'}
embeddings = HuggingFaceEmbeddings(model_kwargs=model_kwargs)

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-128k-instruct", device_map='auto', torch_dtype="auto", trust_remote_code=True,)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=600)
llm = HuggingFacePipeline(pipeline=pipe)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# Load the PDF file
pdf_link = "Weekly Report KYD210.pdf"
#loader = PyPDFLoader(pdf_link, extract_images=False)
loader = PyPDFLoader(pdf_link, extract_images=True)
pages = loader.load_and_split()


# Split data into chunks
text_splitter = RecursiveCharacterTextSplitter(
   chunk_size = 4000,
   chunk_overlap  = 20,
   length_function = len,
   add_start_index = True,
)
chunks = text_splitter.split_documents(pages)

Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary provider
Windows platform detected, try to use DirectML as primary prov

In [5]:
# Store data into database
db=Chroma.from_documents(chunks,embedding=embeddings,persist_directory="test_index")
db.persist()

In [6]:
# Load the database
vectordb = Chroma(persist_directory="test_index", embedding_function = embeddings)

# Load the retriver
retriever = vectordb.as_retriever(search_kwargs = {"k" : 3})

In [7]:
# Define the custom prompt template suitable for the Phi-3 model
qna_prompt_template="""<|system|>
You have been provided with the context and a question, try to find out the answer to the question only using the context information. If the answer to the question is not found within the context, return "I dont know" as the response.<|end|>
<|user|>
Context:
{context}

Question: {question}<|end|>
<|assistant|>"""
PROMPT = PromptTemplate(
   template=qna_prompt_template, input_variables=["context", "question"]
)

# Define the QNA chain
chain = load_qa_chain(llm, chain_type="stuff", prompt=PROMPT)

In [8]:
# A utility function for answer generation
def ask(question):
   context = retriever.get_relevant_documents(question)
   print(context)

   answer = (chain({"input_documents": context, "question": question}, return_only_outputs=True))['output_text']
   return answer

In [14]:
# Take the user input and call the function to generate output
user_question = input("User: ")
answer = ask(user_question)
answer = (answer.split("<|assistant|>")[-1]).strip()
print("Answer:", answer)

[Document(page_content='CONFIDENTIAL\nCTE Operation Update (2024 -01 to 2024 -05)\nWorking ItemDT \n6 UnitedDT \nSTCNB \nBmornNB \nCLEVONB \nGreat WallNB \nHenaNB\nJWIPCTotal\n(hours)\nBuilding Preload 17 9 7 3 4 10 3 53\nMeeting 2 13 7 11 33\nTraining & Q&A 55 28 25 9 2 21 140\nWorking Hours 74 50 39 12 6 42 3 226\nPreload Released 16 7 4 2 2 8 3 42', metadata={'page': 2, 'source': 'Weekly Report KYD210.pdf', 'start_index': 0}), Document(page_content="CONFIDENTIAL\n 14\n2024Win11 24H2 Refreshschedule\nMSFT24H2Timeline\nUCS Testing\n24H2 RTM\nGA +CI\n4/10\nH2\nschedule to ODM\ntversion\nreport\nregression test start\nrefresh\nfrozen\nconfirm\nereport\nrelease\nlist\nlist\ntFinish\nimage refresh start\n24H2\nlist\nApp\nTest\nTest Start\nmode\nUpgrade/App\ntemplate\nReq./Check\nRelease\nRefresh\nle2\nWav\nbd\nApp\nRefresh \nonly)\nRefresh\na)\nPreload \nAPP\nApp & MRD frozen\nWindows\na) App pass QT testing for image refresh\nImage Refresh Start for ALL OS\nApp upgrade T\nPilot-run SCL r