# Question Answering with LangChain and a custom LLM accessible via a custom API

The goal of this notebook is to illustrate how to run a [LangChain](https://github.com/hwchase17/langchain) question answering chain that relies on a custom Language Model that is exposed via an API.

Out of the box LangChain comes with built-in integrations for [multiple Large Language Models](https://python.langchain.com/en/latest/modules/models/llms/integrations.html). However, here we want to cover the case where we have a custom API exposed via our own custom API that does not match any of the pre-existing APIs.

In particular, this example notebook interacts with a model that is hosted and exposed via the API from the [text-generation-webui](https://github.com/oobabooga/text-generation-webui/) application.

References: 
- https://blog.langchain.dev/tutorial-chatgpt-over-your-data/ explains the overall workflow that this notebook follows. The example in the blog post uses OpenAI LLMs and a FAISS vector store; here we will use a locally hosted LM and ChromaDB instead.
- https://github.com/hwchase17/langchain/blob/master/docs/modules/models/llms/examples/custom_llm.ipynb

# Initialization

In [1]:
import json
import os
import requests
from dotenv import load_dotenv, find_dotenv
from langchain.chains import RetrievalQA
from langchain.document_loaders import DirectoryLoader, TextLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms.base import LLM
from langchain.prompts import PromptTemplate
from langchain.text_splitter import MarkdownTextSplitter
from langchain.vectorstores import Chroma
from typing import Optional, List, Mapping, Any

# Custom LLM

To allow LangChain to interact with our model via its custom API, we define a custom LLM wrapper around that API. Essentially, the class that we define must inherit from LangChain's `LLM` base class and provide a `_call` method that takes care of querying our model through its API:

In [2]:
class ApiLLM(LLM):
    
    api_url: str
    "The URL of the API endpoint."

    api_params: dict
    "A dictionary of parameters to pass to the API."

    @property
    def _llm_type(self) -> str:
        return "custom"
    
    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        """Call the LLM."""
        request = {"prompt": prompt, **self.api_params}
        response = requests.post(self.api_url, json=request).json()
        return response["results"][0]['text']
    
    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"api_url": self.api_url, "api_params": self.api_params}

## Instantiating our custom LLM

Here we create an instance of the custom language model by providing an enpdoint URL and a set of default parameters to pass to the API calls along each prompt.

Note that there is no way to pick a specific model in the API we are using here: the model will be whatever is being served by the API endpoint.

In [3]:
# API URL
# The endpoint for the API can be stored in a .env file (see credentials_example.env)
load_dotenv(find_dotenv("credentials.env"), override=True)
endpoint = os.environ.get("LLM_API_URL", None)

# Alternatively, set it manually by uncommenting the following line and replacing the URL
#endpoint = "https://text-generation-api-llm.example.com/api/v1/generate"

# Generation parameters
# Reference: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
params = {
    'max_new_tokens': 250,
    'do_sample': True,
    'temperature': 0.1,
    'top_p': 0.1,
    'typical_p': 1,
    'repetition_penalty': 1.1,
    'encoder_repetition_penalty': 1.0,
    'top_k': 30,
    'min_length': 0,
    'no_repeat_ngram_size': 0,
    'num_beams': 1,
    'penalty_alpha': 0,
    'length_penalty': 1,
    'early_stopping': False,
    'seed': -1,
    'add_bos_token': True,
    'truncation_length': 2048,
    'ban_eos_token': False,
    'skip_special_tokens': True,
    'stopping_strings': []
}

llm = ApiLLM(api_url=endpoint, api_params=params)

Let's do a sample test call to our custom LLM to verify that things work as expected:

In [4]:
print(llm("Once upon a time,"))

 of the new blog.
I’m not sure how to use this yet, but I will figure it out soon enough! This should be fun and interesting for me as well…and hopefully you too!!


# Document indexing

This part of the notebook takes care of preparing an embeddings database for the documents that will become part of the context prompt provided our model to help it answer the question.

First we load the data, which we have as a set of Markdown files, and split it in chunks that attempt to respect the logical structure determined by Markdown:

In [5]:
# Note: using TextLoader here instead of UnstructuredMarkdownLoader. MarkdownTextSplitter will do the job of parsing markdown
loader = DirectoryLoader('../data/external', glob="**/*.md", loader_cls=TextLoader)
documents = loader.load()  # FYI / note to self: with the dataset at the time of this writing, documents[42] is the FAQ

# Split the documents into chunks, respecting Markdown structure.
# The chunk size represents a trade-off between the LLM context size and the quantity of information that can be provided as context.
text_splitter = MarkdownTextSplitter(chunk_overlap=0, chunk_size=750)
texts = text_splitter.split_documents(documents)
print(f"{len(documents)} documents were loaded in {len(texts)} chunks")

75 documents were loaded in 7058 chunks


Now we create (or load, if it already exists) a vector store from the documents:

In [8]:
embeddings = HuggingFaceEmbeddings()

# https://langchain.readthedocs.io/en/latest/modules/indexes/vectorstore_examples/chroma.html#persist-the-database
db_dir = "../data/interim"
docsearch = None
if os.path.isdir(os.path.join(db_dir, "index")):
    # Load the existing vector store
    docsearch = Chroma(persist_directory=db_dir, embedding_function=embeddings)
else:
    # Create a new vector store
    docsearch = Chroma.from_documents(texts, embeddings, persist_directory=db_dir)
    docsearch.persist()

Running Chroma using direct local API.
loaded in 7058 embeddings
loaded in 1 collections
collection with name langchain already exists, returning existing collection


# Question Answering

For this example we use a manually defined `RetrievalQA` chain with a custom template prompt.

References:
- https://python.langchain.com/en/latest/modules/chains/index_examples/vector_db_qa.html
- see other notebooks in this repo for different chain approaches

## Chain definition

In [9]:
template = """When learning about Red Hat OpenShift Service on AWS (ROSA), and considering the following context:
========= snippets start here =========
{context}
========= spippets end here =========
Given this question: {question}
The answer to the question is:"""
qa_prompt = PromptTemplate(template=template, input_variables=["question", "context"])
chain_type_kwargs = {"prompt": qa_prompt}

qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=docsearch.as_retriever(search_kwargs={"k": 3}),  # k is the number of documents to retrieve
    chain_type_kwargs=chain_type_kwargs)

## Question answering

We define a list of questions to ask:

In [10]:
queries = ['Where can I see a roadmap or make feature requests for the service?',
           'How is the pricing of Red Hat OpenShift Service on AWS calculated?',
           'Is there an upfront commitment?',
           'How can I delete ROSA cluster?',
           'Can I shut down my VMs temporarily?', # https://docs.openshift.com/rosa/rosa_architecture/rosa_policy_service_definition/rosa-service-definition.html#rosa-sdpolicy-instance-types_rosa-service-definition
           'How can I automatically deploy ROSA cluster?',
           'How can my ROSA cluster autoscale?',
           'How can I install aws load balancer controller',
           'How can I install Prometheus Operator with my ROSA cluster?',
           'What time is it?', # adversarial example
           'How can I federate metrics to a centralized Prometheus Cluster?',
           'What is the meaning of life?'] # adversarial example

Now we iterate over the list of questions and call the QA chain for each, showing the resulting answer.

**NOTE**: the results below come from an API that is exposing the [MLP-7B Base](https://www.mosaicml.com/blog/mpt-7b) model.

In [11]:
answers = []
for query in queries:
    answers.append(qa_chain(query))

# Print the answers
for result in answers:
    print("="*80)
    print("Question:", result["query"])
    print("Answer: ", result["result"])

Question: Where can I see a roadmap or make feature requests for the service?
Answer:   https://access.redhat.com/documentation/en-US/OpenShift_Service_on_AWS/1.0/html/GettingStartedGuide/#gettingstartedguide-roadmapprojectfeaturerequests
Question: How is the pricing of Red Hat OpenShift Service on AWS calculated?
Answer:   _There's no simple way to calculate it because each customer will use different amounts of compute power over time_. For example, if your application requires more CPU than usual during peak times or when new features go live then that could increase your bill significantly.<br><br>
For most customers we recommend using our Reserved Instance program as described below so they can lock down their price at a discount compared with On Demand prices. This allows them to budget accurately based upon what they expect to spend rather than having unpredictable bills from month to month due to usage spikes. <br><br>

*Reserved instance discounts apply only to clusters create

# Conclusions

This notebook presents a simple way to wrap a custom Large Language Model (LLM) that is being exposed via an API so that it can be integrated into LangChain. The integration is illustrated with a sample set of question answering for Red Hat OpenShift Service on AWS (ROSA).

The integration provided here is very simple and uses a REST API. This works well enough for the use case of "batch" (or essentially non-interactive) question answering.

For more interactive use cases (interactive question answering or conversations), the simple wrapper class could be adapted to use the streaming WebSocket API that is also available. This should provide a better user experience in this type of use case. Some sample code to access the WebSocket API in `text-generation-webui` is available in [their repository](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-stream.py).

There is room for improvement in the question answering process: it should be possible to improve the results by working on better prompts, trying different models, and improving the document retrieval process (quantity and quality of the data, as well as the document retrieval process). However, this is beyond the scope of this notebook, whose main purpose is to illustrate the integration of a custom API-based LLM into LangChain.