In [1]:
# Run this once/as needed
# !wget https://raw.githubusercontent.com/alexeygrigorev/minsearch/main/minsearch.py

--2024-09-15 00:27:44--  https://raw.githubusercontent.com/alexeygrigorev/minsearch/main/minsearch.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3832 (3.7K) [text/plain]
Saving to: ‘minsearch.py’


2024-09-15 00:27:44 (2.05 MB/s) - ‘minsearch.py’ saved [3832/3832]



In [2]:
import pandas as pd

## Ingestion

In [39]:
df = pd.read_csv('../data/medquad.csv')
df = df[0:1028]
df

Unnamed: 0,id,question,answer,source,focus_area
0,0,What is (are) Glaucoma ?,Glaucoma is a group of diseases that can damag...,NIHSeniorHealth,Glaucoma
1,1,What causes Glaucoma ?,"Nearly 2.7 million people have glaucoma, a lea...",NIHSeniorHealth,Glaucoma
2,2,What are the symptoms of Glaucoma ?,Symptoms of Glaucoma Glaucoma can develop in ...,NIHSeniorHealth,Glaucoma
3,3,What are the treatments for Glaucoma ?,"Although open-angle glaucoma cannot be cured, ...",NIHSeniorHealth,Glaucoma
4,4,What is (are) Glaucoma ?,Glaucoma is a group of diseases that can damag...,NIHSeniorHealth,Glaucoma
...,...,...,...,...,...
1023,1023,What are the stages of Childhood Brain Stem Gl...,Key Points\n - The plan for...,CancerGov,Childhood Brain Stem Glioma
1024,1024,what research (or clinical trials) is being do...,New types of treatment are being tested in cli...,CancerGov,Childhood Brain Stem Glioma
1025,1025,What are the treatments for Childhood Brain St...,Key Points\n - There are di...,CancerGov,Childhood Brain Stem Glioma
1026,1026,What is (are) Colorectal Cancer ?,Key Points\n - Colorectal c...,CancerGov,Colorectal Cancer


In [40]:
documents = df.to_dict(orient='records')

In [41]:
documents

[{'id': 0,
  'question': 'What is (are) Glaucoma ?',
  'answer': "Glaucoma is a group of diseases that can damage the eye's optic nerve and result in vision loss and blindness. While glaucoma can strike anyone, the risk is much greater for people over 60. How Glaucoma Develops  There are several different types of glaucoma. Most of these involve the drainage system within the eye. At the front of the eye there is a small space called the anterior chamber. A clear fluid flows through this chamber and bathes and nourishes the nearby tissues. (Watch the video to learn more about glaucoma. To enlarge the video, click the brackets in the lower right-hand corner. To reduce the video, press the Escape (Esc) button on your keyboard.) In glaucoma, for still unknown reasons, the fluid drains too slowly out of the eye. As the fluid builds up, the pressure inside the eye rises. Unless this pressure is controlled, it may cause damage to the optic nerve and other parts of the eye and result in loss 

In [42]:
import minsearch

In [43]:
index = minsearch.Index(
    text_fields=['question','answer','source','focus_area'],
    keyword_fields=['id']
)

In [44]:
index.fit(documents)

<minsearch.Index at 0x133533a10>

In [45]:
index

<minsearch.Index at 0x133533a10>

## RAG flow

In [46]:
from openai import OpenAI

client = OpenAI()

In [47]:
def search(query):
    boost = {}

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

In [48]:
prompt_template = """
You're a medical question answering assistant. Answer the QUESTION based on the CONTEXT from our question and answer database.
Use only the facts from the CONTEXT when answering the QUESTION.

QUESTION: {question}

CONTEXT:
{context}
""".strip()

entry_template = """
question: {question}
answer: {answer}
source: {source}
focus_area: {focus_area}

""".strip()

def build_prompt(query, search_results):
    context = ""
    
    for doc in search_results:
        context = context + entry_template.format(**doc) + "\n\n"

    prompt = prompt_template.format(question=query, context=context).strip()
    return prompt

In [49]:
def llm(prompt, model='gpt-4o-mini'):
    response = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}]
    )
    
    return response.choices[0].message.content

In [50]:
def rag(query, model='gpt-4o-mini'):
    search_results = search(query)
    prompt = build_prompt(query, search_results)
    #print(prompt)
    answer = llm(prompt, model=model)
    return answer

In [51]:
question = 'Can diabetes be prevented, and if so, how?'
answer = rag(question)
print(answer)

Yes, diabetes can be prevented, particularly type 2 diabetes. Currently, there is no way to delay or prevent type 1 diabetes. To prevent or delay type 2 diabetes, research has shown that making modest lifestyle changes can be effective. 

Key strategies include:

1. **Weight Loss**: Losing 5 to 10 percent of your starting weight can significantly cut the risk of developing type 2 diabetes. For instance, if you weigh 200 pounds, losing 10 to 20 pounds can be beneficial.

2. **Physical Activity**: Engaging in moderate physical activity for about 150 minutes per week can help. This could be made up of activities such as walking or other exercises.

3. **Dietary Changes**: Following a low-calorie, low-fat diet that enhances fiber intake and promotes healthy food choices is essential. Eating more fruits, vegetables, whole grains, and reducing portion sizes can aid in weight management and overall health.

4. **Regular Monitoring**: Individuals at risk should monitor their blood glucose leve

## Retrieval evaluation

In [61]:
df_question = pd.read_csv('../data/ground-truth-retrieval.csv')

In [62]:
df_question.head()

Unnamed: 0,id,question
0,0,What are the main causes of glaucoma and how d...
1,0,Can you explain the differences between open-a...
2,0,What are the symptoms that might indicate some...
3,0,How does increased pressure in the eye lead to...
4,0,What treatment options are available to help m...


In [63]:
ground_truth = df_question.to_dict(orient='records')

In [64]:
ground_truth[0]

{'id': 0,
 'question': 'What are the main causes of glaucoma and how does it develop over time?'}

In [65]:
def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)

def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)

In [66]:
def minsearch_search(query):
    boost = {}

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

In [67]:
def evaluate(ground_truth, search_function):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['id']
        results = search_function(q)
        relevance = [d['id'] == doc_id for d in results]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

In [68]:
from tqdm.auto import tqdm

In [69]:
evaluate(ground_truth, lambda q: minsearch_search(q['question']))

  0%|          | 0/5140 [00:00<?, ?it/s]

{'hit_rate': 0.7964980544747081, 'mrr': 0.4015955623494519}

## Finding the best parameters

In [70]:
df_validation = df_question[:100]
df_test = df_question[100:]

In [71]:
import random

def simple_optimize(param_ranges, objective_function, n_iterations=10):
    best_params = None
    best_score = float('-inf')  # Assuming we're minimizing. Use float('-inf') if maximizing.

    for _ in range(n_iterations):
        # Generate random parameters
        current_params = {}
        for param, (min_val, max_val) in param_ranges.items():
            if isinstance(min_val, int) and isinstance(max_val, int):
                current_params[param] = random.randint(min_val, max_val)
            else:
                current_params[param] = random.uniform(min_val, max_val)
        
        # Evaluate the objective function
        current_score = objective_function(current_params)
        
        # Update best if current is better
        if current_score > best_score:  # Change to > if maximizing
            best_score = current_score
            best_params = current_params
    
    return best_params, best_score

In [72]:
gt_val = df_validation.to_dict(orient='records')

In [73]:
def minsearch_search(query, boost=None):
    if boost is None:
        boost = {}

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

In [76]:
param_ranges = {
    'question': (0.0, 3.0),
    'answer': (0.0, 3.0),
    'source': (0.0, 3.0),
    'focus_area': (0.0, 3.0),
}

def objective(boost_params):
    def search_function(q):
        return minsearch_search(q['question'], boost_params)

    results = evaluate(gt_val, search_function)
    return results['mrr']

In [77]:
simple_optimize(param_ranges, objective, n_iterations=20)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

({'question': 0.3417604466420481,
  'answer': 1.8419912735713542,
  'source': 0.25631365201530965,
  'focus_area': 0.22228140744464786},
 0.48292063492063514)

In [79]:
def minsearch_improved(query):
    boost = {
        'question': 0.3417604466420481,
        'answer': 1.8419912735713542,
        'source': 0.25631365201530965,
        'focus_area': 0.22228140744464786
    }

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=10
    )

    return results

evaluate(ground_truth, lambda q: minsearch_improved(q['question']))

  0%|          | 0/5140 [00:00<?, ?it/s]

{'hit_rate': 0.8587548638132295, 'mrr': 0.47430895250447663}

## RAG evaluation

In [80]:
prompt2_template = """
You are an expert evaluator for a RAG system.
Your task is to analyze the relevance of the generated answer to the given question.
Based on the relevance of the generated answer, you will classify it
as "NON_RELEVANT", "PARTLY_RELEVANT", or "RELEVANT".

Here is the data for evaluation:

Question: {question}
Generated Answer: {answer_llm}

Please analyze the content and context of the generated answer in relation to the question
and provide your evaluation in parsable JSON without using code blocks:

{{
  "Relevance": "NON_RELEVANT" | "PARTLY_RELEVANT" | "RELEVANT",
  "Explanation": "[Provide a brief explanation for your evaluation]"
}}
""".strip()

In [81]:
len(ground_truth)

5140

In [85]:
record = ground_truth[0]
question = record['question']
answer_llm = rag(question)

In [86]:
print(answer_llm)

The main causes of glaucoma include high eye pressure and damage to the optic nerve, which can occur due to a buildup of fluid in the eye that does not drain properly. Key risk factors include age—particularly individuals over 60 and African-Americans over 40—as well as a family history of glaucoma and elevated blood pressure. 

Glaucoma develops over time as the pressure inside the eye increases, often leading to damage of the optic nerve. The most common type, open-angle glaucoma, occurs when the fluid drains too slowly through the eye's drainage system, leading to increased intraocular pressure. If untreated, this can result in gradual vision loss, starting with peripheral vision and potentially leading to complete vision loss. Early diagnosis and treatment are important to protect remaining vision.


In [87]:
prompt = prompt2_template.format(question=question, answer_llm=answer_llm)
print(prompt)

You are an expert evaluator for a RAG system.
Your task is to analyze the relevance of the generated answer to the given question.
Based on the relevance of the generated answer, you will classify it
as "NON_RELEVANT", "PARTLY_RELEVANT", or "RELEVANT".

Here is the data for evaluation:

Question: What are the main causes of glaucoma and how does it develop over time?
Generated Answer: The main causes of glaucoma include high eye pressure and damage to the optic nerve, which can occur due to a buildup of fluid in the eye that does not drain properly. Key risk factors include age—particularly individuals over 60 and African-Americans over 40—as well as a family history of glaucoma and elevated blood pressure. 

Glaucoma develops over time as the pressure inside the eye increases, often leading to damage of the optic nerve. The most common type, open-angle glaucoma, occurs when the fluid drains too slowly through the eye's drainage system, leading to increased intraocular pressure. If unt

In [88]:
import json

In [89]:
df_sample = df_question.sample(n=200, random_state=1)

In [90]:
sample = df_sample.to_dict(orient='records')

In [91]:
evaluations = []

for record in tqdm(sample):
    question = record['question']
    answer_llm = rag(question) 

    prompt = prompt2_template.format(
        question=question,
        answer_llm=answer_llm
    )

    evaluation = llm(prompt)
    evaluation = json.loads(evaluation)

    evaluations.append((record, answer_llm, evaluation))

  0%|          | 0/200 [00:00<?, ?it/s]

In [92]:
df_eval = pd.DataFrame(evaluations, columns=['record', 'answer', 'evaluation'])

df_eval['id'] = df_eval.record.apply(lambda d: d['id'])
df_eval['question'] = df_eval.record.apply(lambda d: d['question'])

df_eval['relevance'] = df_eval.evaluation.apply(lambda d: d['Relevance'])
df_eval['explanation'] = df_eval.evaluation.apply(lambda d: d['Explanation'])

del df_eval['record']
del df_eval['evaluation']

In [93]:
df_eval.relevance.value_counts(normalize=True)

relevance
RELEVANT           0.925
NON_RELEVANT       0.045
PARTLY_RELEVANT    0.030
Name: proportion, dtype: float64

In [94]:
df_eval.to_csv('../data/rag-eval-gpt-4o-mini.csv', index=False)

In [95]:
df_eval[df_eval.relevance == 'NON_RELEVANT']

Unnamed: 0,answer,id,question,relevance,explanation
6,The provided context does not contain specific...,633,How does the reactivated virus affect the skin?,NON_RELEVANT,The generated answer clearly states that it la...
10,The provided context does not contain any info...,128,Are there any specific eligibility requirement...,NON_RELEVANT,The generated answer indicates that there is n...
27,The key symptoms associated with Meniere's dis...,155,What are the key symptoms associated with Meni...,NON_RELEVANT,The generated answer explicitly states that ke...
77,The provided context does not contain specific...,214,How does parathyroid hormone therapy work for ...,NON_RELEVANT,The generated answer acknowledges the lack of ...
109,The provided context does not contain specific...,322,What is the progression of vision loss in Wet ...,NON_RELEVANT,The generated answer states that it cannot pro...
119,The provided context does not contain informat...,210,At what age does bone mass typically stop incr...,NON_RELEVANT,The generated answer does not provide any info...
123,The provided context does not specifically men...,488,What type of anesthesia is used during a media...,NON_RELEVANT,The generated answer states that it cannot pro...
149,The context provided does not specifically add...,716,What specific tasks can I request help with wh...,NON_RELEVANT,The generated answer explicitly states that it...
150,The context provided does not contain informat...,111,What are the potential outcomes of living with...,NON_RELEVANT,The generated answer states that it cannot pro...


In [107]:
evaluations_gpt4o = []

for record in tqdm(sample):
    question = record['question']
    answer_llm = rag(question, model='gpt-4o') 

    prompt = prompt2_template.format(
        question=question,
        answer_llm=answer_llm
    )

    evaluation = llm(prompt)
    evaluation = json.loads(evaluation)
    
    evaluations_gpt4o.append((record, answer_llm, evaluation))

  0%|          | 0/200 [00:00<?, ?it/s]

In [108]:
df_eval = pd.DataFrame(evaluations_gpt4o, columns=['record', 'answer', 'evaluation'])

df_eval['id'] = df_eval.record.apply(lambda d: d['id'])
df_eval['question'] = df_eval.record.apply(lambda d: d['question'])

df_eval['relevance'] = df_eval.evaluation.apply(lambda d: d['Relevance'])
df_eval['explanation'] = df_eval.evaluation.apply(lambda d: d['Explanation'])

del df_eval['record']
del df_eval['evaluation']

In [109]:
df_eval.relevance.value_counts()

relevance
RELEVANT           188
NON_RELEVANT         6
PARTLY_RELEVANT      6
Name: count, dtype: int64

In [110]:
df_eval.relevance.value_counts(normalize=True)

relevance
RELEVANT           0.94
NON_RELEVANT       0.03
PARTLY_RELEVANT    0.03
Name: proportion, dtype: float64

In [111]:
df_eval.to_csv('../data/rag-eval-gpt-4o.csv', index=False)