In [28]:
import json
import nltk
import string
import numpy as np
import pandas as pd
from unidecode import unidecode
import transformers
import torch
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
from tqdm import tqdm

from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.llms import Ollama
from langchain.indexes import VectorstoreIndexCreator
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.chains import RetrievalQA
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline

## Setup device

In [29]:
device = (
  'cuda'
  if torch.cuda.is_available()
  else 'mps'
  if torch.backends.mps.is_available()
  else 'cpu'
)

## Setup metric

In [30]:
lemmatizer = nltk.stem.WordNetLemmatizer()

def preprocess(corpus):
  corpus = corpus.lower()
  stopset = nltk.corpus.stopwords.words('english') + nltk.corpus.stopwords.words('russian') + list(string.punctuation)
  tokens = nltk.word_tokenize(corpus)
  tokens = [t for t in tokens if t not in stopset]
  tokens = [lemmatizer.lemmatize(t) for t in tokens]
  corpus = ' '.join(tokens)
  corpus = unidecode(corpus)
  return corpus

In [31]:
embeddings = OllamaEmbeddings(model='llama3')

In [32]:
def embeddings_cosine_sim_metric(expected_answers, predicted_answers):
  results = []

  for expected_answer, predicted_answer in zip(expected_answers, predicted_answers):
    expected_answer = preprocess(expected_answer)
    predicted_answer = preprocess(predicted_answer)

    expected_embedding = np.array(embeddings.embed_query(expected_answer))
    predicted_embedding = np.array(embeddings.embed_query(predicted_answer))

    sim = cosine_similarity(
      expected_embedding.reshape(1, -1),
      predicted_embedding.reshape(1, -1),
    )[0][0]

    results.append(sim)

  return np.mean(results)

## Load QA dataset

In [33]:
qa_df = pd.read_csv('qa.csv')
qa_df

Unnamed: 0,question,answer
0,Following endaerectomy on the right common car...,Central aery of the retina
1,Hypomimia is ?,Deficit of expression by gesture
2,With which of the following receptors theophyl...,Adenosine receptors
3,"During Sx for meningioma, the left paracentral...",Rt. Leg and perineus
4,All of the following structure lies outside th...,Maxillary nerve
...,...,...
2071,What is the purpose of gephyrin in the glycine...,Involved in anchoring the receptor to a specif...
2072,What is the glycine receptor involved in ?,Reflex response\nCauses reciprocal inhibition ...
2073,What happens in hyperperplexia ?,It’s an exaggerated reflex Often caused by a m...
2074,What is hyperperplexia treated with ?,Benzodiazepine


## Load documents

In [34]:
docs_dir = Path('./docs')
loaders = []

for file in tqdm(docs_dir.iterdir()):
  if file.is_file() and file.suffix == '.pdf':
    loader = PyPDFLoader(file)
    loaders.append(loader)

5it [00:00, 1286.99it/s]


## Setup LLMs

In [35]:
def get_llama3():
  return Ollama(temperature=0, model='llama3')

In [36]:
def openbiollm_parser(output):
  query = output['query']
  result = output['result']
  idx = result.find('Helpful Answer: ')
  if idx != -1:
    return {'query': query, 'result': result[idx + len('Helpful answer: '):]}
  else:
    return {'query': query, 'result': result}

def get_openbiollm_8b():
  model = 'aaditya/OpenBioLLM-Llama3-8B'
  model_kwargs = {'torch_dtype': torch.bfloat16}
  pipeline = transformers.pipeline(
    'text-generation',
    model=model,
    model_kwargs=model_kwargs,
    device=device,
  )
  terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids('<|eot_id|>')
  ]
  llm = HuggingFacePipeline.from_model_id(
    model_id=model,
    task='text-generation',
    model_kwargs=model_kwargs,
    pipeline_kwargs={
      'max_new_tokens': 256,
      'eos_token_id': terminators,
      'do_sample': True,
      'temperature': 0.01,
      'top_p': 0.9,
    },
  )
  return llm | openbiollm_parser

## Setup index stores

In [37]:
def get_doc_array_in_memory_search(loaders=[]):
  index = VectorstoreIndexCreator(
    vectorstore_cls=DocArrayInMemorySearch,
    embedding=embeddings,
  ).from_loaders(loaders)
  return index.vectorstore

## Setup experiments

In [38]:
llms = [
  ('LLaMA 3', get_llama3()),
  # ('OpenBioLLM Llama3 8B', get_openbiollm_8b()),
]

vector_stores = [
  ('Doc Array In Memory Search', get_doc_array_in_memory_search),
]

In [39]:
cache_path = Path('cache.json')
with open(cache_path, 'r') as f:
  cache = json.load(f)

In [40]:
df = pd.DataFrame(columns=['llm', 'vector_store', 'use_docs', 'accuracy'])

sample_df = qa_df.sample(frac=0.1)
questions = sample_df['question'].tolist()
expected_answers = sample_df['answer'].tolist()

for llm_name, llm in tqdm(llms, desc='LLMs'):
  for vector_store_name, get_vector_store in tqdm(vector_stores, desc='Vector Stores', leave=False):
    for use_docs in tqdm((True, False), desc='Use Docs', leave=False):
      vector_store = get_vector_store(loaders if use_docs else [])
      qa_llm = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type='stuff',
        retriever=vector_store.as_retriever(),
        verbose=True,
        chain_type_kwargs = {'document_separator': '<<<<>>>>>'},
      )

      predicted_answers = []

      for index, question in tqdm(enumerate(questions), desc='Questions', leave=False):
        key = f'{llm_name}_{vector_store_name}_{use_docs}'

        if not key in cache[question]:
          cache[key] = {}

        if not question in cache[key]:
          cache[key][question] = qa_llm.invoke(question)['result']

        predicted_answers.append(cache[key])

        with open(cache_path, 'w') as f:
          json.dump(cache, f)

      accuracy = embeddings_cosine_sim_metric(expected_answers, predicted_answers)

      row = pd.DataFrame({
        'llm': llm_name,
        'vector_store': vector_store_name,
        'use_docs': use_docs,
        'accuracy': accuracy,
      })
      df = pd.concat([df, row], ignore_index=True)

LLMs:   0%|          | 0/1 [00:00<?, ?it/s]
[A

1


KeyboardInterrupt: 

Exception ignored in: 'zmq.backend.cython.message.Frame.__dealloc__'
Traceback (most recent call last):
  File "zmq/backend/cython/checkrc.pxd", line 13, in zmq.backend.cython.checkrc._check_rc
KeyboardInterrupt: 


In [None]:
df