Install required packages

In [1]:
!pip install datasets keybert transformers kagglehub bertopic --quiet
!pip install -q keybert sentence-transformers bertopic kagglehub tqdm wordcloud


Import dependencies

In [2]:
# Standard libraries
import os
import re
import itertools
from collections import Counter

# Third-party libraries
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import umap
from wordcloud import WordCloud
from sklearn.feature_extraction.text import CountVectorizer
import spacy

# NLP and Transformers
import datasets
from kagglehub import dataset_download
from bertopic import BERTopic
from keybert import KeyBERT
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer, util


In [3]:
from google.colab import drive
drive.mount('/content/drive')

PROJECT_DIR = "/content/drive/MyDrive/Colab Notebooks/ANLP_Project"
os.makedirs(PROJECT_DIR, exist_ok=True)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Load Datasets

In [4]:
# Define paths
raw_metadata_path = os.path.join(PROJECT_DIR, "metadata.csv")
processed_metadata_path = os.path.join(PROJECT_DIR, "processed_cord19.csv")
essential_cols_path = os.path.join(PROJECT_DIR, "cord19_essential.csv")
keywords_cache_path = os.path.join(PROJECT_DIR, "doc_keywords.pkl")
kaggle_meta_path = os.path.join(PROJECT_DIR, "kaggle_metadata.csv")

In [5]:
# Download if not already present
if not os.path.exists(raw_metadata_path):
    print("Downloading CORD-19 metadata...")
    !wget -q https://ai2-semanticscholar-cord-19.s3.us-west-2.amazonaws.com/latest/metadata.csv -O "{raw_metadata_path}"
else:
    print("✅ Raw metadata file already exists.")

✅ Raw metadata file already exists.


In [6]:
# Step 2: Process metadata if missing
if not os.path.exists(processed_metadata_path):
    print("🧹 Cleaning and saving processed metadata...")
    cord19_df = pd.read_csv(raw_metadata_path, low_memory=False)
    cord19_df = cord19_df.drop_duplicates(subset=["cord_uid"])
    cord19_df.to_csv(processed_metadata_path, index=False)
else:
    print("✅ Processed metadata already exists.")

✅ Processed metadata already exists.


In [7]:
# Step 3: Filter to essential columns
if os.path.exists(essential_cols_path):
    cord19_df = pd.read_csv(essential_cols_path)
    print("✅ Loaded essential columns from cache.")
else:
    full_df = pd.read_csv(processed_metadata_path, low_memory=False)
    cols = [
        "cord_uid", "title", "authors", "abstract", "publish_time",
        "journal", "doi", "url", "pdf_json_files", "pmc_json_files"
    ]
    cord19_df = full_df[cols]
    cord19_df.to_csv(essential_cols_path, index=False)
    print("✅ Filtered and saved essential columns.")

  cord19_df = pd.read_csv(essential_cols_path)


✅ Loaded essential columns from cache.


In [8]:
cord19_df["abstract"].head()


Unnamed: 0,abstract
0,OBJECTIVE: This retrospective chart review des...
1,Inflammatory diseases of the respiratory tract...
2,Surfactant protein-D (SP-D) participates in th...
3,Endothelin-1 (ET-1) is a 21 amino acid peptide...
4,Respiratory syncytial virus (RSV) and pneumoni...


# Load googleai/cord19 metadata mappings from Kaggle (links papers to external datasets)


In [9]:
!pip install -q kagglehub


In [10]:
from kagglehub import dataset_download
import shutil

In [11]:
# Define where you want to cache the Kaggle metadata file
#kaggle_meta_path = os.path.join(PROJECT_DIR, "kaggle_metadata.csv")

In [12]:
# Download if not already cached
if os.path.exists(kaggle_meta_path):
    df_meta_cord19 = pd.read_csv(kaggle_meta_path)
    print("✅ Loaded Kaggle metadata from cache.")
else:
    print("⬇️ Downloading Kaggle CORD-19 metadata...")
    path = dataset_download(handle="googleai/dataset-metadata-for-cord19")
    filename = os.listdir(path)[0]
    source_file = os.path.join(path, filename)

    shutil.copyfile(source_file, kaggle_meta_path)
    df_meta_cord19 = pd.read_csv(kaggle_meta_path)
    print("✅ Downloaded and cached Kaggle metadata.")

# Preview
df_meta_cord19.head()

✅ Loaded Kaggle metadata from cache.


Unnamed: 0,cord_uid,paper_url,paper_title,dataset_url,dataset_name,alternate_name,description,author_list,last_updated,license,source_organization,doi,compact_identifier,data_download
0,rmzpiyqj,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6...,"Nipah virus: epidemiology, pathology, immunobi...",https://data.csiro.au/dap/landingpage?pid=csir...,Nature of exposure drives transmission of Nipa...,,"[""RT-PCR data of comparative viral loads/ tiss...",Bronwyn Clayton; Deborah Middleton; Rachel Ark...,2016,"[{""url"":""https://confluence.csiro.au/display/d...",CSIRO,10.4225/08/56806AAEAD713,,
1,h7g5ecc0,http://europepmc.org/articles/pmc4052367?pdf=r...,Novel approaches and challenges to treatment o...,https://datamed.org/display-item.php?repositor...,Key Role of T cell Defects in Age-Related Vuln...,,"[""In a mouse model of age-related vulnerabilit...",,2019-05-06,,,,,
2,3uvlmww0,https://jvi.asm.org/content/jvi/88/17/10228.fu...,"Verdinexor, a Novel Selective Inhibitor of Nuc...",https://datamed.org/display-item.php?repositor...,MicroRNA Regulation of Human Protease Genes,,the human protease genes required for influenz...,,2011-10-13,,,,,
3,xzps65et,https://doi.org/10.14745/ccdr.v45i04a01,Climate change and infectious diseases: What c...,https://search.datacite.org/works/10.5065/d6sj...,The NA-CORDEX dataset,,"[""The NA-CORDEX data archive contains output f...",Linda Mearns; Seth McGinnis; Daniel Korytina; ...,2017,"[{""url"":""http://na-cordex.org/terms-use""}]",UCAR/NCAR,10.5065/d6sj1jch,,
4,a6p8te8q,https://jvi.asm.org/content/jvi/79/6/3370.full...,Increased Epitope-Specific CD8(+) T Cells Prev...,http://www.immunedata.org/display-item.php?rep...,CMV CD8 T Cells,,"[""We present human T cell responses in multipl...",,2018-09-17,,,,,


#Use of BERT for filtering the Datasets

In [13]:
!pip install bertopic[visualization]




Filtering with KeyBERT, Semantic Similarity, and BERTopic ---

In [14]:
# -----------------------------
# 📊 Step 1: Dataset Overview
# -----------------------------
print("📄 Dataset Info:")
cord19_df.info()

print("\n🧱 Columns:", list(cord19_df.columns))
print("\n🔍 Sample Rows:")
display(cord19_df.sample(3))


📄 Dataset Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 970836 entries, 0 to 970835
Data columns (total 10 columns):
 #   Column          Non-Null Count   Dtype 
---  ------          --------------   ----- 
 0   cord_uid        970836 non-null  object
 1   title           970340 non-null  object
 2   authors         948634 non-null  object
 3   abstract        763499 non-null  object
 4   publish_time    969037 non-null  object
 5   journal         887275 non-null  object
 6   doi             654783 non-null  object
 7   url             684925 non-null  object
 8   pdf_json_files  372888 non-null  object
 9   pmc_json_files  315041 non-null  object
dtypes: object(10)
memory usage: 74.1+ MB

🧱 Columns: ['cord_uid', 'title', 'authors', 'abstract', 'publish_time', 'journal', 'doi', 'url', 'pdf_json_files', 'pmc_json_files']

🔍 Sample Rows:


Unnamed: 0,cord_uid,title,authors,abstract,publish_time,journal,doi,url,pdf_json_files,pmc_json_files
702123,g0lcoedm,Positive rates predict death rates of Covid-19...,"Mimkes, J.; Janssen, R.","In the Covid-19-pandemic, the numbers of decea...",2020-11-28,,10.1101/2020.11.24.20237842,http://medrxiv.org/cgi/content/short/2020.11.2...,document_parses/pdf_json/798a028b2bb8792c2dbf6...,
131254,gszsv5cj,Artificial Intelligence-Based Image Enhancemen...,"Liu, Juan; Malekzadeh, Masoud; Mirian, Niloufa...",High noise and low spatial resolution are two ...,2021-10-01,PET clinics,10.1016/j.cpet.2021.06.005,https://doi.org/10.1016/j.cpet.2021.06.005; ht...,,
28910,nwvsubtj,rs2671655 single nucleotide polymorphism modul...,"Shin, Cheol Min; Park, Kyungtaek; Kim, Nayoung...",OBJECTIVE: To identify genetic variations whic...,2022-03-24,Gastric Cancer,10.1007/s10120-022-01285-x,https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8...,document_parses/pdf_json/8bf48aa3c3978f3a235f9...,document_parses/pmc_json/PMC8943788.xml.json


In [15]:
# -----------------------------
# 🔍 Step 2: Smoking Keyword Exploration
# -----------------------------
SMOKING_KEYWORDS = [
    "smoking", "tobacco", "nicotine", "cigarette", "e-cigarette",
    "vaping", "lung cancer", "chronic bronchitis", "secondhand smoke",
    "nicotine dependence", "smoke exposure", "smoker", "hookah"
]

text_columns = ["title", "abstract"]

for col in text_columns:
    if col in cord19_df.columns:
        print(f"\n📘 Smoking keyword frequencies in column: '{col}'")
        for kw in SMOKING_KEYWORDS:
            count = cord19_df[col].astype(str).str.contains(kw, case=False, na=False).sum()
            if count > 0:
                print(f"  🔹 {kw}: {count}")



📘 Smoking keyword frequencies in column: 'title'
  🔹 smoking: 1122
  🔹 tobacco: 773
  🔹 nicotine: 213
  🔹 cigarette: 592
  🔹 e-cigarette: 262
  🔹 vaping: 264
  🔹 lung cancer: 1705
  🔹 chronic bronchitis: 23
  🔹 secondhand smoke: 9
  🔹 nicotine dependence: 16
  🔹 smoke exposure: 41
  🔹 smoker: 265
  🔹 hookah: 17

📘 Smoking keyword frequencies in column: 'abstract'
  🔹 smoking: 5870
  🔹 tobacco: 2140
  🔹 nicotine: 669
  🔹 cigarette: 1564
  🔹 e-cigarette: 518
  🔹 vaping: 432
  🔹 lung cancer: 2799
  🔹 chronic bronchitis: 98
  🔹 secondhand smoke: 48
  🔹 nicotine dependence: 94
  🔹 smoke exposure: 144
  🔹 smoker: 2270
  🔹 hookah: 36


In [16]:
!pip install huggingface_hub[hf_xet]



In [17]:
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

SMOKING_KEYWORDS = [
    "smoking", "tobacco", "nicotine", "cigarette", "e-cigarette",
    "vaping", "lung cancer", "chronic bronchitis", "secondhand smoke",
    "nicotine dependence", "smoke exposure", "smoker", "hookah"
]

COVID_KEYWORDS = ['covid', 'coronavirus', 'sars-cov-2', 'pandemic']
ALL_KEYWORDS = SMOKING_KEYWORDS + COVID_KEYWORDS


In [18]:
from keybert import KeyBERT
from collections import defaultdict
import pickle
keywords_cache_path = os.path.join(PROJECT_DIR, "doc_keywords.pkl")

if os.path.exists(keywords_cache_path):
    print("✅ Loaded cached document keywords.")
    with open(keywords_cache_path, "rb") as f:
        doc_keywords = pickle.load(f)
else:
    print("🚀 Running KeyBERT keyword extraction...")
    model = SentenceTransformer("all-MiniLM-L6-v2")
    kw_model = KeyBERT(model)

    def chunk_text_simple(text, max_words=80, overlap=20):
        words = text.split()
        chunks = []
        for i in range(0, len(words), max_words - overlap):
            chunk = " ".join(words[i:i + max_words])
            if len(chunk.split()) >= 10:
                chunks.append(chunk)
        return chunks

    chunked_docs, doc_indices = [], []
    for idx, text in cord19_df['abstract'].dropna().astype(str).items():
        chunks = chunk_text_simple(text)
        chunked_docs.extend(chunks)
        doc_indices.extend([idx] * len(chunks))

    chunked_docs = chunked_docs[:7000]
    doc_indices = doc_indices[:7000]

    keywords_list = kw_model.extract_keywords(
        chunked_docs,
        keyphrase_ngram_range=(1, 2),
        stop_words='english',
        use_mmr=True,
        top_n=3
    )

    doc_keywords = defaultdict(set)
    for doc_id, keywords in zip(doc_indices, keywords_list):
        for kw, _ in keywords:
            doc_keywords[doc_id].add(kw)

    with open(keywords_cache_path, "wb") as f:
        pickle.dump(doc_keywords, f)

# Attach keywords to DataFrame
cord19_df['keywords'] = cord19_df.index.map(lambda idx: list(doc_keywords.get(idx, [])))
cord19_df[['cord_uid', 'keywords']].head()

✅ Loaded cached document keywords.


Unnamed: 0,cord_uid,keywords
0,ug7v899j,"[infections bronchiolitis, cultures respirator..."
1,02tnwd4m,"[diseases respiratory, diseases lung, respirat..."
2,ejv2xln0,"[antigens sp, surfactant lipids, sp increased,..."
3,2b73a28n,"[et 21, endothelin, inflammatory lung]"
4,9785vg6d,"[display reverse, mucosal responses, transcrip..."


In [None]:
def is_relevant(text):
    return any(k in str(text).lower() for k in ALL_KEYWORDS)

cord19_df['is_relevant'] = cord19_df['abstract'].apply(is_relevant)
relevant_df = cord19_df[cord19_df['is_relevant'] & cord19_df['keywords'].notna()]
keywords_flat = list(itertools.chain.from_iterable(relevant_df['keywords'].tolist()))
keywords_text = " ".join(keywords_flat)

wordcloud = WordCloud(
    width=1000, height=500,
    background_color='white',
    max_words=150,
    collocations=False
).generate(keywords_text)

plt.figure(figsize=(14, 7))
plt.imshow(wordcloud, interpolation='bilinear')
plt.axis("off")
plt.title("WordCloud of Extracted Keywords from Relevant Papers")
plt.tight_layout()
plt.show()

keyword_counts = Counter(keywords_flat)
keywords_df = pd.DataFrame(keyword_counts.most_common(30), columns=['keyword', 'frequency'])

plt.figure(figsize=(12, 8))
sns.barplot(x='frequency', y='keyword', data=keywords_df)
plt.title("Top 30 Keyword Frequencies in Relevant Papers")
plt.xlabel("Frequency")
plt.ylabel("Keyword")
plt.tight_layout()
plt.show()


In [None]:
cord19_df['keywords'] = cord19_df.index.map(lambda idx: list(doc_keywords.get(idx, [])))

def is_relevant(text):
    text = str(text).lower()
    return any(k in text for k in ALL_KEYWORDS)

cord19_df['is_relevant'] = cord19_df['abstract'].astype(str).apply(is_relevant)


In [None]:
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import pandas as pd
from tqdm import tqdm

# Load tokenizer + embedding model (use same model for consistency)
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = SentenceTransformer("all-MiniLM-L6-v2")


In [None]:
def chunk_text(text, max_tokens=512, overlap=50):
    if not isinstance(text, str):
        return []

    tokens = tokenizer.tokenize(text)
    chunks = []
    start = 0

    while start < len(tokens):
        end = min(start + max_tokens, len(tokens))
        chunk_tokens = tokens[start:end]
        chunk_text = tokenizer.convert_tokens_to_string(chunk_tokens)
        chunks.append(chunk_text.strip())
        start += max_tokens - overlap  # move forward with overlap

    return chunks


In [None]:
source_df = cord19_df[cord19_df['is_relevant']].copy()

chunked_docs, doc_ids = [], []

for idx, row in tqdm(source_df.iterrows(), total=len(source_df)):
    abstract = row.get("abstract", "")
    chunks = chunk_text(abstract)
    chunked_docs.extend(chunks)
    doc_ids.extend([idx] * len(chunks))


In [None]:
abstract_embeddings = model.encode(
    chunked_docs,
    batch_size=64,
    show_progress_bar=True,
    convert_to_tensor=True,
    truncate=True  # ✅ This prevents index overflow errors
)
print("Embedding shape:", abstract_embeddings.shape)

In [None]:
print("Example original abstract:")
print(source_df.loc[doc_ids[0], "abstract"][:500])

print("\nExample chunked version:")
print(chunked_docs[0])

print("\nEmbedding shape:", abstract_embeddings.shape)


In [None]:
import pickle

# Save to disk (Google Drive)
with open(os.path.join(PROJECT_DIR, "chunked_docs.pkl"), "wb") as f:
    pickle.dump(chunked_docs, f)

with open(os.path.join(PROJECT_DIR, "doc_ids.pkl"), "wb") as f:
    pickle.dump(doc_ids, f)


In [None]:
import torch

# Save tensor as binary file
torch.save(abstract_embeddings, os.path.join(PROJECT_DIR, "abstract_embeddings.pt"))


In [None]:

chunked_docs_path = os.path.join(PROJECT_DIR, "chunked_docs.pkl")
doc_ids_path = os.path.join(PROJECT_DIR, "doc_ids.pkl")
embedding_path = os.path.join(PROJECT_DIR, "abstract_embeddings.pt")
source_df_path = os.path.join(PROJECT_DIR, "source_df_for_chunks.csv")

In [None]:
if all(os.path.exists(p) for p in [chunked_docs_path, doc_ids_path, embedding_path, source_df_path]):
    print("✅ Loaded chunked data from cache.")

    with open(chunked_docs_path, "rb") as f:
        chunked_docs = pickle.load(f)

    with open(doc_ids_path, "rb") as f:
        doc_ids = pickle.load(f)

    abstract_embeddings = torch.load(embedding_path)
    source_df = pd.read_csv(source_df_path)

else:
    print("🚀 Starting chunking + embedding...")

    source_df = cord19_df[cord19_df["is_relevant"] & cord19_df["abstract"].notna()].copy()
    source_df.reset_index(drop=True, inplace=True)

    source_df.to_csv(source_df_path, index=False)

    tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

    def chunk_text(text, max_tokens=512, overlap=50):
        tokens = tokenizer.tokenize(str(text))
        chunks, start = [], 0
        while start < len(tokens):
            end = min(start + max_tokens, len(tokens))
            chunk = tokenizer.convert_tokens_to_string(tokens[start:end]).strip()
            chunks.append(chunk)
            start += max_tokens - overlap
        return chunks

    chunked_docs, doc_ids = [], []

    for idx, row in tqdm(source_df.iterrows(), total=len(source_df)):
        chunks = chunk_text(row["abstract"])
        chunked_docs.extend(chunks)
        doc_ids.extend([idx] * len(chunks))

    model = SentenceTransformer("all-MiniLM-L6-v2")
    abstract_embeddings = model.encode(
        chunked_docs,
        batch_size=64,
        show_progress_bar=True,
        convert_to_tensor=True,
        truncate=True
    )

    with open(chunked_docs_path, "wb") as f:
        pickle.dump(chunked_docs, f)

    with open(doc_ids_path, "wb") as f:
        pickle.dump(doc_ids, f)

    torch.save(abstract_embeddings, embedding_path)

    print("✅ Saved all embeddings and chunked data.")


# Analyze external dataset mentions among relevant keyword-filtered papers

Understand how many keyword-relevant papers (pre-topic modeling) are referencing datasets.

In [None]:
# ✅ Use the correct column name from current pipeline
relevant_papers = cord19_df[cord19_df['is_relevant']]

# ✅ Match against Kaggle dataset-linked papers
linked = df_meta_cord19['cord_uid'].isin(relevant_papers['cord_uid'])
linked_count = linked.sum()

print(f"🔍 {linked_count} out of {len(relevant_papers)} relevant papers reference external datasets.")


#**Embedding + UMAP**

To visualize how documents differ between filtered and removed abstracts using dimensionality reduction (UMAP).

#**BERTopic Topic Modeling**

This is for semantic structuring and analysis of the filtered corpus. It helps group related documents by themes.

In [None]:
from bertopic import BERTopic

filtered_df = cord19_df[cord19_df['abstract'].notna()].copy()
subset_df = filtered_df.sample(n=7000, random_state=42)

embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
bertopic_model = BERTopic(language="english", verbose=True, embedding_model=embedding_model)

topics, _ = bertopic_model.fit_transform(subset_df['abstract'].astype(str).tolist())

filtered_df["topic"] = -1
filtered_df.loc[subset_df.index, "topic"] = topics

# Topic sizes
topic_sizes = filtered_df["topic"].value_counts()
print("✅ Topic distribution:")
print(topic_sizes.head())


# Post-topic modeling: Analyze external dataset mentions per topic

Understand which topics are associated with dataset-linked papers, and how strong the signal is.

In [None]:

from collections import defaultdict, Counter

In [None]:
# ✅ Already applied above:
# filtered_df["topic"] = -1
# filtered_df.loc[subset_df.index, "topic"] = topics

# ✅ Create set of known dataset-linked cord_uids
referenced_uids = set(df_meta_cord19["cord_uid"])

# ✅ Mark which papers reference datasets
filtered_df["has_dataset_reference"] = filtered_df["cord_uid"].isin(referenced_uids)

# ✅ Group by topic
topic_to_corduids = defaultdict(list)
for index, row in filtered_df.iterrows():
    topic = row['topic']
    if topic != -1:
        topic_to_corduids[topic].append(row['cord_uid'])

# ✅ Count dataset links per topic
external_ref_counts = {
    topic: df_meta_cord19['cord_uid'].isin(uids).sum()
    for topic, uids in topic_to_corduids.items()
}

# ✅ Count documents per topic
topic_sizes = filtered_df['topic'].value_counts().to_dict()
topic_sizes.pop(-1, None)  # remove unassigned

# ✅ Combine size + dataset linkage summary
summary = []
for topic in sorted(topic_to_corduids.keys()):
    size = topic_sizes.get(topic, 0)
    dataset_refs = external_ref_counts.get(topic, 0)
    summary.append((topic, size, dataset_refs))

# ✅ Report: Only topics with 10+ papers
MIN_TOPIC_SIZE = 10
print("📊 Topics with dataset-linked paper counts:\n")
for topic, size, refs in summary:
    if size >= MIN_TOPIC_SIZE:
        print(f"Topic {topic:4d} | Size: {size:4d} | Dataset-linked papers: {refs}")


#Vector DB with LlamaIndex

In [None]:
# core LlamaIndex
!pip install llama-index

# the HuggingFace embeddings integration
!pip install llama-index-embeddings-huggingface


In [None]:

from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader, ServiceContext, Document
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
import os


In [None]:
try:
    row = source_df.iloc[doc_ids[501689]]
    print("✅ Aligned.")
except Exception as e:
    print("❌ Misaligned:", e)


In [None]:
!pip install -q joblib
!pip install -q tqdm



In [None]:
from llama_index.core import VectorStoreIndex, Document
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.settings import Settings
import os
import pickle
import torch
import pandas as pd
from joblib import Parallel, delayed
from tqdm import tqdm
from llama_index.core.node_parser import SimpleNodeParser
# Paths
vector_index_path = os.path.join(PROJECT_DIR, "vector_index_sample_5000new")

# Load chunked docs and metadata
with open(os.path.join(PROJECT_DIR, "chunked_docs.pkl"), "rb") as f:
    chunked_docs = pickle.load(f)

with open(os.path.join(PROJECT_DIR, "doc_ids.pkl"), "rb") as f:
    doc_ids = pickle.load(f)

source_df = pd.read_csv(os.path.join(PROJECT_DIR, "source_df_for_chunks.csv"))

# Truncate to 5,000 sample
chunked_docs = chunked_docs[:5000]
doc_ids = doc_ids[:5000]
source_rows = source_df.to_dict("records")

# ✅ Set embed model BEFORE index creation
embed_model = HuggingFaceEmbedding(model_name="all-MiniLM-L6-v2", device="cuda", max_length=512, show_progress_bar=True)
Settings.embed_model = embed_model
Settings.node_parser = SimpleNodeParser.from_defaults()
# Optional flags
dataset_linked_uids = set()
source_df["is_dataset_linked"] = source_df["cord_uid"].isin(dataset_linked_uids)
source_df["is_smoking_keyword"] = source_df["abstract"].apply(
    lambda x: any(kw.lower() in str(x).lower() for kw in SMOKING_KEYWORDS)
)

def build_document(chunk, doc_id):
    row = source_rows[doc_id]
    metadata = {
        "cord_uid": str(row.get("cord_uid")),
        "title": str(row.get("title", ""))[:200],
        "keywords": list(row.get("keywords", []))[:5],
        "is_relevant": bool(row.get("is_relevant", False)),
        "is_smoking_keyword": bool(row.get("is_smoking_keyword", False)),
        "is_dataset_linked": bool(row.get("is_dataset_linked", False)),
        "topic": int(row.get("topic", -1)) if "topic" in row else -1
    }
    return Document(text=chunk, metadata=metadata)  # ✅ Allow LlamaIndex to embed

# Build docs in parallel
documents = Parallel(n_jobs=-1, backend="threading")(
    delayed(build_document)(chunk, doc_id)
    for chunk, doc_id in tqdm(zip(chunked_docs, doc_ids), total=len(chunked_docs), desc="📄 Creating Documents")
)

# ⚠️ Remove old index if it exists
if os.path.exists(vector_index_path):
    import shutil
    shutil.rmtree(vector_index_path)

# ✅ Build new index with embedded docs
#service_context = ServiceContext.from_defaults(embed_model=embed_model)
index = VectorStoreIndex.from_documents(documents, embed_model=embed_model, show_progress=True)
index.storage_context.persist(persist_dir=vector_index_path)
print("✅ Sample index rebuilt with embedded documents.")


In [None]:
print(documents[0])


In [None]:
print(Settings.embed_model)

#QA-Agent

In [None]:
!pip install llama-index llama-index-embeddings-huggingface llama-index-llms-huggingface bitsandbytes

In [None]:
pip install --upgrade "llama-index>=0.6.10" llama-index-embeddings-huggingface


In [None]:
# prompt: install dependency for llama_index.retrievers

!pip install llama-index-retrievers-bm25

In [None]:
from huggingface_hub import login
login(token="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")  #  actual token

In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
#model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")

In [None]:
# === Load the Meta LLM via HuggingFace ===
from llama_index.llms.huggingface import HuggingFaceLLM

llm = HuggingFaceLLM(
   model_name="meta-llama/Llama-3.2-1B-Instruct",
    tokenizer=tokenizer,
    context_window=4096,
    max_new_tokens=512,
)

In [None]:
# === Import Required Modules ===
from llama_index.core import Settings, Document, StorageContext, VectorStoreIndex, ServiceContext, load_index_from_storage
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.core.chat_engine import CondenseQuestionChatEngine
from llama_index.core.memory import ChatMemoryBuffer


In [None]:
from llama_index.core.retrievers import BaseRetriever
from sentence_transformers import util
import torch

class SemanticBoostRetriever(BaseRetriever):
    """
    A retriever that:
      1. Retrieves candidates from base retriever.
      2. Embeds the query via embed_model.get_query_embedding(...)
      3. Computes cosine similarity with each node.
      4. If node.embedding is missing, computes it on the fly.
      5. Boosts dataset-linked documents.
      6. Optionally filters to smoking or topic.
    """

    def __init__(
        self,
        base_retriever: BaseRetriever,
        embed_model,
        similarity_threshold: float = 0.3,
        boost_factor: float = 2.0,
        filter_smoking: bool = False,
        topic_filter: int = None,
    ):
        super().__init__()
        self.retriever = base_retriever
        self.embed_model = embed_model
        self.similarity_threshold = similarity_threshold
        self.boost_factor = boost_factor
        self.filter_smoking = filter_smoking
        self.topic_filter = topic_filter

        # Safety checks
        if isinstance(embed_model, list):
            raise ValueError("Expected a single embed_model, but got a list.")

        if not hasattr(embed_model, "get_query_embedding"):
            raise ValueError("Embed model must support .get_query_embedding(...)")

    def _retrieve(self, query_bundle):
        query_str = query_bundle.query_str

        # Step 1: Get candidates from base retriever
        candidate_nodes = self.retriever.retrieve(query_bundle) or []

        # Step 2: Embed the query
        query_vector = self.embed_model.get_query_embedding(query_str)
        query_tensor = torch.tensor(query_vector).unsqueeze(0)  # (1, D)

        results = []

        for node in candidate_nodes:
            # Step 3: Fallback embedding if not already embedded
            if node.embedding is None:
                try:
                    node.embedding = self.embed_model.get_text_embedding(node.text)
                except Exception as e:
                    print(f"⚠️ Failed to embed node {node.node_id}: {e}")
                    continue

            node_tensor = torch.tensor(node.embedding).unsqueeze(0)  # (1, D)
            sim_score = util.cos_sim(query_tensor, node_tensor).item()

            # Step 4: Filter by similarity threshold
            if sim_score < self.similarity_threshold:
                continue

            # Step 5: Optional filter for smoking-related
            if self.filter_smoking and not node.metadata.get("is_smoking_keyword", False):
                continue

            # Step 6: Optional topic filter
            if self.topic_filter is not None:
                if node.metadata.get("topic") != self.topic_filter:
                    continue

            # Step 7: Boost score if linked to dataset
            if node.metadata.get("is_dataset_linked", False):
                node.score *= self.boost_factor

            results.append(node)

        if not results:
            return candidate_nodes[:10]  # fallback

        results.sort(key=lambda n: n.score, reverse=True)
        return results[:10]


In [None]:
from llama_index.core.indices.prompt_helper import PromptHelper

Settings.llm = llm
Settings.embed_model = embed_model
# Add PromptHelper to safely manage prompt sizes
# Reason: Prevent context overflow errors by automatically adjusting the size of prompts and chunks.
# It ensures that:
#   - Total tokens (input + output) stay within model limits (e.g. 4096)
#   - Long documents are chunked efficiently (e.g. overlapping content preserved)
#   - Helps maintain performance and avoids runtime crashes
Settings.prompt_helper = PromptHelper(
    context_window=4096,     # Max context tokens LLM supports
             # Expected max response size from the LLM
    #chunk_overlap_ratio=0.1, # Overlap between chunks to retain context across boundaries
    tokenizer=tokenizer  # Use tokenizer of the loaded LLM
)

In [None]:
PROJECT_DIR = "/content/drive/MyDrive/Colab Notebooks/ANLP_Project"
index_path = os.path.join(PROJECT_DIR, "vector_index_sample_5000new")


In [None]:
# 2. Load vector index and apply custom retriever
from llama_index.core import StorageContext, load_index_from_storage

storage_context = StorageContext.from_defaults(persist_dir=index_path)

index = load_index_from_storage(storage_context)

In [None]:
print(index.docstore.docs.values())

In [None]:
# Get the base retriever
base_retriever = index.as_retriever(retriever_mode="embedding", similarity_top_k=50)

retriever = SemanticBoostRetriever(
    base_retriever=base_retriever,
    embed_model=embed_model,
    similarity_threshold=0.6,      # 🔧 Tune based on quality
    boost_factor=2.0,              # 🚀 Boost dataset-linked docs
    filter_smoking=False,          # 🔥 Optional: focus on smoking topics
    topic_filter= None              # 🎯 Optional: filter to a specific topic ID
)

In [None]:
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.chat_engine import ContextChatEngine
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.llms.huggingface import HuggingFaceLLM

In [None]:
# 4. Build chat engine with system prompt + memory + boosted retriever
system_prompt = (
    "You are a biomedical assistant trained to answer questions about COVID-19, smoking, "
    "and the correlation between the two. Your answers are based strictly on scientific abstracts. "
    "Do not give general health advice. Only respond using information retrieved from research data."
)

query_engine = RetrieverQueryEngine.from_args(retriever=retriever, llm=llm)
chat_engine = ContextChatEngine.from_defaults(
    retriever=retriever,
    llm=llm,
    memory=ChatMemoryBuffer.from_defaults(token_limit=32000),
    system_prompt=system_prompt
)

In [None]:
import torch
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "Keine GPU"
total_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3) if torch.cuda.is_available() else 0
print(f"GPU: {gpu_name} mit insgesamt {total_mem:.1f} GiB VRAM")


In [None]:
import torch

# Irgendwo vor der Textgenerierung
torch.cuda.empty_cache()


In [None]:
# ────────────────────────────────────────────────────────────────────────────────
# 1) Two‐Stage Reranker + Reader Helper (in‐notebook)
# ────────────────────────────────────────────────────────────────────────────────

from sentence_transformers import SentenceTransformer, util
import torch
from llama_index.core import QueryBundle

# Load a pretrained bi‐encoder (only once)
reranker = SentenceTransformer("all-MiniLM-L6-v2")

def two_stage_answer(query: str, top_k_retrieve: int = 50, top_k_rerank: int = 5) -> str:
    """
    Stage 1: retriever.retrieve(...) → top_k_retrieve candidates
    Stage 2: rerank those candidates with Sentence‐Transformer; keep top_k_rerank
    Stage 3: feed top passages into chat_engine.stream_chat(...) and return final answer
    """
    # ── Stage 1: Semantic‐boost retrieval
    qb = QueryBundle(query_str=query)
    candidates = retriever.retrieve(qb) or []
    candidates = candidates[:top_k_retrieve]
    if not candidates:
        return "No passages retrieved."

    # Extract raw texts
    docs = [node.node.get_text() for node in candidates]

    # ── Stage 2: embed query + docs, compute cosine scores
    query_emb = reranker.encode(query, convert_to_tensor=True)
    docs_emb  = reranker.encode(docs,  convert_to_tensor=True)
    cosine_scores = util.cos_sim(query_emb, docs_emb)[0]  # shape: (top_k_retrieve,)

    # Pick the indices of the top_k_rerank documents
    topk_idxs = torch.topk(
        cosine_scores,
        k=min(top_k_rerank, len(docs)),
        largest=True
    ).indices.tolist()

    # Build the reranked context
    reranked_passages = [docs[i] for i in topk_idxs]
    context_text = "\n\n".join(reranked_passages)

    # ── Stage 3: Ask the LLM with an extractive prompt
    prompt = (
        "You are a biomedical QA assistant. Find the single sentence that answers the question "
        "and output that sentence verbatim. If no answer is present, say \"No answer in retrieved passages.\".\n\n"
        f"Passages:\n{context_text}\n\n"
        f"Question: {query}\n"
        "Exact sentence (copy verbatim):"
    )

    response = chat_engine.stream_chat(prompt)
    answer = "".join([tok for tok in response.response_gen]).strip()
    return answer

# ────────────────────────────────────────────────────────────────────────────────
# 2) Test it right here in the notebook
# ────────────────────────────────────────────────────────────────────────────────
test_q = "How does nicotine affect ACE2 expression in lung tissue?"
print("Two‐Stage Answer:", two_stage_answer(test_q, top_k_retrieve=50, top_k_rerank=5))


In [None]:
while True:
    query = input(">>> ")
    if query.lower().strip() == "quit":
        break

    print("Agent: ", end="", flush=True)
    # call our two‐stage pipeline instead of direct stream_chat
    answer = two_stage_answer(query, top_k_retrieve=50, top_k_rerank=5)
    print(answer)
    chat_engine.reset()


In [None]:
while True:
    query = input("> ")
    if query.lower() == "quit":
        break

    # Streamelt válasz
    print("Agent: ", end="", flush=True)
    response = chat_engine.stream_chat(query)
    for token in response.response_gen:
        print(token, end="", flush=True)
    print()


chat_engine.reset()


In [None]:
# prompt: implement validation methods to test the agent model performance

import pandas as pd
# Define validation methods
def evaluate_response(query: str, expected_answer: str, agent_response: str) -> dict:
    """Evaluates a single response based on expected answer and keywords."""
    evaluation = {
        "query": query,
        "expected_answer": expected_answer,
        "agent_response": agent_response,
        "contains_expected_keywords": False,
        "is_relevant_to_query": False, # Requires manual review or more sophisticated NLP
        "uses_retrieved_context": False # Requires analyzing source nodes
    }

    # Simple keyword check for relevance
    expected_keywords = set(expected_answer.lower().split())
    response_words = set(agent_response.lower().split())
    common_keywords = expected_keywords.intersection(response_words)

    if len(common_keywords) > 0:
        evaluation["contains_expected_keywords"] = True

    # Placeholder for more advanced checks
    # For 'is_relevant_to_query' and 'uses_retrieved_context', you'd need more complex NLP
    # or access to the source nodes the agent used.

    return evaluation

def run_evaluation_suite(queries_and_answers: list):
    """Runs a suite of queries through the agent and evaluates responses."""
    results = []
    print("\n--- Running Agent Evaluation Suite ---")
    for query, expected_answer in tqdm(queries_and_answers, desc="Evaluating Agent"):
        try:
            # Get response from the agent
            # Assuming chat_engine is accessible from the previous code
            response_stream = chat_engine.stream_chat(query)
            agent_response_tokens = [token for token in response_stream.response_gen]
            agent_response = "".join(agent_response_tokens).strip()

            # Evaluate the response
            evaluation_result = evaluate_response(query, expected_answer, agent_response)
            results.append(evaluation_result)

        except Exception as e:
            print(f"\nError processing query '{query}': {e}")
            results.append({
                "query": query,
                "expected_answer": expected_answer,
                "agent_response": f"ERROR: {e}",
                "contains_expected_keywords": False,
                "is_relevant_to_query": False,
                "uses_retrieved_context": False
            })

    print("--- Evaluation Complete ---")
    return pd.DataFrame(results)

# --- Example Usage ---

# Define a list of queries and their expected answers for testing
# These should ideally be derived from your dataset or knowledge base
test_cases = [
    ("What is the link between smoking and COVID-19 severity?", "Smoking is associated with increased severity of COVID-19."),
    ("Does vaping increase the risk of COVID-19 infection?", "Some studies suggest vaping may be a risk factor, but more research is needed."),
    ("Are former smokers at higher risk for severe COVID-19?", "Former smokers might have residual lung damage affecting COVID-19 outcomes."),
    ("Mention a study related to tobacco use and coronavirus.", "Look for papers discussing SARS-CoV-2 and smoking habits."),
    ("What datasets are linked to papers on COVID-19 and smoking?", "Find papers in the dataset-linked category that also mention smoking keywords.")
]

# Run the evaluation suite
evaluation_df = run_evaluation_suite(test_cases)

# Display results
print("\nEvaluation Results:")
display(evaluation_df)

# Analyze results
print("\nSummary Statistics:")
print(evaluation_df["contains_expected_keywords"].value_counts(normalize=True))

# You would typically add more sophisticated metrics:
# - ROUGE scores (for recall/precision of n-grams compared to expected answer)
# - Semantic similarity between agent response and expected answer (e.g., using another embedding model)
# - Manual review for relevance and factual correctness
# - Checks if the response cites or is based on the retrieved nodes (requires modifying the agent or analyzing its output structure)


In [None]:
# ──────────────────────────────────────────────────────────────────────────
# Two‐Stage Reranking + Reader Example
# ──────────────────────────────────────────────────────────────────────────

from sentence_transformers import SentenceTransformer, util
from llama_index.core import QueryBundle

# 1) Load a pretrained bi‐encoder as reranker
#    You can swap in any Sentence‐Transformer model (e.g. "all-mpnet-base-v2").
reranker = SentenceTransformer("all-MiniLM-L6-v2")

# 2) A helper function to do reranking on top_K candidates
def rerank_and_read(query: str, top_k_retrieve: int = 10, top_k_rerank: int = 3):
    """
    1) Retrieve top_k_retrieve passages via retriever.
    2) Embed query + passages with `reranker`, compute cosine scores.
    3) Take top top_k_rerank passages, join them into a single context.
    4) Call chat_engine.predict(...) on that concatenated context+query.
    """

    # ── Stage 1: retrieve top_k_retrieve candidates ─────────────────────────────
    qb = QueryBundle(query_str=query)
    candidates = retriever.retrieve(qb) or []
    candidates = candidates[:top_k_retrieve]

    if len(candidates) == 0:
        return "No passages retrieved."

    # Extract their texts
    docs = [node.node.get_text() for node in candidates]

    # ── Stage 2: rerank those K by cosine similarity ────────────────────────────
    #    Embed the query once, embed all K document texts, compute sim scores in batch.
    query_emb = reranker.encode(query, convert_to_tensor=True)
    docs_emb  = reranker.encode(docs,  convert_to_tensor=True)
    # cosine_sim is a 1×K tensor
    cosine_scores = util.cos_sim(query_emb, docs_emb)[0]  # shape: (K,)

    # Sort indices by descending score, pick top_k_rerank
    topk_idxs = torch.topk(cosine_scores, k=min(top_k_rerank, len(docs)), largest=True).indices
    topk_idxs = topk_idxs.tolist()

    # Build the reranked context_text by concatenating the chosen passages
    reranked_passages = [docs[i] for i in topk_idxs]
    context_text = "\n\n".join(reranked_passages)

    # ── Stage 3: feed reranked context into your LLM reader ────────────────────
    prompt = (
        "You are a biomedical QA assistant. Use ONLY the passages below to answer EXACTLY. "
        "If no answer is present, say \"No answer in retrieved passages.\".\n\n"
        f"Passages:\n{context_text}\n\n"
        f"Question: {query}\n"
        "Answer:"
    )

    # Use chat_engine.predict(...) so you get the full string
        # Use stream_chat instead of predict()
    response = chat_engine.stream_chat(prompt)
    answer = "".join([token for token in response.response_gen]).strip()

    return answer

# 3) Example usage: Evaluate a single question
q = "How does nicotine affect ACE2 expression in lung tissue?"
print("Agent’s answer:", rerank_and_read(q, top_k_retrieve=10, top_k_rerank=3))


# ──────────────────────────────────────────────────────────────────────────
# (Optional) Batch‐evaluate on test_set (token‐F1 or human grade)
# ──────────────────────────────────────────────────────────────────────────

import json, re
from collections import Counter

# Load a small test list of {question, reference} for fuzzy‐match evaluation.
eval_questions = [
    {"question": "Which receptor does SARS-CoV-2 use to enter lung cells?",
     "reference": "ACE2 receptor"},
    {"question": "What is the primary carcinogen in cigarette smoke?",
     "reference": "Benzopyrene"},
    # … add 20–30 more …
]

def normalize(text: str) -> str:
    text = text.lower()
    text = re.sub(r"[^a-z0-9\s]", "", text)
    return text.strip()

def token_f1(pred: str, ref: str) -> float:
    ptoks = normalize(pred).split()
    rtoks = normalize(ref).split()
    common = Counter(ptoks) & Counter(rtoks)
    num_same = sum(common.values())
    if num_same == 0:
        return 0.0
    p = num_same / len(ptoks)
    r = num_same / len(rtoks)
    return 2 * (p*r) / (p + r)

scores = []
for entry in eval_questions:
    q_text = entry["question"]
    ref    = entry["reference"]
    pred   = rerank_and_read(q_text, top_k_retrieve=10, top_k_rerank=3)
    scores.append(token_f1(pred, ref))

print("Avg token‐F1 on eval set:", sum(scores) / len(scores))


In [None]:
# prompt: install rouge

!pip install rouge --quiet

In [None]:
# prompt: Use Python in Colab to evaluate the performance of a question-answering agent. Visualize the following metrics over a test set of question-answer pairs

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from rouge import Rouge

# --- Evaluation ---
# Create a test set of question-answer pairs (manual creation or from a benchmark)
test_set = [
    ("What is the relationship between smoking and severe COVID-19 outcomes?", "Smoking is associated with increased severity and mortality in COVID-19 patients."),
    ("Does smoking increase the risk of contracting COVID-19?", "Current evidence suggests smoking may increase susceptibility to SARS-CoV-2 infection."),
    ("How does nicotine affect the ACE2 receptor?", "Research suggests nicotine may upregulate ACE2 receptor expression, which could potentially increase SARS-CoV-2 entry."),
    ("Are e-cigarettes linked to severe COVID-19?", "Studies indicate that e-cigarette use may also be associated with increased risk of severe COVID-19."),
    ("What are the long-term respiratory effects of smoking after COVID-19?", "Smoking may exacerbate long-term respiratory complications following COVID-19.")
]

# Store agent responses
agent_responses = []
for question, _ in test_set:
    # Get response from the chat engine
    response = chat_engine.chat(question) # Use chat() for non-streaming in eval
    agent_responses.append(str(response))

# Prepare for evaluation
true_answers = [ans for _, ans in test_set]

# 1. ROUGE Score (measures overlap between agent and true answers)
rouge = Rouge()
rouge_scores = rouge.get_scores(agent_responses, true_answers, avg=True)
print("\nROUGE Scores:")
print(rouge_scores)

# Note: For a true QA evaluation, you would need a way to score answers as correct/incorrect,
# which usually requires human annotation or a carefully designed scoring system.
# The metrics below are illustrative and would require binary labels (e.g., is the answer correct?).
# Since we don't have binary labels for this task, we'll skip accuracy, precision, recall, F1.

# --- Visualization ---

# Function to visualize ROUGE scores
def plot_rouge_scores(rouge_scores):
    rouge_types = list(rouge_scores.keys())
    metrics = ['r', 'p', 'f'] # recall, precision, f-score
    metric_labels = {'r': 'Recall', 'p': 'Precision', 'f': 'F1-Score'}

    data = []
    for rt in rouge_types:
        for met in metrics:
            data.append({'ROUGE Type': rt.upper(), 'Metric': metric_labels[met], 'Score': rouge_scores[rt][met]})

    df_rouge = pd.DataFrame(data)

    plt.figure(figsize=(10, 6))
    sns.barplot(data=df_rouge, x='Metric', y='Score', hue='ROUGE Type')
    plt.title('ROUGE Scores for QA Agent')
    plt.ylabel('Score')
    plt.ylim(0, 1)
    plt.show()

plot_rouge_scores(rouge_scores)

# Note: More comprehensive QA evaluation often involves metrics like F1 score on extracted answers (like in SQuAD)
# or using an independent QA model to score the generated answers against references.
# This ROUGE-based approach gives a measure of semantic overlap, which is a good start.


In [None]:
!pip install bert_score  --quiet

In [None]:
# Step 1: Install Hugging Face evaluation package
pip install git+https://github.com/google-research/bleurt.git

In [None]:
# prompt: Evaluate With Semantic Similarity
#     compute  BERTScore or BLEURT to capture semantic match.

import matplotlib.pyplot as plt

from datasets import load_metric

# 2. BERTScore
# Load the BERTScore metric
bertscore = load_metric("bertscore")

# Compute BERTScore
bertscore_results = bertscore.compute(
    predictions=agent_responses,
    references=true_answers,
    lang="en",
    model_type="bert-base-uncased" # Can try other models like 'roberta-base'
)

print("\nBERTScore Results:")
# BERTScore provides precision, recall, and F1 for each prediction-reference pair.
# Let's print the average F1 score as a single metric.
print(f"Average BERTScore F1: {sum(bertscore_results['f1']) / len(bertscore_results['f1'])}")
# You can also inspect individual scores:
# print("Individual BERTScore F1s:", bertscore_results['f1'])


# 3. BLEURT (Requires checkpoint download, can be slow first time)
# BLEURT is typically used for evaluating machine translation, but can also be used for
# evaluating text generation where a reference is available.
# It requires a pre-trained checkpoint. You can specify one, or it will try to download a default.

# try:
#     bleurt = load_metric("bleurt", checkpoint="bleurt-base-512")
#     bleurt_results = bleurt.compute(predictions=agent_responses, references=true_answers)

#     print("\nBLEURT Results:")
#     # BLEURT provides a score for each prediction-reference pair.
#     # Let's print the average score.
#     print(f"Average BLEURT Score: {sum(bleurt_results['scores']) / len(bleurt_results['scores'])}")
#     # print("Individual BLEURT Scores:", bleurt_results['scores'])

# except Exception as e:
#     print(f"\nCould not compute BLEURT: {e}")
#     print("BLEURT requires downloading a checkpoint which might fail or take time.")
#     print("Consider running this part separately or handling potential errors.")
#     print("Check BLEURT documentation for available checkpoints and installation steps.")
#     # Example of handling: if checkpoint fails, you might just skip BLEURT or use a local path if pre-downloaded.

# --- Visualization for BERTScore F1 ---

def plot_bertscore_f1(bertscore_results):
    f1_scores = bertscore_results['f1']
    questions = [f"Q{i+1}" for i in range(len(f1_scores))]

    plt.figure(figsize=(10, 6))
    sns.barplot(x=questions, y=f1_scores)
    plt.title('BERTScore F1 for each QA Pair')
    plt.ylabel('BERTScore F1')
    plt.ylim(0, 1)
    plt.show()

plot_bertscore_f1(bertscore_results)


# Note: Evaluating LLM responses for factual correctness is complex.
# Semantic similarity metrics like BERTScore and BLEURT give a sense of how similar the *meaning*
# is between the generated answer and the reference answer, which is better than simple word overlap (like ROUGE)
# for capturing paraphrasing or different wording that conveys the same information.
# However, they don't guarantee factual accuracy. A high score means the answer is semantically close to the *reference*,
# but if the reference is wrong, the score will still be high for a semantically similar wrong answer.
# A truly robust evaluation often involves:
# 1. Fact-checking the agent's response against the source documents or external knowledge.
# 2. Human evaluation for relevance, coherence, completeness, and accuracy.
# 3. Using reference answers that are verified to be correct and comprehensive.

In [None]:
# Create dev_set.json directly
import json

dev_questions = [
    {
        "question": "Which receptor does SARS-CoV-2 use to enter lung cells?",
        "gold_node_ids": ["abc156"],
        "gold_answer": "ACE2 receptor"
    },
    {
        "question": "What is the primary carcinogen in cigarette smoke?",
        "gold_node_ids": ["acf123"],
        "gold_answer": "Benzopyrene"
    },
    {
        "question": "Does active smoking increase the risk of severe COVID-19?",
        "gold_node_ids": ["abc123"],
        "gold_answer": "Yes, active smokers have a higher risk of severe COVID-19."
    },
    {
        "question": "What is the incubation period of COVID-19?",
        "gold_node_ids": ["def456"],
        "gold_answer": "Approximately 2 to 14 days"
    }
]

with open("dev_set.json", "w") as f:
    json.dump(dev_questions, f, indent=2)


In [None]:
# Create test_set.json directly
import json

test_questions = [
    {
        "question": "What are the most common symptoms of COVID-19?",
        "gold_node_ids": ["ghi789"],
        "gold_answer": "Fever, dry cough, and fatigue"
    },
    {
        "question": "Which lung disease is most strongly linked to chronic smoking?",
        "gold_node_ids": ["jkl012"],
        "gold_answer": "Chronic obstructive pulmonary disease"
    },
    {
        "question": "How does nicotine affect ACE2 expression in lung tissue?",
        "gold_node_ids": ["jkl076"],
        "gold_answer": "Nicotine upregulates ACE2 expression in lung epithelial cells"
    },
    {
        "question": "Is there a difference in COVID-19 mortality between smokers and non-smokers?",
        "gold_node_ids": ["jdp012"],
        "gold_answer": "Smokers have higher COVID-19 mortality rates than non-smokers"
    }
]

with open("test_set.json", "w") as f:
    json.dump(test_questions, f, indent=2)


In [None]:
# 1) Overwrite dev_set.json and test_set.json so that each gold_node_ids becomes
#    the top‐1 retrieved node_id for that question (for quick alignment).

import json
from llama_index.core import QueryBundle

# Load current question lists
with open("dev_set.json",  "r") as f:
    dev_questions  = json.load(f)
with open("test_set.json", "r") as f:
    test_questions = json.load(f)

# For each question, retrieve the top‐1 node and replace gold_node_ids with its node_id
def align_gold_ids(retriever, question_list):
    for entry in question_list:
        q_text = entry["question"]
        qb = QueryBundle(query_str=q_text)
        candidates = retriever.retrieve(qb) or []
        if candidates:
            entry["gold_node_ids"] = [candidates[0].node_id]
        else:
            entry["gold_node_ids"] = []  # no candidate found
    return question_list

dev_questions  = align_gold_ids(retriever, dev_questions)
test_questions = align_gold_ids(retriever, test_questions)

# Write back to JSON files
with open("dev_set.json",  "w") as f:
    json.dump(dev_questions, f, indent=2)
with open("test_set.json", "w") as f:
    json.dump(test_questions, f, indent=2)


# 2) Re‐run the evaluation cell now that gold_node_ids have valid IDs.

# ──────────────────────────────────────────────────────────────────────────
# STEP 0: Imports and load dev/test question files
# ──────────────────────────────────────────────────────────────────────────

import torch
from sentence_transformers import util

with open("dev_set.json",  "r") as f:
    dev_questions  = json.load(f)
with open("test_set.json", "r") as f:
    test_questions = json.load(f)

# ──────────────────────────────────────────────────────────────────────────
# STEP 1: Helper to compute Recall@K and MRR@K
# ──────────────────────────────────────────────────────────────────────────

def evaluate_retriever(retriever, question_data, K=10):
    total = len(question_data)
    num_recall_hits = 0
    sum_reciprocal_rank = 0.0

    for entry in question_data:
        q_text  = entry["question"]
        gold_ids = set(entry["gold_node_ids"])

        qb = QueryBundle(query_str=q_text)
        candidates = retriever.retrieve(qb) or []
        topk = candidates[:K]

        found_rank = None
        for rank, node in enumerate(topk, start=1):
            if node.node_id in gold_ids:
                found_rank = rank
                break

        if found_rank is not None:
            num_recall_hits += 1
            sum_reciprocal_rank += 1.0 / found_rank

    recall_at_K = num_recall_hits / total
    mrr_at_K    = sum_reciprocal_rank / total
    return recall_at_K, mrr_at_K

# ──────────────────────────────────────────────────────────────────────────
# STEP 2: Evaluate retriever on dev set
# ──────────────────────────────────────────────────────────────────────────

dev_recall10, dev_mrr10 = evaluate_retriever(retriever, dev_questions,  K=10)
dev_recall20, dev_mrr20 = evaluate_retriever(retriever, dev_questions,  K=20)

print(f"Dev Recall@10 = {dev_recall10:.3f}, MRR@10 = {dev_mrr10:.3f}")
print(f"Dev Recall@20 = {dev_recall20:.3f}, MRR@20 = {dev_mrr20:.3f}")

# ──────────────────────────────────────────────────────────────────────────
# STEP 3: Hyperparameter grid search on dev
# ──────────────────────────────────────────────────────────────────────────

best_combination = None
best_recall = 0.0

for thresh in [0.2, 0.3, 0.4, 0.5, 0.6]:
    for boost in [1.0, 1.2, 1.5, 2.0]:
        retriever.similarity_threshold = thresh
        retriever.boost_factor        = boost
        recall10, _ = evaluate_retriever(retriever, dev_questions, K=10)
        if recall10 > best_recall:
            best_recall      = recall10
            best_combination = (thresh, boost)

print(
    "Best on dev: threshold =", best_combination[0],
    "boost =", best_combination[1],
    "→ Recall@10 =", best_recall
)

best_threshold, best_boost = best_combination
retriever.similarity_threshold = best_threshold
retriever.boost_factor        = best_boost

# ──────────────────────────────────────────────────────────────────────────
# STEP 4: Compute test‐set retrieval metrics
# ──────────────────────────────────────────────────────────────────────────

test_recall10, test_mrr10 = evaluate_retriever(retriever, test_questions, K=10)
test_recall20, test_mrr20 = evaluate_retriever(retriever, test_questions, K=20)

print(f"Test Recall@10 = {test_recall10:.3f}, MRR@10 = {test_mrr10:.3f}")
print(f"Test Recall@20 = {test_recall20:.3f}, MRR@20 = {test_mrr20:.3f}")


In [None]:
from evaluate import load

# Load BLEURT with pre-trained checkpoint
bleurt = load("bleurt", checkpoint="bleurt-base-512")

# Compute BLEURT scores
bleurt_results = bleurt.compute(
    predictions=agent_responses,
    references=true_answers
)

# Display average BLEURT score
average_bleurt = sum(bleurt_results["scores"]) / len(bleurt_results["scores"])
print(f"✅ Average BLEURT Score: {average_bleurt:.4f}")


In [None]:

# === Configure Embedding Model ===
Settings.embed_model = embed_model
Settings.llm = llm

# === Load Stored Index ===
storage_context = StorageContext.from_defaults(persist_dir="/content/ANLP_Project/index")

index = load_index_from_storage(storage_context)

# Use the boosted retriever
retriever = DatasetBoostRetriever(index=index, embed_model=embed_model)

chat_engine = index.as_chat_engine(
    chat_mode="context",  # Use context-based QA mode
    memory=ChatMemoryBuffer.from_defaults(token_limit=32000),  # Large memory for longer convos
    system_prompt=(
        "You are a biomedical assistant trained to answer questions about COVID-19, "
        "smoking, and the correlation between the two. You only answer based on the "
        "context provided through your knowledge base."
    )
)


In [None]:
from llama_index.core.chat_engine import ContextChatEngine
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.memory import ChatMemoryBuffer

# 1. Setup: Boosted Retriever
retriever = DatasetBoostRetriever(index=index, embed_model=embed_model)

# 2. Build the Chat Engine with the retriever
chat_engine = ContextChatEngine.from_defaults(
    retriever=retriever,  # ✅ Corrected line
    memory=ChatMemoryBuffer.from_defaults(token_limit=32000),
    system_prompt=(
        "You are a biomedical assistant trained to answer questions about COVID-19, "
        "smoking, and the correlation between the two. Only answer based on the "
        "retrieved context from the knowledge base."
    )
)


In [None]:
from llama_index.core.retrievers import VectorIndexRetriever

class DatasetBoostRetriever(VectorIndexRetriever):
    def _postprocess_nodes(self, nodes, query_bundle):
        for node in nodes:
            if node.metadata.get("is_dataset_linked"):
                node.score *= 1.3  # 🔼 Boost score if dataset-linked
        return sorted(nodes, key=lambda n: n.score, reverse=True)
