diff --git a/audio-to-text/Containerfile b/audio-to-text/Containerfile index 8719e66..28d2b58 100644 --- a/audio-to-text/Containerfile +++ b/audio-to-text/Containerfile @@ -1,4 +1,4 @@ -FROM registry.access.redhat.com/ubi9/python-311:1-66.1720018730 +FROM registry.access.redhat.com/ubi9/python-311:1-72.1722518949 WORKDIR /locallm COPY requirements.txt /locallm/requirements.txt RUN pip install --upgrade pip && \ diff --git a/chatbot/Containerfile b/chatbot/Containerfile index 585f4e4..29cff1a 100644 --- a/chatbot/Containerfile +++ b/chatbot/Containerfile @@ -1,4 +1,4 @@ -FROM registry.access.redhat.com/ubi9/python-311:1-66.1720018730 +FROM registry.access.redhat.com/ubi9/python-311:1-72.1722518949 WORKDIR /chat COPY requirements.txt . RUN pip install --upgrade pip diff --git a/codegen/Containerfile b/codegen/Containerfile index 6a3258d..b9de4a3 100644 --- a/codegen/Containerfile +++ b/codegen/Containerfile @@ -1,4 +1,4 @@ -FROM registry.access.redhat.com/ubi9/python-311:1-66.1720018730 +FROM registry.access.redhat.com/ubi9/python-311:1-72.1722518949 WORKDIR /codegen COPY requirements.txt . RUN pip install --upgrade pip diff --git a/object-detection/Containerfile b/object-detection/Containerfile index b98d502..a5f6ae3 100644 --- a/object-detection/Containerfile +++ b/object-detection/Containerfile @@ -1,4 +1,4 @@ -FROM registry.access.redhat.com/ubi9/python-311:1-66.1720018730 +FROM registry.access.redhat.com/ubi9/python-311:1-72.1722518949 WORKDIR /locallm COPY requirements.txt /locallm/requirements.txt RUN pip install --upgrade pip && \ diff --git a/pull-sample-app.sh b/pull-sample-app.sh index 11c9b69..0a54a14 100755 --- a/pull-sample-app.sh +++ b/pull-sample-app.sh @@ -4,6 +4,7 @@ CHATBOT_DIR=$ROOT_DIR/chatbot/ CODEGEN_DIR=$ROOT_DIR/codegen/ AUDIO_TO_TEXT_DIR=$ROOT_DIR/audio-to-text/ OBJECTION_DETECTION_DIR=$ROOT_DIR/object-detection/ +RAG_DIR=$ROOT_DIR/rag/ REPO="https://github.com/containers/ai-lab-recipes" @@ -19,5 +20,6 @@ cp -r $TEMPDIR/$REPONAME/recipes/natural_language_processing/chatbot/app/ $CHATB cp -r $TEMPDIR/$REPONAME/recipes/natural_language_processing/codegen/app/ $CODEGEN_DIR cp -r $TEMPDIR/$REPONAME/recipes/audio/audio_to_text/app/ $AUDIO_TO_TEXT_DIR cp -r $TEMPDIR/$REPONAME/recipes/computer_vision/object_detection/app/ $OBJECTION_DETECTION_DIR +cp -r $TEMPDIR/$REPONAME/recipes/natural_language_processing/rag/app/ $RAG_DIR rm -rf $TEMPDIR # clean up \ No newline at end of file diff --git a/rag/Containerfile b/rag/Containerfile new file mode 100644 index 0000000..1aa72f5 --- /dev/null +++ b/rag/Containerfile @@ -0,0 +1,24 @@ +FROM registry.access.redhat.com/ubi9/python-311:1-72.1722518949 +### Update sqlite for chroma +USER root +RUN dnf remove sqlite3 -y +RUN wget https://www.sqlite.org/2023/sqlite-autoconf-3410200.tar.gz +RUN tar -xvzf sqlite-autoconf-3410200.tar.gz +WORKDIR sqlite-autoconf-3410200 +RUN ./configure +RUN make +RUN make install +RUN mv /usr/local/bin/sqlite3 /usr/bin/sqlite3 +ENV LD_LIBRARY_PATH="/usr/local/lib" +#### +WORKDIR /rag +COPY requirements.txt . +RUN pip install --upgrade pip +RUN pip install --no-cache-dir --upgrade -r /rag/requirements.txt +COPY rag_app.py . +COPY manage_vectordb.py . +EXPOSE 8501 +ENV HF_HUB_CACHE=/rag/models/ +RUN mkdir -p /rag/models/ +RUN chgrp -R 0 /rag/models/ && chmod -R g=u /rag/models/ +ENTRYPOINT [ "streamlit", "run" ,"rag_app.py" ] diff --git a/rag/manage_vectordb.py b/rag/manage_vectordb.py new file mode 100644 index 0000000..82566ab --- /dev/null +++ b/rag/manage_vectordb.py @@ -0,0 +1,81 @@ +from langchain_community.vectorstores import Chroma +from chromadb import HttpClient +from chromadb.config import Settings +import chromadb.utils.embedding_functions as embedding_functions +from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings +from langchain_community.vectorstores import Milvus +from pymilvus import MilvusClient +from pymilvus import connections, utility + +class VectorDB: + def __init__(self, vector_vendor, host, port, collection_name, embedding_model): + self.vector_vendor = vector_vendor + self.host = host + self.port = port + self.collection_name = collection_name + self.embedding_model = embedding_model + + def connect(self): + # Connection logic + print(f"Connecting to {self.host}:{self.port}...") + if self.vector_vendor == "chromadb": + self.client = HttpClient(host=self.host, + port=self.port, + settings=Settings(allow_reset=True,)) + elif self.vector_vendor == "milvus": + self.client = MilvusClient(uri=f"http://{self.host}:{self.port}") + return self.client + + def populate_db(self, documents): + # Logic to populate the VectorDB with vectors + e = SentenceTransformerEmbeddings(model_name=self.embedding_model) + print(f"Populating VectorDB with vectors...") + if self.vector_vendor == "chromadb": + embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.embedding_model) + collection = self.client.get_or_create_collection(self.collection_name, + embedding_function=embedding_func) + if collection.count() < 1: + db = Chroma.from_documents( + documents=documents, + embedding=e, + collection_name=self.collection_name, + client=self.client + ) + print("DB populated") + else: + db = Chroma(client=self.client, + collection_name=self.collection_name, + embedding_function=e, + ) + print("DB already populated") + + elif self.vector_vendor == "milvus": + connections.connect(host=self.host, port=self.port) + if not utility.has_collection(self.collection_name): + print("Populating VectorDB with vectors...") + db = Milvus.from_documents( + documents, + e, + collection_name=self.collection_name, + connection_args={"host": self.host, "port": self.port}, + ) + print("DB populated") + else: + print("DB already populated") + db = Milvus( + e, + collection_name=self.collection_name, + connection_args={"host": self.host, "port": self.port}, + ) + return db + + def clear_db(self): + print(f"Clearing VectorDB...") + try: + if self.vector_vendor == "chromadb": + self.client.delete_collection(self.collection_name) + elif self.vector_vendor == "milvus": + self.client.drop_collection(self.collection_name) + print("Cleared DB") + except: + print("Couldn't clear the collection possibly because it doesn't exist") diff --git a/rag/rag_app.py b/rag/rag_app.py new file mode 100644 index 0000000..fde09b1 --- /dev/null +++ b/rag/rag_app.py @@ -0,0 +1,101 @@ +from langchain_openai import ChatOpenAI +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain.text_splitter import CharacterTextSplitter +from langchain_community.callbacks import StreamlitCallbackHandler +from langchain_community.document_loaders import TextLoader +from langchain_community.document_loaders import PyPDFLoader +from manage_vectordb import VectorDB +import tempfile +import streamlit as st +import os + +model_service = os.getenv("MODEL_ENDPOINT","http://0.0.0.0:8001") +model_service = f"{model_service}/v1" +chunk_size = os.getenv("CHUNK_SIZE", 150) +embedding_model = os.getenv("EMBEDDING_MODEL","BAAI/bge-base-en-v1.5") +vdb_vendor = os.getenv("VECTORDB_VENDOR", "chromadb") +vdb_host = os.getenv("VECTORDB_HOST", "0.0.0.0") +vdb_port = os.getenv("VECTORDB_PORT", "8000") +vdb_name = os.getenv("VECTORDB_NAME", "test_collection") + +vdb = VectorDB(vdb_vendor, vdb_host, vdb_port, vdb_name, embedding_model) +vectorDB_client = vdb.connect() +def split_docs(raw_documents): + text_splitter = CharacterTextSplitter(separator = ".", + chunk_size=int(chunk_size), + chunk_overlap=0) + docs = text_splitter.split_documents(raw_documents) + return docs + + +def read_file(file): + file_type = file.type + if file_type == "application/pdf": + temp = tempfile.NamedTemporaryFile() + with open(temp.name, "wb") as f: + f.write(file.getvalue()) + loader = PyPDFLoader(temp.name) + + if file_type == "text/plain": + temp = tempfile.NamedTemporaryFile() + with open(temp.name, "wb") as f: + f.write(file.getvalue()) + loader = TextLoader(temp.name) + raw_documents = loader.load() + return raw_documents + +st.title("📚 RAG DEMO") +with st.sidebar: + file = st.file_uploader(label="📄 Upload Document", + type=[".txt",".pdf"], + on_change=vdb.clear_db + ) + +### populate the DB #### +if file != None: + text = read_file(file) + documents = split_docs(text) + db = vdb.populate_db(documents) + retriever = db.as_retriever(threshold=0.75) +else: + retriever = {} + print("Empty VectorDB") + + +######################## + +if "messages" not in st.session_state: + st.session_state["messages"] = [{"role": "assistant", + "content": "How can I help you?"}] + +for msg in st.session_state.messages: + st.chat_message(msg["role"]).write(msg["content"]) + + +llm = ChatOpenAI(base_url=model_service, + api_key="EMPTY", + streaming=True, + callbacks=[StreamlitCallbackHandler(st.container(), + collapse_completed_thoughts=True)]) + +prompt = ChatPromptTemplate.from_template("""Answer the question based only on the following context: +{context} + +Question: {input} +""" +) + +chain = ( + {"context": retriever, "input": RunnablePassthrough()} + | prompt + | llm +) + +if prompt := st.chat_input(): + st.session_state.messages.append({"role": "user", "content": prompt}) + st.chat_message("user").markdown(prompt) + response = chain.invoke(prompt) + st.chat_message("assistant").markdown(response.content) + st.session_state.messages.append({"role": "assistant", "content": response.content}) + st.rerun() diff --git a/rag/requirements.txt b/rag/requirements.txt new file mode 100644 index 0000000..572acc0 --- /dev/null +++ b/rag/requirements.txt @@ -0,0 +1,8 @@ +langchain-openai==0.1.7 +langchain==0.1.20 +chromadb==0.5.5 +sentence-transformers==2.7.0 +streamlit==1.34.0 +pypdf==4.2.0 +pymilvus==2.4.1 +