## Example of querying a set of documents with sources using Caikit-TGIS with Llama2, Langchain and a custom prompt

In [1]:
!pip install evaluate
!pip install rouge_score

!pip install chromadb

!pip install git+https://github.com/caikit/caikit@v0.11.3
!pip install git+https://github.com/caikit/caikit-nlp

!git clone https://github.com/caikit/caikit-nlp

Collecting evaluate
  Obtaining dependency information for evaluate from https://files.pythonhosted.org/packages/70/63/7644a1eb7b0297e585a6adec98ed9e575309bb973c33b394dae66bc35c69/evaluate-0.4.1-py3-none-any.whl.metadata
  Downloading evaluate-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Collecting datasets>=2.0.0 (from evaluate)
  Obtaining dependency information for datasets>=2.0.0 from https://files.pythonhosted.org/packages/89/a9/8e097f79d2941a2f96e33f57032957429a79f66c8252ac7fcce586a43406/datasets-2.19.0-py3-none-any.whl.metadata
  Downloading datasets-2.19.0-py3-none-any.whl.metadata (19 kB)
Collecting xxhash (from evaluate)
  Obtaining dependency information for xxhash from https://files.pythonhosted.org/packages/eb/3a/25c4aecb61a49d4415fd71d4f66a8a5b558dd44a52d7054ea9aa59ccbac1/xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading xxhash-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess

In [2]:
!pip install chromadb
!pip install -q caikit-nlp-client

Collecting pydantic<2.0,>=1.9 (from chromadb)
  Obtaining dependency information for pydantic<2.0,>=1.9 from https://files.pythonhosted.org/packages/28/67/359f77395bedd707fbbd29dba5725d412b436250dfb14c2261adbd59d256/pydantic-1.10.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading pydantic-1.10.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (150 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m150.6/150.6 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
Collecting fastapi==0.85.1 (from chromadb)
  Obtaining dependency information for fastapi==0.85.1 from https://files.pythonhosted.org/packages/bf/54/6eb1882b5cb29e6647df92ee74d0a93dab149234ec914563cab955fa667f/fastapi-0.85.1-py3-none-any.whl.metadata
  Downloading fastapi-0.85.1-py3-none-any.whl.metadata (24 kB)
Collecting starlette==0.20.4 (from fastapi==0.85.1->chromadb)
  Obtaining dependency information for starlette==0.20.4 from https://files.pythonhosted.org/p

### Set the Inference server url (replace with your own address) and the model-id

In [3]:
inference_server_url = "http://ollama.myredis.svc.cluster.local:11434"
model_id = "Llama-2-7b-chat-hf"

# If your endpoint is using a self-signed certificate, export the certificate chain as a .pem file and provide its path
# Example: certificate_chain = "certificate.pem"
# Adjust llm instantiation to use this parameter or not
certificate_chain_file = "certificate.pem"

In [4]:
# Optional, requirements if they are not already present
# !pip -q install grpcio grpcio-reflection

### Load some data from the folder where we have stored the PDF documentation

In [5]:
from langchain.document_loaders import PyPDFDirectoryLoader

pdf_folder_path = 'rhods-doc'

loader = PyPDFDirectoryLoader(pdf_folder_path)
docs = loader.load()

### Split the data in chunks large enough to have meaningful answers, and some overlap not to miss anything

In [6]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1024, chunk_overlap = 40)
all_splits = text_splitter.split_documents(docs)

### Store the data as embeddings in a vector database (Chroma)

In [7]:
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma

vectorstore = Chroma.from_documents(documents=all_splits, embedding=HuggingFaceEmbeddings())

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Unable to connect optimized C data functions [No module named '_testbuffer'], falling back to pure Python


### Test data retrieving

In [8]:
question = "How do I create a Data Science Project"
docs = vectorstore.similarity_search(question)
docs

[Document(page_content='CHAPTER 3. WORKING ON DATA SCIENCE PROJECTS\n17', metadata={'source': 'rhods-doc/red_hat_openshift_ai_self-managed-2-latest-working_on_data_science_projects-en-us.pdf', 'page': 20}),
 Document(page_content='Procedure\n1\n. \nFrom the OpenShift AI dashboard, click \nData Science Projects\n.\nThe \nData science projects\n page opens.\n2\n. \nClick \nCreate data science project\n.\nThe \nCreate a data science project\n dialog opens.\n3\n. \nEnter a \nname\n for your data science project.\n4\n. \nOptional: Edit the \nresource name\n for your data science project. The resource name must\nconsist of lowercase alphanumeric characters, \n-\n, and must start and end with an alphanumeric\ncharacter.\n5\n. \nEnter a \ndescription\n for your data science project.\n6\n. \nClick \nCreate\n.\nA project details page opens. From this page, you can create workbenches, add cluster storage\nand data connections, import pipelines, and deploy models.\nVerification\nThe project that y

### Create the chain

In [9]:
from langchain.chains import RetrievalQA
import caikit_tgis_langchain
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.prompts import PromptTemplate

# NOTE: This template syntax is specific to Llama2
template="""<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant.
You will be given a question you need to answer, and a context to provide you with information. You must answer the question based as much as possible on this context.
Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

Question: {question}
Context: {context} [/INST]
"""
QA_CHAIN_PROMPT = PromptTemplate.from_template(template)

llm = caikit_tgis_langchain.CaikitLLM(
    inference_server_url=inference_server_url,
    model_id=model_id,
    certificate_chain=certificate_chain_file,
    streaming=True
)

qa_chain = RetrievalQA.from_chain_type(llm,
                                       retriever=vectorstore.as_retriever(),
                                       chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
                                       return_source_documents=True)

ValueError: "CaikitLLM" object has no field "inference_server"

### Launch the query

In [None]:
question = "How do I create a Data Science Project?"
result = qa_chain({"query": question, "min_new_tokens": 100, "max_new_tokens": 1024}, callbacks=[StreamingStdOutCallbackHandler()])

### Print the sources

In [None]:
def remove_duplicates(input_list):
    unique_list = []
    for item in input_list:
        if item.metadata['source'] not in unique_list:
            unique_list.append(item.metadata['source'])
    return unique_list

results = remove_duplicates(result['source_documents'])

for s in results:
    print(s)