In [27]:
import nltk
import string
import numpy as np
import pandas as pd
from unidecode import unidecode
import transformers
import torch
from sklearn.metrics.pairwise import cosine_similarity

from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.chains import RetrievalQA
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline

## Setup metric

In [28]:
lemmatizer = nltk.stem.WordNetLemmatizer()

def preprocess(corpus):
  corpus = corpus.lower()
  stopset = nltk.corpus.stopwords.words('english') + nltk.corpus.stopwords.words('russian') + list(string.punctuation)
  tokens = nltk.word_tokenize(corpus)
  tokens = [t for t in tokens if t not in stopset]
  tokens = [lemmatizer.lemmatize(t) for t in tokens]
  corpus = ' '.join(tokens)
  corpus = unidecode(corpus)
  return corpus

In [29]:
embeddings = OllamaEmbeddings(model='llama3')

In [30]:
def embeddings_cosine_sim_metric(expected_answers, predicted_answers):
  results = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    expected_embedding = np.array(embeddings.embed_query(expected_answer))
    predicted_embedding = np.array(embeddings.embed_query(predicted_answer))

    sim = cosine_similarity(
      expected_embedding.reshape(1, -1),
      predicted_embedding.reshape(1, -1),
    )[0][0]

    results.append(sim)

  return np.mean(results)

## Load QA dataset

In [31]:
qa_df = pd.read_csv('qa.csv')
qa_df

Unnamed: 0,question,answer
0,Following endaerectomy on the right common car...,Central aery of the retina
1,Hypomimia is ?,Deficit of expression by gesture
2,With which of the following receptors theophyl...,Adenosine receptors
3,"During Sx for meningioma, the left paracentral...",Rt. Leg and perineus
4,All of the following structure lies outside th...,Maxillary nerve
...,...,...
2071,What is the purpose of gephyrin in the glycine...,Involved in anchoring the receptor to a specif...
2072,What is the glycine receptor involved in ?,Reflex response\nCauses reciprocal inhibition ...
2073,What happens in hyperperplexia ?,It’s an exaggerated reflex Often caused by a m...
2074,What is hyperperplexia treated with ?,Benzodiazepine


## Setup experiments

In [32]:
def get_llama2_index_store_rag():
  index = VectorstoreIndexCreator(
    vectorstore_cls=DocArrayInMemorySearch,
    embedding=embeddings,
  ).from_loaders([])
  llm = Ollama(temperature=0, model='llama2')
  qa_llm = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type='stuff',
    retriever=index.vectorstore.as_retriever(),
    verbose=True,
    chain_type_kwargs = {
      'document_separator': '<<<<>>>>>',
    },
  )

  return qa_llm

In [33]:
def get_llama3_index_store_rag():
  index = VectorstoreIndexCreator(
    vectorstore_cls=DocArrayInMemorySearch,
    embedding=embeddings,
  ).from_loaders([])
  llm = Ollama(temperature=0, model='llama3')
  qa_llm = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type='stuff',
    retriever=index.vectorstore.as_retriever(),
    verbose=True,
    chain_type_kwargs = {
      'document_separator': '<<<<>>>>>',
    },
  )

  return qa_llm

In [52]:
def openbiollm_parser(output):
  query = output['query']
  result = output['result']
  idx = result.find('Helpful Answer: ')
  if idx != -1:
    return {'query': query, 'result': result[idx + len('Helpful answer: '):]}
  else:
    return {'query': query, 'result': result}

def get_llama3_openbiollm_8b_index_store_rag():
  index = VectorstoreIndexCreator(
    vectorstore_cls=DocArrayInMemorySearch,
    embedding=embeddings,
  ).from_loaders([])
  model = 'aaditya/OpenBioLLM-Llama3-8B'
  model_kwargs = {'torch_dtype': torch.bfloat16}
  pipeline = transformers.pipeline(
      'text-generation',
      model=model,
      model_kwargs=model_kwargs,
      device='cpu',
  )
  terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids('<|eot_id|>')
  ]
  llm = HuggingFacePipeline.from_model_id(
    model_id=model,
    task='text-generation',
    model_kwargs=model_kwargs,
    pipeline_kwargs={
      'max_new_tokens': 256,
      'eos_token_id': terminators,
      'do_sample': True,
      'temperature': 0.01,
      'top_p': 0.9,
    },
  )
  qa_llm = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type='stuff',
    retriever=index.vectorstore.as_retriever(),
    verbose=True,
    chain_type_kwargs = {
      'document_separator': '<<<<>>>>>',
    },
  )

  return qa_llm | openbiollm_parser

## Compare

In [53]:
rags = [
  { 'label': 'llama 2 index store rag', 'rag': get_llama2_index_store_rag() },
  { 'label': 'llama 3 index store rag', 'rag': get_llama3_index_store_rag() },
  { 'label': 'llama3 openbiollm 8b index store rag', 'rag': get_llama3_openbiollm_8b_index_store_rag() },
]

Loading checkpoint shards: 100%|██████████| 4/4 [00:25<00:00,  6.32s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:22<00:00,  5.50s/it]


In [58]:
results = {}

sample_df = qa_df.sample(frac=0.1)
questions = sample_df['question'].tolist()
expected_answers = sample_df['answer'].tolist()

for rag in rags:
  predicted_answers = []

  for index, question in enumerate(questions):
    percent = (index + 1) / len(sample_df) * 100
    predicted_answers.append(rag['rag'].invoke(question)['result'])

  result = embeddings_cosine_sim_metric(expected_answers, predicted_answers)

  results[rag['label']] = result

llama 2 index store rag
0.48%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
0.96%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
1.44%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
1.92%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
2.40%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
2.88%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
3.37%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
3.85%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
4.33%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
4.81%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
5.29%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
5.77%


[1m> Entering new RetrievalQA chain...[0m

[1m> Finished chain.[0m
6.25%


[1m> Entering new R

In [59]:
results

{'llama 2 index store rag': 0.41725124681704473,
 'llama 3 index store rag': 0.4304846762514193,
 'llama3 openbiollm 8b index store rag': 0.6262885296215033}