# Introduction: A Retreival Augmented Generation (RAG)-based AI assistant

#### We use Retreival Augmented Generation (RAG) to answer questions on the document(s) using an LLM
#### In this notebook we go through the following steps:

1. Collect the document(s) from different webpages in the website: We choose to build an AI assistant on the history of Bengal, so we pick the relevant pages from wikipedia
2. Split the document(s) into chunks -- could be the original pages themselves
3. Transform the chunks into embedding vectors 
3. Create a FAISS vector database with the chunks and embeddings
4. Query the database with a question, and retrieve the most relevant chunks, which we call context
5. Combine a "system prompt" and the context to create a query for the LLM
6. Generate an answer from the LLM using this query

#### - Large Language Model used: Gemma2B 
#### - Embedding Model used: bge-small-en-v1.5





# Pip install packages 

In [None]:
%%time
%%capture
!pip install tiktoken FlagEmbedding transformers faiss-gpu
!pip install sentence_transformers
!pip install -q -U wikipedia-api
#!pip install requests bs4

# Obtain the wikipedia pages for a given topic

In [None]:
from tqdm import tqdm
import wikipediaapi
import re

# Pre-compile the regular expression pattern for better performance
BRACES_PATTERN = re.compile(r'\{.*?\}|\}')

def remove_braces_and_content(text):
    """Remove all occurrences of curly braces and their content from the given text"""
    return BRACES_PATTERN.sub('', text)

def clean_string(input_string):
    """Clean the input string."""
    
    # Remove extra spaces by splitting the string by spaces and joining back together
    cleaned_string = ' '.join(input_string.split())
    
    # Remove consecutive carriage return characters until there are no more consecutive occurrences
    cleaned_string = re.sub(r'\r+', '\r', cleaned_string)
    
    # Remove all occurrences of curly braces and their content from the cleaned string
    cleaned_string = remove_braces_and_content(cleaned_string)
    
    # Return the cleaned string
    return cleaned_string

def extract_wiki(wiki, category_name):
    """Extract all references from a category on Wikipedia"""
    
    # Get the Wikipedia page corresponding to the provided category name
    category = wiki.page("Category:" + category_name)
    
    # Initialize an empty list to store page titles
    pages = []
    
    # Check if the category exists
    if category.exists():
        # Iterate through each article in the category and append its title to the list
        for article in category.categorymembers.values():
            pages.append(article.title)
    
    # Return the list of page titles
    return pages


def get_wiki(categories):
    """Retrieve Wikipedia pages from a list of categories and extract their content"""
    
    # Create a Wikipedia object
    wiki_wiki = wikipediaapi.Wikipedia('AI_Assistant', 'en')
    
    # Initialize lists to store explored categories and Wikipedia pages
    explored_categories = []
    wikipedia_pages = []

    # Iterate through each category
    print("- Processing Wikipedia categories:")
    for category_name in categories:
        print(f"\tExploring {category_name} on Wikipedia")
        
        # Get the Wikipedia page corresponding to the category
        category = wiki_wiki.page("Category:" + category_name)
        
        # Extract Wikipedia pages from the category and extend the list
        wikipedia_pages.extend(extract_wiki(wiki_wiki, category_name))
        
        # Add the explored category to the list
        explored_categories.append(category_name)

    # Extract subcategories and remove duplicate categories
    categories_to_explore = [item.replace("Category:", "") for item in wikipedia_pages if "Category:" in item]
    wikipedia_pages = list(set([item for item in wikipedia_pages if "Category:" not in item]))
    
    # Explore subcategories recursively
    while categories_to_explore:
        category_name = categories_to_explore.pop()
        print(f"\tExploring {category_name} on Wikipedia")
        
        # Extract more references from the subcategory
        more_refs = extract_wiki(wiki_wiki, category_name)

        # Iterate through the references
        for ref in more_refs:
            # Check if the reference is a category
            if "Category:" in ref:
                new_category = ref.replace("Category:", "")
                # Add the new category to the explored categories list
                if new_category not in explored_categories:
                    explored_categories.append(new_category)
            else:
                # Add the reference to the Wikipedia pages list
                if ref not in wikipedia_pages:
                    wikipedia_pages.append(ref)

    # Initialize a list to store extracted texts
    extracted_texts = []
    
    # Iterate through each Wikipedia page
    print("- Processing Wikipedia pages:")
    for page_title in tqdm(wikipedia_pages):
        try:
            # Make a request to the Wikipedia page
            page = wiki_wiki.page(page_title)

            # Check if the page summary does not contain certain keywords
            if "Biden" not in page.summary and "Trump" not in page.summary:
                # Append the page title and summary to the extracted texts list
                if len(page.summary) > len(page.title):
                    extracted_texts.append(page.title + " : " + clean_string(page.summary))

                # Iterate through the sections in the page
                for section in page.sections:
                    # Append the page title and section text to the extracted texts list
                    if len(section.text) > len(page.title):
                        extracted_texts.append(page.title + " : " + clean_string(section.text))
                        
        except Exception as e:
            print(f"Error processing page {page.title}: {e}")
                    
    # Return the extracted texts
    return extracted_texts

In [None]:
# collect all the pages under these categories....
categories = ["History_of_Bengal", "History of Bangladesh", "Kolkata",
             "History of Kolkata", "History of West Bengal" ]
extracted_texts = get_wiki(categories)
print("Found", len(extracted_texts), "Wikipedia pages")

In [None]:
#### Dump the texts for future use 

import pickle
with open("bengal_history_wikidump.pickle", 'wb') as file:
    pickle.dump(extracted_texts, file)

# Reformat text chunks

#### How large are the chunks on average?

In [None]:
import numpy as np
from matplotlib import pyplot as plt

chunks_sizes = [len(chunk) for chunk in extracted_texts]
print("Average number of characters per chunk : ""%.2f" % np.mean(chunks_sizes))

percentile_50th = np.percentile(chunks_sizes, 50)
plt.hist(chunks_sizes, bins=30)
plt.title('Distribution of text-chunks length, with 50th percentile')
plt.xlabel('Length of text-chunk in characters')
plt.axvline(x = percentile_50th, color = 'yellow', linestyle = '--', alpha = 0.9)
plt.gca().spines[['top', 'right',]].set_visible(False)
plt.show();


In [None]:
def split_text_chunks(strings, K, L):
    result = []
    for string in strings:
        if len(string) > K:
            i = 0
            while i < len(string):
                result.append(string[i:i+K])
                i += K - L
        else:
            result.append(string)
    return result

In [None]:
new_chunks = split_text_chunks(extracted_texts, K=2000, L =1000)

chunks_sizes = [len(chunk) for chunk in new_chunks]
print("Average number of characters per chunk : ""%.2f" % np.mean(chunks_sizes))

percentile_50th = np.percentile(chunks_sizes, 50)
plt.hist(chunks_sizes, bins=30)
plt.title('Distribution of text-chunks length, with 50th percentile')
plt.xlabel('Length of text-chunk in characters')
plt.axvline(x = percentile_50th, color = 'yellow', linestyle = '--', alpha = 0.9)
plt.gca().spines[['top', 'right',]].set_visible(False)
plt.show();


In [None]:
len(new_chunks)
#len(extracted_texts)

# Build a vector database using an embedding model


#### Download an embedding model. We consider 'BAAI/bge-small-en-v1.5' 
- The Embedding model will transform the text chunks into numerical vectors in high-dimensions
- Embedding vectors corresponding to text chunks similar in meaning (semantics) to each other, are closer to each other in the vector space
- Next we use FAISS to build a vector database

In [None]:
%%time
# Load a pre-trained sentence transformer model
from sentence_transformers import SentenceTransformer, util
import faiss
from transformers import AutoTokenizer, AutoModel
import torch
#model = SentenceTransformer('all-MiniLM-L6-v2')
model = SentenceTransformer('BAAI/bge-small-en-v1.5') 
                            
# Function to embed text chunks
def embed_text_chunks(text_chunks, model):
    embeddings = model.encode(text_chunks, convert_to_tensor=True, device='cuda')
    return embeddings.cpu().numpy()

# Function to create FAISS index
def create_faiss_index(text_chunks, model):
    embeddings = embed_text_chunks(text_chunks, model)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    del embeddings
    return index

# Create the FAISS index
faiss_index = create_faiss_index(new_chunks, model)

In [None]:
torch.cuda.empty_cache()

#### Download the reranker model.  Consider 'BAAI/bge-reranker-large'
- A query on the embedding vectors returns chunks that are most relevant to the query
- A reranker is used to further evaluate these chunks, based on certain criteria 
- It reorders or "re-ranks" them according to their quality or relevance

In [None]:
%%time
from FlagEmbedding import FlagReranker
reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation

In [None]:
# Function to search the FAISS index
def query_faiss(query_text, index, model, k=5):
    query_embedding = model.encode([query_text], convert_to_tensor=True, device='cuda').cpu().numpy()
    distances, indices = index.search(query_embedding, k)
    return distances, indices




# Load the Large Language Model and Query

- We use Gemma2B, a pretrained lightweight LLM wih 2 billion parameters, released by Google 


In [None]:
%%time

language_model_name = "/kaggle/input/gemma/transformers/2b-it/3" #"
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(language_model_name, use_fast = True)
language_model = AutoModelForCausalLM.from_pretrained(language_model_name, 
                            device_map="cuda",  # change this to "auto" if you want to run the LLM distributed on 2 GPUs, 
                             trust_remote_code=False,
                                             revision="main")
language_model.config.hidden_activation = 'gelu_pytorch_tanh'


### insert your code here!
print('Done loading model: '+ language_model_name)

#### Specify system prompt


In [None]:
system_message =  """You are a Professor of Bengali History with great expertise in the history of Bengal. 
Answer the question using the context provided.
Answer must be very detailed and thorough.
"""


#### Specify the query, aka question we want to ask
- Some example questions


In [None]:

query1 = "Who was Shashanka?" 
query2 = "When did the partition of Bengal take place?" 
query3 = "Who was Netaji Subhash Chandra Bose?" 
query4 = "When was Mujibur Rahaman assassinated?" 
query5 = "What is Rabindranath Tagore famous for?"
query6 = "Where was the Portuguese trading posts in Bengal?"
query7 = "Who founded the Pala Empire?" 
query8 = "What was thw religion of the Pala Empire?"
query9 =  "How did Islam spread in Bengal?"

#### query the vector database to identify matching chunks

- We choose the top 5 best matching chunks (this number could be something else as well, by varying the n_results parameter)

#### context_chunks_init are the chunks with relevant context (with regards to the query)

- We will use the reranker to sort these chunks based on relevance

#### Keep the "best" chunks that are the most relevant as per the reranker
 - Use a threshold
 
 #### Query the LLM with the full query, with system_message, context and query

In [None]:
%%time
n_top_matches = 5

query = query9
distances, indices = query_faiss(query, faiss_index, model, k = n_top_matches)
print(distances)
context_chunks_init = []
for i in range(n_top_matches):
    context_chunks_init.append(new_chunks[indices[0][i]])
    
#make pairs of query and chunks
query_and_chunks = [[query, chunk] for chunk in context_chunks_init]
scores_reranker = reranker.compute_score(query_and_chunks)
print("Scores of retrieved chunks:\n", np.round(  scores_reranker, decimals =2))
# indexes sorted according to new rank
print("\nOld order of retrieved chunks:", np.arange(n_top_matches))
max_idx_reranked = np.argsort(-np.array(scores_reranker))
print("\nRe-ranked order of retrieved chunks:", max_idx_reranked)

cut_off_score = 3
min_score = np.max(scores_reranker)
# if np.max(scores_reranker) > cut_off_score:
#     cut_off_score = 3

context_chunks = []
for idx in max_idx_reranked:
    if scores_reranker[idx] >= min_score:
        context_chunks.append(context_chunks_init[idx])
        
prompt_template = system_message + """\n\n\nContext:
{context}

Question: {question}

Answer:"""

context_chunks_as_str = '\n###\n'.join([str(elem) for elem in context_chunks])
llm_full_query = prompt_template.format(context=context_chunks_as_str, question=query)


input_ids = tokenizer(llm_full_query, return_tensors="pt").to("cuda")

outputs = language_model.generate(**input_ids, max_new_tokens = 1024, do_sample = True, top_k = 5, top_p = 0.1, temperature = 0.1)
full_answer = tokenizer.decode(outputs[0])

printed_answer = full_answer.split("Answer:")[1]
printed_answer = printed_answer.split("<eos>")[0]
print("\nQuestion:" + query+"\n")
print(printed_answer)

torch.cuda.empty_cache()