### Creates Vector embeddings and FAISS index from knowledge base

In [3]:
import json
import faiss
import numpy as np
from pathlib import Path
from sentence_transformers import SentenceTransformer
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
RAG_DIR = Path("../data/rag/")
FAISS_DIR = Path("../data/faiss/")
FAISS_DIR.mkdir(parents=True, exist_ok=True)

# EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"  # 384 dimensions, fast
EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"  # 768 dimensions, better quality

In [5]:
def load_knowledge_base():
	print("Loading knowledge base...")
	kb_path = RAG_DIR / "knowledge_base.json"
	# kb_path = RAG_DIR / "knowledge_base.jsonl"
	
	if not kb_path.exists():
		raise FileNotFoundError(
			f"Knowledge base not found at {kb_path}. "
			"Please run data preprocessing first."
		)
	
	with open(kb_path, 'r', encoding='utf-8') as f:
		knowledge_base = json.load(f)
	
	print(f"Loaded {len(knowledge_base)} knowledge base entries.")
	return knowledge_base

In [6]:
def create_embeddings(knowledge_base):
	print("Creating embeddings...")
	model = SentenceTransformer(EMBEDDING_MODEL)
	print(f"Model loaded dimension: {model.get_sentence_embedding_dimension()}")

	# prepare text for embedding 
	texts = []
	metadata = []

	for chunk in knowledge_base:
		# text_to_embed = f"[{chunk['drug_name']}] [{chunk['category']}] {chunk['section_title']}:\n{chunk['text']}"
		# text_to_embed = chunk["text"] 
		text_to_embed = (
			f"Drug: {chunk['drug_name']}. "
			f"Category: {chunk['category']}. "
			f"Section: {chunk['section_title']}. "
			f"Content: {chunk['text']}"
		)

		texts.append(text_to_embed)

		# store metadata for later retrieval
		metadata.append({
			'drug_name': chunk['drug_name'],
			'category': chunk['category'],
			'section_title': chunk.get('section_title', ''), # if null in json return empty string
			'text': chunk['text'],
			'source': chunk.get('source', '') # if null in json return empty string
		})

	print(f"Encoding {len(texts)} texts...")
	embeddings = model.encode(
		texts,
		show_progress_bar=True,
		batch_size=32,
		convert_to_numpy=True
	)

	print(f"Created embeddings with shape: {embeddings.shape}")
	return embeddings, metadata, model 

In [7]:
def build_faiss_index(embeddings):
	print("Building FAISS index...")
	dimension = embeddings.shape[1]
	n_embeddings = embeddings.shape[0]
	print(f"Dimension: {dimension}")
	print(f"Number of embeddings: {n_embeddings}")

	if n_embeddings < 10000:
		index = faiss.IndexFlatL2(dimension)  # exact search
		print("Using IndexFlatL2 for exact search.")
	else:
		# for larger datasets, use an approximate index like IndexIVFFlat
		nlist = min(100, n_embeddings // 10)  # number of clusters
		quantizer = faiss.IndexFlatL2(dimension)
		index = faiss.IndexIVFFlat(quantizer, dimension, nlist)
		print(f"Using IndexIVFFlat for approximate search with {nlist} clusters.")
		# train the index
		print("Training the index...")
		index.train(embeddings)

	# add vectors to the index
	print("Adding embeddings to the index...")
	index.add(embeddings)
	print(f"FAISS index build with {index.ntotal} vectors.")

	return index

In [8]:
def save_index_and_metadata(index, metadata, model):
	print("Saving FAISS index and metadata...")
	# save faiss index
	index_path = FAISS_DIR / "drug_knowledge.index"
	faiss.write_index(index, str(index_path))
	print(f"FAISS index saved to {index_path} ")

	# save metada
	metadata_path = FAISS_DIR / "metadata.pkl"
	with open(metadata_path, 'wb') as f:
		pickle.dump(metadata, f)
	print(f"metadata saved to {metadata_path} ")

	# save configuration 
	config = {
		'embedding_model': EMBEDDING_MODEL,
		'embedding_dimension': model.get_sentence_embedding_dimension(),
		'num_chunks' : len(metadata),
		'index_type': type(index).__name__
	}

	config_path = FAISS_DIR / "config.json"
	with open(config_path, 'w') as f:
		json.dump(config, f, indent=2)
	print(f"Configuration saved to {config_path}")

### Run all function above

In [9]:
# load knowledge base
knowledge_base = load_knowledge_base()
# create embeddings
embeddings, metadata, model = create_embeddings(knowledge_base)
# build faiss index
index = build_faiss_index(embeddings)
# save index and metadata
save_index_and_metadata(index, metadata, model)

Loading knowledge base...
Loaded 504 knowledge base entries.
Creating embeddings...
Model loaded dimension: 768
Encoding 504 texts...


Batches: 100%|██████████| 16/16 [00:04<00:00,  3.99it/s]

Created embeddings with shape: (504, 768)
Building FAISS index...
Dimension: 768
Number of embeddings: 504
Using IndexFlatL2 for exact search.
Adding embeddings to the index...
FAISS index build with 504 vectors.
Saving FAISS index and metadata...
FAISS index saved to ..\data\faiss\drug_knowledge.index 
metadata saved to ..\data\faiss\metadata.pkl 
Configuration saved to ..\data\faiss\config.json





### Retrieval Test

In [10]:
def test_retrieval(index, metadata, model):
	test_queries = [
		"what is the dosage for ibuprofen",
		"side effects of acetaminophen",
		"how to take amoxicillin",
		"Contraindications of aspirin"
	]

	for query in test_queries:
		print(f"\nQuery: {query}")

		# encode query 
		query_embedding = model.encode([query])

		# search top K
		k = 3
		distance, indices = index.search(query_embedding, k)

		# display results
		for i, (dist, idx) in enumerate(zip(distance[0], indices[0])):
			if idx < len(metadata):
				result = metadata[idx]
				print(f"[Result {i+1}] Distance: {dist:.4f}")
				print(f"Drug: {result['drug_name']}")
				print(f"Category: {result['category']}")
				print(f"Text: {result['text'][:200]}...")

In [11]:
# test retrieval
test_retrieval(index, metadata, model)


Query: what is the dosage for ibuprofen
[Result 1] Distance: 0.6634
Drug: Ibuprofen
Category: overdosage
Text: Ingestion of less than 100 mg/kg is unlikely to produce toxicity. Children ingesting 100 to 200 mg/kg may be managed with induced emesis and a minimal observation time of four hours. Children ingestin...
[Result 2] Distance: 0.6994
Drug: Ibuprofen
Category: overdosage
Text: Children ingesting 200 to 400 mg/kg of ibuprofen should have immediate gastric emptying and at least four hours observation in a health care facility. Children ingesting greater than 400 mg/kg require...
[Result 3] Distance: 0.7608
Drug: Ibuprofen
Category: overdosage
Text: In children, the estimated amount of ibuprofen ingested per body weight may be helpful to predict the potential for development of toxicity although each case must be evaluated. Ingestion of less than...

Query: side effects of acetaminophen
[Result 1] Distance: 0.7299
Drug: Ibuprofen
Category: side_effects
Text: In patients taking ibup

### Evaluation

In [None]:

def recall_at_k(gold, index, metadata, model, k=3):
  hits = 0
  for item in gold:
    query = item['query']
    relevant = item['relevant']

    q_emb = model.encode([query])
    _, indices = index.search(q_emb, k)
    retrieved = [metadata[i] for i in indices[0]]

    found = False 
    for r in retrieved:
      for gt in relevant:
        if(r["drug_name"].lower() == gt["drug_name"].lower() and r["category"].lower() == gt["category"].lower()):
          found = True
          break
    
    if found:
      hits += 1
  
  return hits / len(gold)

### Evaluation

In [None]:
import json
import numpy as np
from collections import defaultdict

In [None]:
# Recall@K evaluation : are there relevant chunks retrieved in top K?
# example: Recall@3 is number of queries with at least one relevant chunk in top 3 / total number of queries
def evaluate_recall_at_k(
    index,
    metadata,
    model,
    gold_queries_path,
    k_values=(1, 3, 5),
    verbose=False
):
    with open(gold_queries_path, "r", encoding="utf-8") as f:
        gold_queries = json.load(f)

    total_queries = len(gold_queries)
    recall_hits = {k: 0 for k in k_values}

    detailed_results = []

    for item in gold_queries:
        query = item["query"]
        expected_drug = item["expected_drug"]
        expected_categories = set(item["expected_category"])

        # Encode query
        query_embedding = model.encode([query], convert_to_numpy=True)

        # Search with max K
        max_k = max(k_values)
        distances, indices = index.search(query_embedding, max_k)

        retrieved = [
            metadata[idx] for idx in indices[0] if idx < len(metadata)
        ]

        # Check hits per K
        hit_at_k = {}

        for k in k_values:
            top_k = retrieved[:k]

            hit = any(
                (doc["drug_name"] == expected_drug) and
                (doc["category"] in expected_categories)
                for doc in top_k
            )

            hit_at_k[k] = hit
            if hit:
                recall_hits[k] += 1

        detailed_results.append({
            "query": query,
            "expected_drug": expected_drug,
            "expected_category": list(expected_categories),
            "hit_at_k": hit_at_k
        })

        if verbose and not any(hit_at_k.values()):
            print(f"MISS: {query}")
            for i, doc in enumerate(retrieved[:5]):
                print(
                    f"   {i+1}. {doc['drug_name']} | {doc['category']}"
                )

    # Compute recall
    recall_scores = {
        f"Recall@{k}": round(recall_hits[k] / total_queries, 4)
        for k in k_values
    }

    return recall_scores, detailed_results


In [20]:
recall_scores, details = evaluate_recall_at_k(
    index=index,
    metadata=metadata,
    model=model,
    gold_queries_path="../data/eval_retrieval/new_gold_queris.json",
    k_values=(1, 3, 5),
    verbose=True
)

print(recall_scores)


MISS: What precautions are listed for Omeprazole use?
   1. Omeprazole | side_effects
   2. Omeprazole | contraindications
   3. Omeprazole | contraindications
   4. Omeprazole | overdosage
   5. Omeprazole | overdosage
MISS: What precautions should patients know before taking Omeprazole?
   1. Omeprazole | side_effects
   2. Omeprazole | contraindications
   3. Omeprazole | overdosage
   4. Omeprazole | overdosage
   5. Omeprazole | overdosage
MISS: What drug interactions are associated with Albuterol?
MISS: Which medications should not be taken together with Ibuprofen?
   1. Ibuprofen | indications
   3. Ibuprofen | indications
   4. Ibuprofen | indications
   5. Ibuprofen | indications
MISS: Are there drug interaction concerns for Albuterol users?
MISS: What substances may interact with Ibuprofen?
   1. Ibuprofen | indications
   3. Ibuprofen | contraindications
   4. Ibuprofen | indications
   5. Ibuprofen | contraindications
MISS: Are there overdose risks associated with Atorvasta

In [21]:
print("Detail result: ")
for detail in details:
    print(detail)

Detail result: 
{'query': 'Does Atorvastatin interact with other medications?', 'expected_drug': 'Atorvastatin', 'expected_category': ['interactions'], 'hit_at_k': {1: True, 3: True, 5: True}}
{'query': 'What drug interactions are associated with Albuterol?', 'expected_drug': 'Albuterol', 'expected_category': ['interactions'], 'hit_at_k': {1: False, 3: False, 5: False}}
{'query': 'Are there known drug interactions with Omeprazole?', 'expected_drug': 'Omeprazole', 'expected_category': ['interactions'], 'hit_at_k': {1: True, 3: True, 5: True}}
{'query': 'Which medications should not be taken together with Ibuprofen?', 'expected_drug': 'Ibuprofen', 'expected_category': ['interactions'], 'hit_at_k': {1: False, 3: False, 5: False}}
{'query': 'What medications interact with Atorvastatin therapy?', 'expected_drug': 'Atorvastatin', 'expected_category': ['interactions'], 'hit_at_k': {1: True, 3: True, 5: True}}
{'query': 'Are there drug interaction concerns for Albuterol users?', 'expected_drug