# 1. Data Loading and Filtering Records with Focus (Primary or Secondary)

In [10]:
import json
with open("Data/meta_test.json", "r", encoding="utf-8") as f:
    records = json.load(f)
# Filter JSON entries where focus is primary or secondary
filtered_records = [entry for entry in records if any(f in ["primary", "secondary"] for f in entry["metadata"].get("focus", []))]

# Calculation of Primary and Secondary records %
filtered_records_percent = round(((len(filtered_records)/len(records)) * 100), 2)

print(f"Only {filtered_records_percent}% of entire records are Primary or Secondary ")



Only 85.42% of entire records are Primary or Secondary 


# 2. Data Restructuring

In [None]:
def list_all_roots(records):
    roots = set()
    for r in records:
        md = r.get("metadata", {})
        rn = md.get("root_name")
        if rn:
            roots.add(str(rn).strip())
    return sorted(roots)

def build_alias_map(records):
    alias2root = {}
    for r in records:
        md = r.get("metadata", {})
        root = str(md.get("root_name") or "").strip().lower()
        if not root:
            continue
        # map root to itself
        alias2root[root] = root
        # map synonyms to root
        for s in md.get("synonyms", []) or []:
            alias2root[str(s).strip().lower()] = root
    return alias2root

def metadata_restructuring(records):
    restructured_records = []
    for record in records:
        metadata = record.get("metadata", {}).copy()  # copy to avoid mutating original

        # Explicitly ensure top-level fields are part of metadata
        for field in ["root_name", "search_term", "synonyms", "PMID", "pubmed_type"]:
            if field in record:
                metadata[field] = record[field]
        
        restructured_records.append({"metadata": metadata})
    return restructured_records

restructured_records = metadata_restructuring(filtered_records)

ALIAS2ROOT = build_alias_map(restructured_records) 
ALL_ROOTS  = sorted(set(ALIAS2ROOT.values()))

### Optional: Validation Checkpoint to get matching record from json_list

In [12]:
# def get_record_by_pmid(json_list, pmid):
#     """Pass PMID and get matching record from json_list"""
#     for record in json_list:
#         if record['metadata']['PMID'] == pmid:
#             return record
#     return None


# # Example usage:
# result = get_record_by_pmid(restructured_records, 11524119)

# if result:
#     print(json.dumps(result, indent=2))  # Prints the entire matching record
# else:
#     print("PMID not found")

# 3. Flattening the Data

In [None]:
for record in restructured_records:
    metadata = record["metadata"]
    
    # Process interventions with Parallel - Indexing
    interventions = metadata.get("interventions", [])
    record["intervention_names"] = [i.get("ingredient") for i in interventions]
    record["intervention_dosages"] = [i.get("daily_dosage") for i in interventions]
    record["intervention_units"] = [i.get("units") if i.get("units") else "" for i in interventions]
    record["intervention_original_texts"] = [i.get("original_text") for i in interventions]
    
    # Process outcomes with Parallel - Indexing
    outcomes = metadata.get("outcomes", [])
    record["biomarker_names"] = [o["name"] for o in outcomes if o["domain"] == "biomarker"]
    record["biomarker_types"] = [o["type"] for o in outcomes if o["domain"] == "biomarker"]
    record["biomarker_results"] = [o["result"] for o in outcomes if o["domain"] == "biomarker"]

    record["function_names"] = [o["name"] for o in outcomes if o["domain"] == "function"]
    record["function_types"] = [o["type"] for o in outcomes if o["domain"] == "function"]
    record["function_results"] = [o["result"] for o in outcomes if o["domain"] == "function"]

    record["condition_names"] = [o["name"] for o in outcomes if o["domain"] == "condition"]
    record["condition_types"] = [o["type"] for o in outcomes if o["domain"] == "condition"]
    record["condition_results"] = [o["result"] for o in outcomes if o["domain"] == "condition"]

    # force consistent types for filtering
    if "published_year" in metadata and isinstance(metadata["published_year"], str) and metadata["published_year"].isdigit():
        metadata["published_year"] = int(metadata["published_year"])
    if "PMID" in metadata:
        metadata["PMID"] = str(metadata["PMID"])
    
    # Lowercase/canonicalize list fields (beyond synonyms)
    for key in ("study_type", "species", "experimental_model", "usage",
                "keywords", "benefits", "diseases", "symptoms", "sample_gender"):
        if key in metadata and isinstance(metadata[key], list):
            metadata[key] = [str(x).strip().lower() for x in metadata[key] if x is not None and str(x).strip()]

    # Lowercase single-string fields you might filter on:
    if isinstance(metadata.get("population"), str):
        metadata["population"] = metadata["population"].strip().lower()
    if isinstance(metadata.get("location"), str):
        metadata["location"] = metadata["location"].strip().lower()

    # Keep your synonyms normalization after this:
    syns = metadata.get("synonyms") or []
    if isinstance(syns, list):
        syns = sorted({str(x).strip().lower() for x in syns if x is not None and str(x).strip()})
    else:
        syns = []
    metadata["synonyms"] = syns
            
    # Delete original detailed fields
    for key in ["interventions", "outcomes", "biomarkers", "functions", "conditions"]:
        metadata.pop(key, None)


In [14]:
with open("flatten.json", "w", encoding="utf-8") as f:
    json.dump(restructured_records, f, indent=2, ensure_ascii=False)

# 4. Data Ingestion into PineCone 

### 4a. Converting into Embeddings and performing Sematic Chunking

In [15]:
# %pip install -U \
#   pandas \
#   "llama-index" \
#   "llama-index-embeddings-huggingface" \
#   "llama-index-vector-stores-p"inecone" \
#   "llama-index-retrievers-bm25" \
#   pinecone-client \
#   "sentence-transformers" \
#   transformers \
#   "torch" \
#   python-dotenv \
#   tqdm \
#     biopython

In [16]:
#pip install "numpy<2"

In [17]:
import pandas as pd
from llama_index.core import Document, VectorStoreIndex, StorageContext
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.node_parser import SemanticSplitterNodeParser
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.pinecone import PineconeVectorStore
from pinecone import Pinecone, ServerlessSpec
from UPDATED_meta_data_generation import *
from dotenv import load_dotenv
import os
from tqdm import tqdm

load_dotenv()

# --------------------------
# Initialize Pinecone
# --------------------------
INDEX_NAME = "pubmed-abstracts-v4"
client = Pinecone(api_key=os.getenv("PINECONE_API"))
spec = ServerlessSpec(cloud="aws", region="us-east-1")

if INDEX_NAME not in client.list_indexes().names():
    client.create_index(
        name=INDEX_NAME,
        dimension=768,
        metric="cosine",
        spec=spec
    )

pinecone_index = client.Index(INDEX_NAME)
vector_store = PineconeVectorStore(pinecone_index=pinecone_index)

# --------------------------
# Initialize embedding + semantic chunker
# --------------------------
embed_model = HuggingFaceEmbedding(model_name="NeuML/pubmedbert-base-embeddings")

splitter = SemanticSplitterNodeParser(
    buffer_size=1,
    breakpoint_percentile_threshold=95,
    embed_model=embed_model
)


# --------------------------
# Build all semantic nodes
# --------------------------
all_nodes = []

for idx, row in enumerate(tqdm(restructured_records, desc="Processing papers")):
    md = row["metadata"]
    paper = fetch_extract_and_abstract(md['PMID'])
    title = paper['title']
    abstract = paper['abstract']

    # Title node
    title_node = Document(
        text=title,
        metadata={"type": "title", "node_index": 0, **md}
    )
    all_nodes.append(title_node)

    # Abstract nodes
    abstract_doc = Document(
        text=abstract,
        metadata={"type": "abstract", **md}
    )

    abstract_nodes = splitter.get_nodes_from_documents([abstract_doc])
    for i, node in enumerate(abstract_nodes, start=1):
        node.metadata["node_index"] = i
        all_nodes.append(node)

# --------------------------
# Save nodes both to Pinecone (for vector) and local docstore
# --------------------------
# print("Indexing nodes into Pinecone and persisting locally...")

# index = VectorStoreIndex(
#     all_nodes,
#     storage_context=storage_context,
#     embed_model=embed_model,
#     show_progress=True
# )

# Persist docstore + metadata to disk
        
# --------------------------
# Create a persistent docstore
# --------------------------
docstore = SimpleDocumentStore()

docstore.add_documents(all_nodes)

storage_context = StorageContext.from_defaults(
    vector_store=vector_store,
    docstore=docstore
)

storage_context.persist(persist_dir="pubmed_nodes")



Processing papers: 100%|██████████| 82/82 [00:59<00:00,  1.39it/s]


### 4b. Injecting Embedded Chunks into PineCone

In [18]:
# --------------------------
# 4️ Store nodes in Pinecone on Cloud via LlamaIndex
# --------------------------
index = VectorStoreIndex([], storage_context=storage_context, embed_model=embed_model)
if all_nodes:
    index.insert_nodes(all_nodes, show_progress=True)
else:
    print("WARNING: No nodes to upsert.")

Generating embeddings:   0%|          | 0/247 [00:00<?, ?it/s]

Upserted vectors:   0%|          | 0/247 [00:00<?, ?it/s]

In [19]:
# Get Stats of Vector Index
stats = pinecone_index.describe_index_stats()
stats

{'dimension': 768,
 'index_fullness': 0.0,
 'metric': 'cosine',
 'namespaces': {'': {'vector_count': 247}},
 'total_vector_count': 247,
 'vector_type': 'dense'}

# Debugging: Similarity Search 

In [20]:
from llama_index.core.vector_stores import MetadataFilter, MetadataFilter, MetadataFilters, FilterCondition
from llama_index.core import Document
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.schema import QueryBundle

import re

from openai import OpenAI
import os, json

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=OPENAI_API_KEY)

def llm_extract_facets(query_text: str) -> dict:
    system = (
        "You extract structured search facets for PubMed-style retrieval. "
        "Return STRICT JSON with keys:\n"
        "- candidate_ingredients: list[string]\n"
        "- published_year_min: int or null\n"
        "- published_year_max: int or null\n"
        "- study_type: list[string]\n"
        "- neuraci_type: string or null\n"
        "- species: list[string]\n"
        "- population: string or null\n"
        "- sample_gender: list[string]\n"
        "- benefits: list[string]\n"
        "- diseases: list[string]\n"
        "- symptoms: list[string]\n"
        "- location: string or null\n"
        "Use null for unknown scalars and [] for unknown lists. No extra keys."
    )
    user = f"Query: {query_text}\nReturn only JSON."

    resp = client.chat.completions.create(
        model="gpt-4o-mini",
        response_format={"type": "json_object"},
        messages=[{"role": "system", "content": system},
                  {"role": "user", "content": user}],
        temperature=0.1,
    )
    try:
        data = json.loads(resp.choices[0].message.content)
    except Exception:
        data = {}

    # NEW: strict normalization to the new schema
    out = {
        "candidate_ingredients": [str(x) for x in data.get("candidate_ingredients", []) if str(x).strip()],
        "published_year_min": data.get("published_year_min") if isinstance(data.get("published_year_min"), int) else None,
        "published_year_max": data.get("published_year_max") if isinstance(data.get("published_year_max"), int) else None,
        "study_type": [str(x) for x in data.get("study_type", []) if str(x).strip()],
        "neuraci_type": data.get("neuraci_type") if isinstance(data.get("neuraci_type"), str) and data.get("neuraci_type").strip() else None,
        "species": [str(x) for x in data.get("species", []) if str(x).strip()],
        "population": data.get("population") if isinstance(data.get("population"), str) and data.get("population").strip() else None,
        "sample_gender": [str(x) for x in data.get("sample_gender", []) if str(x).strip()],
        "benefits": [str(x) for x in data.get("benefits", []) if str(x).strip()],
        "diseases": [str(x) for x in data.get("diseases", []) if str(x).strip()],
        "symptoms": [str(x) for x in data.get("symptoms", []) if str(x).strip()],
        "location": data.get("location") if isinstance(data.get("location"), str) and data.get("location").strip() else None,
    }
    return out
    
def llm_map_to_roots(candidates: list[str], allowed_roots: list[str]) -> list[str]:
    """
    Ask LLM to map candidate ingredient mentions to the canonical root_name(s) from allowed_roots.
    Returns a unique list of chosen roots that exist in allowed_roots.
    """
    if not candidates:
        return []
    system = (
        "Map ingredient mentions to canonical names from a provided list. "
        "If no match, omit it. Return JSON: {\"roots\": [..canonical names..]}"
    )
    user = (
        "Candidates: " + json.dumps(candidates) + "\n"
        "Allowed roots: " + json.dumps(allowed_roots) + "\n"
        "Return only JSON."
    )
    resp = client.chat.completions.create(
        model="gpt-4o-mini",
        response_format={"type": "json_object"},
        messages=[
            {"role": "system", "content": system},
            {"role": "user", "content": user},
        ],
        temperature=0.0,
    )
    try:
        data = json.loads(resp.choices[0].message.content)
        roots = data.get("roots", [])
        # keep only valid ones in allowed list
        valid = [r for r in roots if r in allowed_roots]
        return sorted(set(valid))
    except Exception:
        return []

def facets_to_filters(facets):
    flist = []

    def add_eq_or_in(key, val):
        if val is None:
            return
        if isinstance(val, list):
            vals = [str(x) for x in val if str(x)]
            if vals:
                flist.append(MetadataFilter(key=key, operator="in", value=vals))
        else:
            flist.append(MetadataFilter(key=key, operator="==", value=str(val)))

    # INGREDIENTS → filter on synonyms (root included in synonyms at index time)
    roots = facets.get("root_name")
    if roots:
        flist.append(MetadataFilter(key="synonyms", operator="in", value=roots))

    # YEARS
    ymin = facets.get("published_year_min")
    ymax = facets.get("published_year_max")
    if isinstance(ymin, int):
        flist.append(MetadataFilter(key="published_year", operator=">=", value=ymin))
    if isinstance(ymax, int):
        flist.append(MetadataFilter(key="published_year", operator="<=", value=ymax))

    # REPLACE pubmed_type → neuraci_type (string)
    add_eq_or_in("neuraci_type", facets.get("neuraci_type"))

    # KEEP study_type
    add_eq_or_in("study_type", facets.get("study_type"))

    # ADD new filters
    add_eq_or_in("species", facets.get("species"))
    add_eq_or_in("sample_gender", facets.get("sample_gender"))
    add_eq_or_in("benefits", facets.get("benefits"))
    add_eq_or_in("diseases", facets.get("diseases"))
    add_eq_or_in("symptoms", facets.get("symptoms"))

    # population/location are strings in your JSON; equality filter only (exact match)
    add_eq_or_in("population", facets.get("population"))
    add_eq_or_in("location", facets.get("location"))

    # IMPORTANT: Pinecone adapter expects MetadataFilters, not a bare list
    return MetadataFilters(filters=flist, condition=FilterCondition.AND)

def md_matches_filter(md, mf):
    # normalize to (filters_list, condition)
    if isinstance(mf, MetadataFilters):
        filters_list = mf.filters or []
        condition = mf.condition or FilterCondition.AND
    elif isinstance(mf, list):
        filters_list = mf
        condition = FilterCondition.AND
    else:
        # single filter or None
        if isinstance(mf, MetadataFilter):
            filters_list = [mf]
            condition = FilterCondition.AND
        else:
            return True  # no filters

    def match_one(m, flt: MetadataFilter):
        key, op, tgt = flt.key, flt.operator, flt.value
        val = m.get(key)
        if op == "==":  return val == tgt
        if op == "!=":  return val != tgt
        if op == "in":
            if isinstance(val, list):
                return any(x in val for x in tgt)
            return val in tgt
        if op == "nin":
            if isinstance(val, list):
                return all(x not in val for x in tgt)
            return val not in tgt
        if op == ">=":  return isinstance(val,(int,float)) and val >= tgt
        if op == "<=":  return isinstance(val,(int,float)) and val <= tgt
        if op == ">":   return isinstance(val,(int,float)) and val >  tgt
        if op == "<":   return isinstance(val,(int,float)) and val <  tgt
        return False

    if not filters_list:
        return True

    if condition == FilterCondition.OR:
        return any(match_one(md, f) for f in filters_list)
    return all(match_one(md, f) for f in filters_list)

nodes_for_bm25 = [Document(text=n.text, metadata=dict(n.metadata)) for n in all_nodes]
BM25_ALL = BM25Retriever.from_defaults(nodes=nodes_for_bm25, similarity_top_k=200)

class BM25FilteredRetriever(BaseRetriever):
    def __init__(self, base_bm25, top_k=50, filters=None):
        super().__init__()
        self.base = base_bm25
        self.top_k = top_k
        self.filters = filters

    def _post_filter(self, hits):
        if not self.filters:
            return hits
        return [h for h in hits if md_matches_filter(h.node.metadata, self.filters)]

    def _retrieve(self, query_bundle: QueryBundle):
        hits = self.base.retrieve(query_bundle.query_str)
        return self._post_filter(hits)[:self.top_k]

    async def _aretrieve(self, query_bundle: QueryBundle):
        # simple async shim over sync path
        return self._retrieve(query_bundle)

def choose_alpha_for_query(q):
    ql = q.lower()
    # crude signals: lots of digits/symbols → keyword-heavy
    digits = sum(c.isdigit() for c in ql)
    specials = sum(c in "-_:/().," for c in ql)
    tokens = len(ql.split()) or 1
    density = (digits + specials) / tokens
    # clamp ~ [0.3, 0.8]
    if density > 0.6:   return 0.35
    if density > 0.4:   return 0.45
    if density < 0.15:  return 0.75
    return 0.6

def make_hybrid_retriever(index, filters, alpha, vec_top_k=50, bm25_top_k=50, final_k=50):
    vec_k  = max(1, int(vec_top_k  * alpha))
    bm25_k = max(1, int(bm25_top_k * (1.0 - alpha)))

    # filters is now a list[MetadataFilter]
    vector_ret = index.as_retriever(similarity_top_k=vec_k, filters=filters)

    bm25_ret   = BM25FilteredRetriever(BM25_ALL, top_k=bm25_k, filters=filters)

    fused = QueryFusionRetriever(
        retrievers=[vector_ret, bm25_ret],
        mode="reciprocal_rerank",
        num_queries=1,
        similarity_top_k=final_k,
    )
    return fused



In [21]:
def map_candidates_to_roots_via_alias(cands, alias2root):
    roots = []
    for c in cands:
        key = str(c).strip().lower()
        if key in alias2root:
            roots.append(alias2root[key])
    # unique & stable
    return sorted(set(roots))

def rag_retrieve(query_text, vec_top_k=50, bm25_top_k=50):
    raw = llm_extract_facets(query_text)

    roots = map_candidates_to_roots_via_alias(raw.get("candidate_ingredients", []), ALIAS2ROOT)
    # (optional LLM fallback to roots if you want)
    # if not roots:
    #     roots = llm_map_to_roots(raw.get("candidate_ingredients", []), ALL_ROOTS)

    # REPLACE facets dict with the new fields
    facets = {
        "root_name": roots or None,
        "published_year_min": raw.get("published_year_min"),
        "published_year_max": raw.get("published_year_max"),
        "study_type": raw.get("study_type") or None,
        "neuraci_type": raw.get("neuraci_type") or None,
        "species": raw.get("species") or None,
        "population": raw.get("population") or None,
        "sample_gender": raw.get("sample_gender") or None,
        "benefits": raw.get("benefits") or None,
        "diseases": raw.get("diseases") or None,
        "symptoms": raw.get("symptoms") or None,
        "location": raw.get("location") or None,
    }
    facets = {k: v for k, v in facets.items() if v not in (None, [], {})}

    filters = facets_to_filters(facets)
    alpha = choose_alpha_for_query(query_text)

    hybrid = make_hybrid_retriever(index, filters, alpha, vec_top_k, bm25_top_k)
    results = hybrid.retrieve(query_text)
    return results, facets, filters, alpha

In [23]:
# Create a retriever for similarity search
retriever = index.as_retriever(similarity_top_k=5)  # retrieve top 5 similar chunks
query_text = "I’m looking for solid human studies on cedar leaf oil to help with anxiety during radiotherapy. Can you pull primary-outcome evidence since about 2000, inhalation use only, and ignore chemistry/extraction papers?"
hits, facets, filters, alpha = rag_retrieve(query_text, vec_top_k=40, bm25_top_k=40)

print("α used:", alpha)
print("facets:", facets)
print("Number of hits:", len(hits))
for r in hits:
    md = r.node.metadata
    print(f"Score {r.score:.4f} | PMID {md.get('PMID')} | {md.get('type')} | year {md.get('published_year')} | root {md.get('root_name')}")

print("\n")
results = retriever.retrieve(query_text)
for res in results:
    print("Score:", res.score)
    print("Text:", res.node.text)
    print("PMID:", res.node.metadata.get("PMID"))
    print("Type:", res.node.metadata.get("type"))
    print("-" * 80)


α used: 0.6
facets: {'root_name': ['cedarwood'], 'published_year_min': 2000, 'study_type': ['clinical trial', 'randomized controlled trial'], 'species': ['Homo sapiens'], 'benefits': ['anxiety reduction'], 'diseases': ['cancer'], 'symptoms': ['anxiety']}
Number of hits: 0


Score: 0.627434254
Text: Inhalation aromatherapy during radiotherapy: results of a placebo-controlled double-blind randomized trial.
PMID: 12805340
Type: title
--------------------------------------------------------------------------------
Score: 0.605291367
Text: New woody and ambery notes from cedarwood and turpentine oil.
PMID: 17191830
Type: title
--------------------------------------------------------------------------------
Score: 0.598267615
Text: To determine whether the inhalation of aromatherapy during radiotherapy reduces anxiety. Three hundred thirteen patients undergoing radiotherapy were randomly assigned to receive either carrier oil with fractionated oils, carrier oil only, or pure essential oils o

# Debugging: Reconstruction the Paper

In [None]:
# Function to reconstruct a paper from nodes
def reconstruct_paper(all_nodes, pmid):
    # Filter nodes belonging to this paper
    paper_nodes = [node for node in all_nodes if str(node.metadata.get("PMID")) == str(pmid)]
    
    # Sort nodes by node_index
    paper_nodes = sorted(paper_nodes, key=lambda x: x.metadata.get("node_index", 0))
    print("Noumber of Nodes:",len(paper_nodes))
    # Concatenate the text
    full_text = "\n".join([node.text for node in paper_nodes])
    
    return full_text

# Example usage
pmid_to_reconstruct = restructured_records[0]['metadata']['PMID']
full_paper_text = reconstruct_paper(all_nodes, pmid_to_reconstruct)

print("Reconstructed Paper Text:")
print(full_paper_text)


### Optional: Delete PineCone Index

In [None]:
# from pinecone import Pinecone
# import os
# from dotenv import load_dotenv
# load_dotenv()

# INDEX_NAME = "pubmed-abstracts"

# # Initialize Pinecone client
# client = Pinecone(api_key=os.getenv("PINECONE_API"))

# try:
#     client.delete_index(name=INDEX_NAME)
#     print("Index deleted")
# except:
#     print("Data base is empty")

### Opitional: Check available Indices

In [None]:
# indexes = client.list_indexes()
# print(f"Available indexes: {indexes.names()}")
# print(f"Current index name: {INDEX_NAME}")

# Hybrid Search Retrival Pipeline directly from Pinecone
Note: Make Sure to restart the kernal before you run the below cell to ensure that data is not being retrived from local- memory/in-memory/RAM

In [None]:
from llama_index.core import StorageContext
#from llama_index.core.storage.docstore import SimpleDocumentStore

# Just point to the folder where you persisted
storage_context = StorageContext.from_defaults(persist_dir="pubmed_nodes")

# Now access your persisted documents
docstore = storage_context.docstore
print("Number of documents:", len(docstore.docs))


In [24]:
from pinecone import Pinecone
from llama_index.vector_stores.pinecone import PineconeVectorStore
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.llms import MockLLM
from dotenv import load_dotenv

import os

load_dotenv()

pc = Pinecone(api_key=os.getenv("PINECONE_API"))
pinecone_index = pc.Index(INDEX_NAME)

vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)

embed_model = HuggingFaceEmbedding(model_name="NeuML/pubmedbert-base-embeddings")

index = VectorStoreIndex.from_vector_store(
    vector_store=vector_store,
    storage_context=storage_context,
    embed_model=embed_model
)

# Create vector retriever
vector_retriever = index.as_retriever(similarity_top_k=5)

# Create BM25 retriever for keyword-based search
# Ensure you have the documents loaded in memory for BM25
bm25_retriever = BM25Retriever.from_defaults(
    docstore=docstore,
    similarity_top_k=5
)

# Create hybrid retriever using QueryFusionRetriever
# This combines results from both retrievers
hybrid_retriever = QueryFusionRetriever(
    retrievers=[vector_retriever, bm25_retriever],
    retriever_weights=[0.5, 0.5],  # Equal weight to both retrievers
    llm=MockLLM(),  # Use MockLLM to avoid needing OpenAI API key
    use_async=False,
    #mode="reciprocal_rerank",
)

# Perform hybrid search
query = "Which analytical method was used to photosynthetic tissues?"
results = hybrid_retriever.retrieve(query)

# Display results
for res in results:
    print("Score:", res.score)
    print("Text:", res.node.text)
    print("PMID:", res.node.metadata.get("PMID"))
    print("Type:", res.node.metadata.get("type"))
    print("-" * 80)


KeyboardInterrupt: 