![image](https://raw.githubusercontent.com/IBM/watson-machine-learning-samples/master/cloud/notebooks/headers/watsonx-Prompt_Lab-Notebook.png)
# Use Watsonx to respond to natural language questions using RAG approach

**Note:** Please note that for the watsonx challenge, please consider running these notebooks locally on your laptop/desktop.

This notebook contains the steps and code to demonstrate support of Retrieval Augumented Generation in watsonx.ai. It introduces commands for data retrieval, knowledge base building & querying, and model testing.

This notebook uses Python 3.10.

#### Objective

Use ibm/granite-13b-chat-v2, Langchain and Milvus to create a Retrieval Augmented Generation (RAG) system. This will allow us to ask questions about our documents (that were not included in the training data), without fine-tunning the Large Language Model (LLM). When using RAG, if you are given a question, you first do a retrieval step to fetch any relevant documents from a special database, a vector database where these documents were indexed.

Retrieval Augmented Generation (RAG) is a versatile pattern that can unlock a number of use cases requiring factual recall of information, such as querying a knowledge base in natural language.

##### Definitions:

* LLM - Large Language Model
* granite-13b-chat-v2 - LLM from IBM
* Langchain - a framework designed to simplify the creation of applications using LLMs
* Vector database - a database that organizes data through high-dimmensional vectors
* Milvus - vector database
* RAG - Retrieval Augmented Generation (see below more details about RAGs)

#### What is a Retrieval Augmented Generation (RAG) system?
Large Language Models (LLMs) has proven their ability to understand context and provide accurate answers to various NLP tasks, including summarization, Q&A, when prompted. While being able to provide very good answers to questions about information that they were trained with, they tend to hallucinate when the topic is about information that they do "not know", i.e. was not included in their training data. Retrieval Augmented Generation combines external resources with LLMs. The main two components of a RAG are therefore a retriever and a generator.

The retriever part can be described as a system that is able to encode our data so that can be easily retrieved the relevant parts of it upon queriying it. The encoding is done using text embeddings, i.e. a model trained to create a vector representation of the information. The best option for implementing a retriever is a vector database. As vector database, there are multiple options, both open source or commercial products. Few examples are ChromaDB, Mevius, FAISS, Pinecone, Weaviate. Our option in this Notebook will be a local instance of ChromaDB (persistent).

For the generator part, the obvious option is a LLM. In this Notebook we will use a quantized LLaMA v2 model, from the Kaggle Models collection.

The orchestration of the retriever and generator will be done using Langchain. A specialized function from Langchain allows us to create the receiver-generator in one line of code.

In its simplest form, RAG requires 3 steps:

- Index knowledge base passages (once)
- Retrieve relevant passage(s) from knowledge base (for every user query)
- Generate a response by feeding retrieved passage into a large language model (for every user query)


<a id="setup"></a>
##  Set up the environment



### Install and import dependecies

**Note:** For Windows environments, please remove `| tail -n 1` commands in the cell below.

In [1]:
#!pip install sentence_transformers | tail -n 1
#!pip install pandas | tail -n 1
#!pip install rouge_score | tail -n 1
#!pip install nltk | tail -n 1
#!pip install "ibm-watson-machine-learning>=1.0.312" | tail -n 1
#!pip install PyPDF2 | tail -n 1
#!pip install langchain | tail -n 1
#!pip install --upgrade pip | tail -n 1
#!pip install ibm-generative-ai | tail -n 1
#!pip install ipywidgets | tail -n 1
#!pip install ipywidgets widgetsnbextension pandas-profiling | tail -n 1
#!jupyter nbextension enable --py widgetsnbextension
#!python3 -m pip install tensorflow-macos | tail -n 1
#!python -m pip install tensorflow-metal | tail -n 1

In [2]:
# install milvus with [client] extras by pip
#!python3 -m pip install "milvus[client]" | tail -n 1
#!pip install -U --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

**Note:** Please restart the notebook kernel to pick up proper version of packages installed above.

In [3]:
import os
import pandas as pd
import string
import numpy as np
import re
import os
import zipfile
from pathlib import Path
from tqdm.notebook import tqdm
import pickle
import requests
import warnings
warnings.filterwarnings("ignore")
np.random.seed(0)

from sentence_transformers import SentenceTransformer
from milvus import default_server
from pymilvus import (
    Collection,
    CollectionSchema,
    DataType,
    FieldSchema,
    connections,
    utility,
)
from sklearn.model_selection import train_test_split
from langchain.document_loaders import DataFrameLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Milvus

from genai import Model
from genai.model import Credentials
from genai.schemas import GenerateParams
from dotenv import load_dotenv

### Watsonx API connection
This cell defines the credentials required to work with watsonx API for Foundation
Model inferencing.

**Action:** Provide the IBM Cloud user API key. For details, see
[documentation](https://cloud.ibm.com/docs/account?topic=account-userapikey&interface=ui).

In [4]:
load_dotenv()
BAM_API_URL = os.getenv("BAM_API_URL")
BAM_API_KEY = os.getenv("BAM_API_KEY")

In [5]:
credentials = Credentials(api_key=BAM_API_KEY, api_endpoint=BAM_API_URL)

### Defining the project id
The API requires project id that provides the context for the call. We will obtain the id from the project in which this notebook runs. Otherwise, please provide the project id.

**Hint**: You can find the `project_id` as follows. Open the prompt lab in watsonx.ai. At the very top of the UI, there will be `Projects / <project name> /`. Click on the `<project name>` link. Then get the `project_id` from Project's Manage tab (Project -> Manage -> General -> Details).


In [6]:
try:
    project_id = os.environ["PROJECT_ID"]
except KeyError:
    project_id = 'b9dc502c-91f0-4c12-ab48-684495500d51'

<a id="data"></a>
## Train and Test data loading

Load train and test datasets. At first, training dataset (`train_data`) should be used to work with the models to prepare and tune prompt. Then, test dataset (`test_data`) should be used to calculate the metrics score for selected model, defined prompts and parameters.

#### mss_alerts_1011.json 

In [7]:
df_08 = pd.read_table('S08_question_answer_pairs.txt')
df_09 = pd.read_table('S09_question_answer_pairs.txt')
df_10 = pd.read_table('S10_question_answer_pairs.txt', encoding='Windows-1252')

In [8]:
df = pd.concat([df_08, df_09, df_10])

In [9]:

df.drop(['DifficultyFromQuestioner', 'DifficultyFromAnswerer', 'ArticleTitle'], axis = 1, inplace=True)

df.dropna(inplace=True)
print('-' * 15, df.isna().sum(), sep='\n')

---------------
Question       0
Answer         0
ArticleFile    0
dtype: int64


In [10]:
# Limpar coluna "Answer"
def strip_last_punctuation(s):
  if s and s[-1] in string.punctuation:
    return s[:-1].strip()
  else:
    return s.strip()

df['answer_clean'] = df['Answer'].str.lower().map(strip_last_punctuation)

# Remove os dados faltantes da base de treino
df.dropna(inplace=True)
print('-' * 15, df.isna().sum(), sep='\n')

---------------
Question        0
Answer          0
ArticleFile     0
answer_clean    0
dtype: int64


In [11]:
train_data, test_data = train_test_split(df, test_size=0.20)
test_data = test_data.reset_index(drop=True)

## Build up knowledge base

The current state-of-the-art in RAG is to create dense vector representations of the knowledge base in order to calculate the semantic similarity to a given user query.

We can generate dense vector representations using embedding models. In this notebook, we use [SentenceTransformers](https://www.google.com/search?client=safari&rls=en&q=sentencetransformers&ie=UTF-8&oe=UTF-8) [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) to embed both the knowledge base passages and user queries. `all-MiniLM-L6-v2` is a performant open-source model that is small enough to run locally.

A vector database is optimized for dense vector indexing and retrieval. This notebook uses [Milvus](https://milvus.io), a user-friendly open-source vector database, licensed under Apache 2.0, which offers good speed and performance with all-MiniLM-L6-v2 embedding model.

The size of each passage is limited by the embedding model's context window (which is 256 tokens for `all-MiniLM-L6-v2`).

### Create an embedding function

Note that you can feed a custom embedding function to be used by Milvischromadb. The performance of Milvus may differ depending on the embedding model used.

In [12]:
# Select a Sentence Transformer: https://www.sbert.net/docs/pretrained_models.html
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

### Start Milvus

Start the Milvus embedded server

In [13]:
if default_server.running != True:
    default_server.start()
    print("Server should have now started")
else:
    default_server.stop()
    default_server.cleanup()
    default_server.start()
    print("Server is already running")

Server should have now started


Establish a connection with the embedded server and print its version information.

In [14]:
connections.connect(host="localhost", port=default_server.listen_port)
print(utility.get_server_version())

v2.3.3-lite


Define the collection

In [15]:
COLLECTION_NAME = "wikpedia_collection"
INDEX_NAME = "wikpedia_index"

In [16]:
# Run if you want to drop your old data
try:
    utility.drop_collection(COLLECTION_NAME)
    print("Collection has been deleted")
except:  # noqa: E722
    pass

Collection has been deleted


In [17]:
id = FieldSchema(
    name="id",
    dtype=DataType.INT64,
    is_primary=True,
    auto_id=True,
)

text = FieldSchema(
    name="text",
    dtype=DataType.VARCHAR,
    max_length=5120,
)

text_vector = FieldSchema(name="text_vector", dtype=DataType.FLOAT_VECTOR, dim=384)

qid = FieldSchema(name="qid", dtype=DataType.INT64)

title = FieldSchema(
    name="title",
    dtype=DataType.VARCHAR,
    max_length=5120,
)

schema = CollectionSchema(
    fields=[id, text, text_vector, qid, title],
    description="SIEM vector store",
    enable_dynamic_field=True,
)

collection = Collection(
    name=COLLECTION_NAME, schema=schema, using="default", shards_num=2
)

### Embed and index documents with Milvus

**Note: Could take several minutes if you don't have pre-built indices**


Prepare collection

In [18]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=5000, chunk_overlap=50, length_function=len, add_start_index=False
)


def split_and_prepare_document_new(qid: str, title: str, text: str):
    split_text = [text.page_content for text in text_splitter.create_documents([text])]
    ids = [qid] * len(split_text)
    titles = [title] * len(split_text)
    embeddings = [
        embedding_model.encode(xc, show_progress_bar=False) for xc in split_text
    ]
    return split_text, ids, titles, embeddings


def process_batch(document_list):
    batch_results = []

    for id, title, text in zip(
        document_list["id"].values.tolist(),
        document_list["title"].values.tolist(),
        document_list["text"].values.tolist(),
    ):
        for sub_text, sub_id, sub_title, sub_embedding in zip(
            *split_and_prepare_document_new(id, title, text)
        ):
            batch_results.append(tuple((sub_id, sub_title, sub_text, sub_embedding)))
    return batch_results

In [19]:
batch_size = 10
processed_docs = []
cache_filename = '/Users/rrfsantos/i-watsonx-siem/SIEM/prepared-docs.pkl'
allow_cache = True

if allow_cache and os.path.isfile(cache_filename):
    os.remove(cache_filename)

for i in tqdm(range(0, len(documents), batch_size), desc="Processing Documents in Batches"):
    # find end of batch
    i_end = min(i + batch_size, len(documents))
    documents_batch = documents[i:i_end]

    # Process the batch
    processed = process_batch(documents_batch)
    processed_docs.extend(processed)

with open(cache_filename, "wb") as f:
    pickle.dump(processed_docs, f)

print("Processed docs saved to pickle checkpoint")

NameError: name 'documents' is not defined

Insert the embeddings, texts, titles and documents id's in collection.

In [None]:
if default_server.running:
    collection = Collection(COLLECTION_NAME)
    error  = []
    batch_size = 1000
    for i in tqdm(
        range(0, len(processed_docs), batch_size),
        desc="Inserting documents batches to Milvus VectorDB",
    ):
        # find end of batch
        i_end = min(i + batch_size, len(processed_docs))
        id_l, title_l, text_l, embed_l = list(zip(*processed_docs[i:i_end]))

        data_to_insert = [text_l, embed_l, id_l, title_l]
        try:
            collection.insert(data_to_insert)
        except Exception as ex:
            print(f"Failed to insert: {ex}")
            error.append(title_l)
else:
    print("Milvus server is not running! Rerun related notebook cells.")

Inserting documents batches to Milvus VectorDB:   0%|          | 0/2 [00:00<?, ?it/s]

RPC error: [batch_insert], <MilvusException: (code=1100, message=the length (5141) of 35th string exceeds max length (5120): expected=valid length string, actual=string length exceeds max length: invalid parameter)>, <Time:{'RPC start': '2023-11-27 12:49:51.352043', 'RPC error': '2023-11-27 12:49:51.455651'}>
RPC error: [batch_insert], <MilvusException: (code=1100, message=the length (5153) of 31th string exceeds max length (5120): expected=valid length string, actual=string length exceeds max length: invalid parameter)>, <Time:{'RPC start': '2023-11-27 12:49:51.456497', 'RPC error': '2023-11-27 12:49:51.487858'}>


Failed to insert: <MilvusException: (code=1100, message=the length (5141) of 35th string exceeds max length (5120): expected=valid length string, actual=string length exceeds max length: invalid parameter)>
Failed to insert: <MilvusException: (code=1100, message=the length (5153) of 31th string exceeds max length (5120): expected=valid length string, actual=string length exceeds max length: invalid parameter)>


In [None]:
NLIST_SIZE = 1024

index_params = {
    "metric_type": "COSINE",
    "index_type": "HNSW",
    "params": {"nlist": NLIST_SIZE},
    "M": 16,
    "efConstruction": 200,
}

collection.create_index(field_name="text_vector", index_params=index_params)

print("Collection index has been successfully created!")

Collection index has been successfully created!


<a id="models"></a>
## Foundation Models on Watsonx

In [None]:
load_dotenv()
BAM_API_URL = os.getenv("BAM_API_URL")
BAM_API_KEY = os.getenv("BAM_API_KEY")

In [None]:
# get the list of supported models from the API
models = Model.models(credentials=credentials)

model_ids = []
for model_n in models:
    print(model_n.id)

salesforce/codegen2-16b
codellama/codellama-34b-instruct
tiiuae/falcon-180b
tiiuae/falcon-40b
ibm/falcon-40b-8lang-instruct
google/flan-t5-xl
google/flan-t5-xxl
google/flan-ul2
eleutherai/gpt-neox-20b
ibm/granite-13b-chat-v1
ibm/granite-13b-chat-v2
ibm/granite-13b-instruct-v1
ibm/granite-13b-instruct-v2
ibm/granite-20b-code-instruct-v1
ibm/granite-3b-code-plus-v1
elyza/japanese-llama-2-7b-fast
elyza/japanese-llama-2-7b-instruct
meta-llama/llama-2-13b
meta-llama/llama-2-13b-chat
meta-llama/llama-2-13b-chat-beam
meta-llama/llama-2-70b
meta-llama/llama-2-70b-chat
thebloke/llama-2-70b-chat-gptq
meta-llama/llama-2-7b
meta-llama/llama-2-7b-chat
mosaicml/mpt-30b
ibm/mpt-7b-instruct
bigscience/mt0-xxl
bigcode/starcoder
flan-t5-xl-mpt-HrlayZEh-2023-10-25-18-15-34


In [None]:
# select generative model to use
model_id = "meta-llama/llama-2-70b-chat"

# Iterate over the "results" list to find the matching model ID
for model_n in models:
    if model_n.id == model_id:
        model_token_limit = model_n.token_limit
        print(f"Model was found, it's token limit is {model_token_limit}.")
        break
else:
    print("Model was not found, pick a different one!")
    model_token_limit = None

Model was found, it's token limit is 8192.


In [None]:
# set-up inference parameters
params = GenerateParams(decoding_method="greedy", max_new_tokens=500, min_new_tokens=10,  repetition_penalty=1)

model = Model(model=model_id, credentials=credentials, params=params)

The input token limit depends on the selected generative model's max sequence length. The total input tokens in the RAG prompt should not exceed the model's max sequence length minus the number of desired output tokens. The choice of the number of paragraphs to retrieve as context impacts the number tokens in the prompt.

In [None]:
# For setting the input token limit we subtract the max_new_tokens (to be generated) and -1 from the model_token_limit
input_token_limit = model_token_limit - params.max_new_tokens - 1
print(f"Input token limit: {input_token_limit}")

Input token limit: 7991


<a id="predict"></a>
## Generate a retrieval-augmented response to a question

### Feed the context and the questions to `watsonx.ai` model.

### Using the context and train data

Select a question

Feed the context and the question to `genai` model

In [None]:
# Token counting function
def token_count(doc):
    return model.tokenize([doc])[0].token_count

`prompt_template` is a function to create a prompt from the given context and question. Changing the prompt will sometimes result in much more appropriate answers (or it may degrade the quality significantly). The prompt template below is most appropriate for short-form extractive use cases.

`make_prompt` includes a script to truncate the context length provided as an input in case the total token inputs exceed the model's limit. The paragraphs with the largest distance are truncated first. This functionality is helpful in case the embedded passages are not of the same size.

In [None]:
few_shot_example=[]
few_shot_examples=[]
for input,output,id in train_data.groupby('answer').apply(lambda x: x.sample(10)).values:
    few_shot_example.append(f"input: {input}\noutput: {output}")
few_shot_examples='\n\n\n'.join(few_shot_example)

In [None]:
def prompt_template(context, few_shot_examples, question_text):
    return ('''You are cybersecurity agent, your primary responsibility is to determine whether a given alert is a true positive or a false positive based on your knowledge of typical cyberattack indicators. Examine the provided alert, indicate whether it is a "false positive" or a "true positive",
provide a confidence score, and offer a detailed explanation of your reasoning.
Provide your final answer in the format: “true positive” or “false positive” 
Confidence score: number between 0 and 100
Reasoning: detailed explanation why the alert is true positive or false positive.
Below is an example of answer:
“true positive”
Confidence score: 80
Reasoning: The flow record analytics provided contains several indicators that suggest a potential DNS-based cyber attack.
Firstly, the source IP address fe80:0:0:0:a8f5:bfcb:1515:24dd is an internal IP address, which could indicate that the attacker is trying to use the internal
DNS server to perform the attack. Secondly, the destination IP address ff02:0:0:0:0:0:1:3 is a multicast DNS address,
which is commonly used in DNS-based attacks. Thirdly, the DNS query ID 17327 and the fact that it's a PTR, A, AAAA request,
suggests that the attacker is trying to perform a DNS lookup. Lastly, the fact that the source port is 57592, which is an unusual'''
        + "Context:\n\n"
        + f"{context}\n\n" 
        + "##\n\n"
        + f"{few_shot_examples}\n\n"
        + "##\n\n"
        + f"Alert: {question_text}\n\n"
        + "Answer: "
    )

In [None]:
def make_prompt(alert, few_shot_examples, input_token_limit):

    ### Create question embedding
    question_embeddings = embedding_model.encode(alert)

    ### Collect the context in vetctordb
    
    search_params = {"metric_type": "COSINE", "params": {"ef": 10}}

    results = collection.search(
       data=[question_embeddings],
       anns_field="text_vector",
       param=search_params,
       limit=4,
       expr=None,
       output_fields=["id","text"],  # name of the field to retrieve from the search result
       consistency_level="Strong",
    )
    print(f"Found {len(results[0])} results in the collection.")

    documents = []
    for raw_result in results:
        for result in raw_result:
            documents.append(result.entity.get("text"))
            print("=========")
            print("Paragraph : ", result.entity.get("text"))
            print("Distance : ", result.distance)

    context = "\n\n\n".join(documents)

    prompt = prompt_template(context, few_shot_examples, alert)

    prompt_token_count = token_count(prompt)

    if prompt_token_count <= input_token_limit:
        return prompt

    print("exceeded input token limit, truncating context", prompt_token_count)
    
    # documents with the lower distance scores are included in the truncated context first
    distances = results[0].distances
    sorted_indices = sorted(range(len(distances)), key=lambda k: distances[k])

    truncated_context = ""
    token_count_so_far = 0
    i = 0

    while token_count_so_far <= input_token_limit and i < len(sorted_indices):
        doc_index = sorted_indices[i]
        document = documents[doc_index]
        doc_token_count = token_count(document) + token_count(few_shot_examples)

        if token_count_so_far + doc_token_count <= input_token_limit:
            truncated_context += document + "\n\n\n"
            token_count_so_far += doc_token_count
        else:
            remaining_tokens = input_token_limit - token_count_so_far
            truncated_context += document[:remaining_tokens]
            break

        i += 1

    return prompt_template(truncated_context, few_shot_examples, alert)

In [None]:
%time
alerts = test_data['question'].tolist()
collection = Collection(COLLECTION_NAME)      # Get the existing collection.
collection.load() 
prompt_texts = []
for alert in alerts:
    prompt_text = make_prompt(alert, few_shot_examples, input_token_limit)
    prompt_texts.append(prompt_text)

CPU times: user 1e+03 ns, sys: 1 µs, total: 2 µs
Wall time: 3.1 µs
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the collection.
Found 0 results in the 

Generate responses

In [None]:
%time
answers = []
for response in model.generate(prompt_texts):
    answers.append(response.generated_text)

CPU times: user 2 µs, sys: 2 µs, total: 4 µs
Wall time: 5.96 µs


In [None]:
correct = 0
total = 0
for idx, answer in enumerate(answers):
    print("Model Output = ", answer )
    print("Expected Answer = ", test_data.iloc[idx]['answer'])
    print("\n")
    if answer.lower()[2:6] == (test_data.iloc[idx]['answer'].lower())[0:4]:
        correct += 1
    total += 1

Model Output =  
"false positive"

Reasoning: The flow record analytics provided contains several indicators that suggest a potential DNS-based cyber attack.
Firstly, the source IP address fe80:0:0:0:a8f5:bfcb:1515:24dd is an internal IP address, which could indicate that the attacker is trying to use the internal
DNS server to perform the attack. Secondly, the destination IP address ff02:0:0:0:0:0:1:3 is a multicast DNS address,
which is commonly used in DNS-based attacks. Thirdly, the DNS query ID 17327 and the fact that it's a PTR, A, AAAA request,
suggest that the attacker is trying to perform a DNS lookup. Lastly, the fact that the source port is 57592, which is an unusual port,
suggests that the attacker is trying to perform a DNS lookup.

However, the
Expected Answer =  false positive


Model Output =  
"true positive"

Confidence score: 100

Reasoning: The alert is a true positive because it contains indicators that suggest a potential cyber attack.

Firstly, the source IP addr

In [None]:
print("correct: ", correct)
print("total: ", total)
print("percent correct ", correct/total)

correct:  124
total:  166
percent correct  0.7469879518072289


In [None]:
default_server.stop()