In [1]:
import os
import argparse
import pandas as pd
from tqdm import tqdm
from termcolor import colored
from transformers import pipeline
from langchain_community.llms import HuggingFacePipeline
from langchain_community.chat_models import ChatOllama
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter, TokenTextSplitter, NLTKTextSplitter, SpacyTextSplitter
import sys
sys.path.append('..')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def read_csv(file_name:str):
    file_path = os.path.join(os.getcwd(), file_name)
    return pd.read_csv(file_path)

def measure_accuracy(LLM, prompt, Q, A, RAG:bool=False):
    correct_count = 0
    wrong_count = 0
    unsure_count = 0
    for q,a in tqdm(zip(Q, A), total=len(Q), desc="Measuring Accuracy"):
        if RAG:
            pred = LLM.invoke(prompt + q)['result']
        else:
            pred = LLM.invoke(prompt + q).content
        pred = pred.strip()[0]
        print(f"Correct Answer: {a}, Predicted Answer: {pred}")

        if pred == 'X': # if the answer is X (don't know)
            unsure_count += 1
        else:
            if a == pred:
                correct_count += 1
            else:
                wrong_count += 1
    print(colored(f"Correct: {correct_count}/{len(Q)}", 'green'))
    print(colored(f"Wrong: {wrong_count}/{len(Q)}", 'red'))
    print(colored(f"Unsure: {unsure_count}/{len(Q)}", 'yellow'))
    return correct_count, wrong_count, unsure_count
        

def get_llm_config(params:dict):
    LLM_name = "mistral:instruct" # https://ollama.com/library/mistral:instruct
    LLM = ChatOllama(model=LLM_name, temperature=params['temperature'])
    params['llm'] = LLM_name
    params['is_hf'] = False
    return LLM, params

def get_llm(params:dict):
    LLM, config = get_llm_config(params)
    if params['is_hf']:
        pipe = pipeline(
            task = params['task'],
            model = LLM,
            tokenizer = params['tokenizer'],
            pad_token_id = params['tokenizer'].eos_token_id,
            max_length = params['max_length'],
            temperature = params['temperature'],
            do_sample = params['do_sample'] if params['task'] == 'text2text-generation' else None,
            top_p = params['top_p'] if params['task'] == 'text-generation' else None,
            repetition_penalty = params['repetition_penalty']
            )
        LLM = HuggingFacePipeline(pipeline = pipe)
    return LLM, config

### Multiple Choce Questions (MCQs)

In [3]:
# MCQs about Omicron
MCQs_omicrons = read_csv("MCQs_omicron.csv")
MCQs_omicrons_Q = MCQs_omicrons['Q']
MCQs_omicrons_A = MCQs_omicrons['A']

### Model without RAG

In [4]:
# model parameters
params = {
    'chain_type': 'stuff',
    'embedding_device': 'cuda',
    'embedding_model': 'sentence-transformers/all-MiniLM-L6-v2',
    'k': 3,
    'llm': 'google/flan-t5-base',
    'llm_device': 'cuda',
    'max_length': 2000,
    'quantize': True,
    'query_instruction': 'Represent the question for retrieving supporting documents',
    'repetition_penalty': 1.0,
    'search_type': 'similarity',
    'separator': '\n\n',
    'temperature': 0.05,
    'top_p': 1.0,
    'chunk_size': 500,
    'chunk_overlap': 0,
}

# engineered prompt template
prompt = """You are the angent that has to select the correct answer to the following multiple choice question in the context provided.
            You cannot speak human language, but you can only say one single letter.
            Choose the letter corresponding to the correct answer.
            If you don't know or unsure about the answer, just display the letter X without any additional text.
            If you know the answer, display the letter corresponding to the correct answer without any additional text.
            Whether you know or don't know the answer, do not display any other texts except for the letter.
            Your response always has to be just one single letter.
            Example of the answer that you have to say: A.
            Context is: """

In [5]:
# Large Language Model (LLM)
LLM, _ = get_llm(params)
print(colored(f"MCQ Generator LLM: {LLM}", "yellow"))

[33mMCQ Generator LLM: model='mistral:instruct' temperature=0.05[0m


In [6]:
# measure accuracy
acc = measure_accuracy(LLM, prompt, MCQs_omicrons_Q, MCQs_omicrons_A)

Measuring Accuracy:   4%|▍         | 1/25 [00:09<03:50,  9.61s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:   8%|▊         | 2/25 [00:21<04:11, 10.95s/it]

Correct Answer: B, Predicted Answer: C


Measuring Accuracy:  12%|█▏        | 3/25 [00:35<04:30, 12.29s/it]

Correct Answer: B, Predicted Answer: X


Measuring Accuracy:  16%|█▌        | 4/25 [00:45<04:03, 11.60s/it]

Correct Answer: A, Predicted Answer: X


Measuring Accuracy:  20%|██        | 5/25 [00:52<03:12,  9.62s/it]

Correct Answer: B, Predicted Answer: X


Measuring Accuracy:  24%|██▍       | 6/25 [01:02<03:08,  9.90s/it]

Correct Answer: A, Predicted Answer: A


Measuring Accuracy:  28%|██▊       | 7/25 [01:14<03:11, 10.63s/it]

Correct Answer: C, Predicted Answer: X


Measuring Accuracy:  32%|███▏      | 8/25 [01:21<02:40,  9.43s/it]

Correct Answer: A, Predicted Answer: A


Measuring Accuracy:  36%|███▌      | 9/25 [01:34<02:51, 10.70s/it]

Correct Answer: A, Predicted Answer: X


Measuring Accuracy:  40%|████      | 10/25 [01:43<02:31, 10.09s/it]

Correct Answer: C, Predicted Answer: X


Measuring Accuracy:  44%|████▍     | 11/25 [01:53<02:20, 10.00s/it]

Correct Answer: B, Predicted Answer: B


Measuring Accuracy:  48%|████▊     | 12/25 [02:04<02:13, 10.25s/it]

Correct Answer: A, Predicted Answer: X


Measuring Accuracy:  52%|█████▏    | 13/25 [02:09<01:43,  8.60s/it]

Correct Answer: A, Predicted Answer: X


Measuring Accuracy:  56%|█████▌    | 14/25 [02:21<01:48,  9.83s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  60%|██████    | 15/25 [02:27<01:26,  8.67s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  64%|██████▍   | 16/25 [02:36<01:17,  8.56s/it]

Correct Answer: C, Predicted Answer: X


Measuring Accuracy:  68%|██████▊   | 17/25 [02:46<01:11,  9.00s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  72%|███████▏  | 18/25 [02:54<01:01,  8.77s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  76%|███████▌  | 19/25 [03:01<00:48,  8.17s/it]

Correct Answer: C, Predicted Answer: X


Measuring Accuracy:  80%|████████  | 20/25 [03:13<00:47,  9.54s/it]

Correct Answer: B, Predicted Answer: C


Measuring Accuracy:  84%|████████▍ | 21/25 [03:34<00:51, 12.79s/it]

Correct Answer: C, Predicted Answer: X


Measuring Accuracy:  88%|████████▊ | 22/25 [03:44<00:36, 12.09s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  92%|█████████▏| 23/25 [03:53<00:22, 11.17s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  96%|█████████▌| 24/25 [04:02<00:10, 10.49s/it]

Correct Answer: A, Predicted Answer: X


Measuring Accuracy: 100%|██████████| 25/25 [04:09<00:00,  9.97s/it]

Correct Answer: C, Predicted Answer: C
[32mCorrect: 11/25[0m
[31mWrong: 2/25[0m
[33mUnsure: 12/25[0m





### Model with RAG

In [7]:
from langchain_community.embeddings import HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter, TokenTextSplitter, NLTKTextSplitter, SpacyTextSplitter
from langchain_community.vectorstores import Cassandra, Chroma, FAISS # vector database
from utils import pdf_loader, docs_splitter, get_embeddings, \
                    build_database, get_retriever, get_qa_chain

In [8]:
def RAG(pdf:str, params:dict):
    # get LLM
    LLM, config = get_llm(params)
    # load document
    pdf = os.path.join("../data", pdf) # PDF file
    assert pdf.endswith(".pdf"), "Please provide a PDF document"
    loader = PyPDFLoader(pdf) # PDF loader
    docs = loader.load() # load document
    # split document
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=params['chunk_size'], chunk_overlap=params['chunk_overlap'], separators=params['separator'])
    chunks = docs_splitter(text_splitter=text_splitter, docs=docs)
    # get embeddings
    embedding_model_wrapper = HuggingFaceInstructEmbeddings
    embedding_model = params['embedding_model']
    embeddings = get_embeddings(embedding_model_wrapper, embedding_model, device=params['embedding_device'], query_instruction=params['query_instruction'])
    # build database
    database = FAISS
    db = build_database(database, chunks, embeddings)
    # get retriever
    retriever = get_retriever(db, search_type=params['search_type'], k=params['k'])
    # get Q&A chain
    qa_chain = get_qa_chain(LLM, retriever, chain_type=params['chain_type'])
    return qa_chain

In [9]:
pdf = "Omicron Variant Symptoms and Treatment.pdf"
qa_chain = RAG(pdf, params)

[1m[32mDocument has been split into 31 chunks[0m
load INSTRUCTOR_Transformer


  return self.fget.__get__(instance, owner)()


max_seq_length  512
[1m[32mEmbeddings have been generated using HuggingFaceInstructEmbeddings[0m
[1m[32mBuilding FAISS vector database...[0m
[1m[32mFAISS vector database has successfully been built[0m
[1m[32mVector retriever has been created for similarity search[0m
[1m[32mQ&A chain has been created[0m


In [10]:
# measure accuracy
acc = measure_accuracy(qa_chain, prompt, MCQs_omicrons_Q, MCQs_omicrons_A, RAG=True)

Measuring Accuracy:   4%|▍         | 1/25 [00:10<04:22, 10.94s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:   8%|▊         | 2/25 [00:13<02:20,  6.11s/it]

Correct Answer: B, Predicted Answer: B


Measuring Accuracy:  12%|█▏        | 3/25 [00:17<01:54,  5.19s/it]

Correct Answer: B, Predicted Answer: B


Measuring Accuracy:  16%|█▌        | 4/25 [00:38<03:58, 11.37s/it]

Correct Answer: A, Predicted Answer: X


Measuring Accuracy:  20%|██        | 5/25 [00:39<02:32,  7.61s/it]

Correct Answer: B, Predicted Answer: B


Measuring Accuracy:  24%|██▍       | 6/25 [00:51<02:53,  9.14s/it]

Correct Answer: A, Predicted Answer: X


Measuring Accuracy:  28%|██▊       | 7/25 [01:03<02:59,  9.95s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  32%|███▏      | 8/25 [01:31<04:27, 15.73s/it]

Correct Answer: A, Predicted Answer: X


Measuring Accuracy:  36%|███▌      | 9/25 [01:40<03:40, 13.77s/it]

Correct Answer: A, Predicted Answer: A


Measuring Accuracy:  40%|████      | 10/25 [01:54<03:24, 13.66s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  44%|████▍     | 11/25 [02:17<03:52, 16.60s/it]

Correct Answer: B, Predicted Answer: B


Measuring Accuracy:  48%|████▊     | 12/25 [02:28<03:12, 14.80s/it]

Correct Answer: A, Predicted Answer: A


Measuring Accuracy:  52%|█████▏    | 13/25 [02:33<02:24, 12.00s/it]

Correct Answer: A, Predicted Answer: I


Measuring Accuracy:  56%|█████▌    | 14/25 [02:36<01:40,  9.15s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  60%|██████    | 15/25 [02:39<01:12,  7.26s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  64%|██████▍   | 16/25 [03:10<02:10, 14.49s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  68%|██████▊   | 17/25 [03:21<01:48, 13.52s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  72%|███████▏  | 18/25 [04:10<02:49, 24.19s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  76%|███████▌  | 19/25 [04:15<01:49, 18.20s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  80%|████████  | 20/25 [04:30<01:26, 17.28s/it]

Correct Answer: B, Predicted Answer: C


Measuring Accuracy:  84%|████████▍ | 21/25 [04:56<01:19, 19.86s/it]

Correct Answer: C, Predicted Answer: A


Measuring Accuracy:  88%|████████▊ | 22/25 [05:00<00:45, 15.20s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  92%|█████████▏| 23/25 [05:17<00:31, 15.83s/it]

Correct Answer: C, Predicted Answer: C


Measuring Accuracy:  96%|█████████▌| 24/25 [05:31<00:15, 15.31s/it]

Correct Answer: A, Predicted Answer: A


Measuring Accuracy: 100%|██████████| 25/25 [05:46<00:00, 13.86s/it]

Correct Answer: C, Predicted Answer: C
[32mCorrect: 19/25[0m
[31mWrong: 3/25[0m
[33mUnsure: 3/25[0m



