In [None]:
import os
import time

from pymilvus import connections
from pymilvus import FieldSchema
from pymilvus import CollectionSchema
from pymilvus import DataType
from pymilvus import Collection
from pymilvus import utility

from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import PyPDFLoader
from langchain.vectorstores import Milvus
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter

from dotenv import load_dotenv
import genai.extensions.langchain
from genai.extensions.langchain import LangChainInterface
from genai.schemas import GenerateParams
from genai import Credentials
from genai import Model
from genai import PromptPattern

from langchain.vectorstores import FAISS

In [None]:
load_dotenv()
api_key = os.getenv("GENAI_KEY", None)
api_endpoint = os.getenv("GENAI_API", None)
COLLECTION_NAME = os.getenv("COLLECTION_NAME", None)
DIMENSION = os.getenv("DIMENSION", None)
COUNT = os.getenv("COUNT", None)
MAX = os.getenv("MAX",None)
MILVUS_HOST = os.getenv("MILVUS_HOST", None)
MILVUS_PORT = os.getenv("MILVUS_PORT", None)

In [None]:
chunk_size = 1000
chunk_overlap = 150
separator = "\n"

r_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
c_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, separator=separator)

In [None]:
load_dotenv()
api_key = os.getenv("GENAI_KEY", None)
creds = Credentials(api_key)
params = GenerateParams(
    decoding_method = "greedy"
)

In [None]:
connections.connect(host = MILVUS_HOST, port = MILVUS_PORT)

if utility.has_collection(COLLECTION_NAME):
   utility.drop_collection(COLLECTION_NAME)

fields = [
    FieldSchema(name = "id", dtype = DataType.INT64, description = "Ids", is_primary = True, auto_id = False),
    FieldSchema(name = "content", dtype = DataType.VARCHAR, description = "Content texts", max_length = 768*8),
    FieldSchema(name = "embedding", dtype = DataType.FLOAT_VECTOR, description = "Embedding vectors", dim = 768)
]
schema = CollectionSchema(fields = fields, description = "content collection")
collection = Collection(name = COLLECTION_NAME, schema = schema)

index_params = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 1024}
}
collection.create_index(field_name = "embedding", index_params = index_params)

In [None]:
embeddings = HuggingFaceEmbeddings()

In [None]:
def loadPDF(filename):
    loader = PyPDFLoader(filename)
    pages = loader.load()

    len(pages)

    docs = r_splitter.split_documents(pages)

    print(docs[1].page_content)
    print(docs[1].metadata)
    len(docs)
    return docs

In [None]:
def storeToMilvus(docs):
    start = time.time()
    data = [[], [], []]
    if docs:
        for idx, text in enumerate(docs):
            data[0].append(idx)
            data[1].append(text.page_content)
            if len(text.page_content) > 768:
                tt = text.page_content[:766] + ".."
                data[2].append(embeddings.embed_query(tt))
            else:
                data[2].append(embeddings.embed_query(text.page_content))

    collection.insert(data)
    end = time.time()
    print("Duration: ", end - start)

In [None]:
from pathlib import Path

for path in Path('menu/').rglob('*.pdf'):
    print('menu/'+ path.name)
    docs = loadPDF('menu/'+ path.name)
    storeToMilvus(docs)

# print("Number of entities: ", collection.num_entities)

In [None]:
file = open("sample.txt", "r")
questions = file.readlines()

print(questions)

#searching

In [None]:
db = Milvus.from_documents(
    docs,
    embedding=embeddings,
    connection_args={"host": MILVUS_HOST, "port": MILVUS_PORT}
)

In [None]:
creds = Credentials(api_key,api_endpoint)

params = GenerateParams(
    decoding_method="greedy",
    max_new_tokens=300,
    min_new_tokens=15,
    repetition_penalty=2,
)
llm = LangChainInterface(model="meta-llama/llama-2-13b",credentials=creds,params=params)

In [None]:
chain = load_qa_chain(llm, chain_type="stuff")

for query in questions:
    print("Q:"+query)
    resultdocs = db.similarity_search(query, k=3)
    # for res in resultdocs:
    #     print(res)
    answer = chain.run(input_documents=resultdocs, question=query)
    print("A:"+answer)