<a href="https://colab.research.google.com/github/tarupathak30/rag-astronomy-chatbot-/blob/main/bot_astro_exoplanets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os, shutil

# create folder
folder_name = "exoplanet_data"
os.makedirs(folder_name, exist_ok=True)

In [None]:
from google.colab import files
files.upload()

In [None]:
json_path = os.path.join(folder_name, "planets.json")

In [None]:
source = "/content/exoplanets.json"  # your existing file
destination = os.path.join(folder_name, "/content/exoplanet_data/exoplanets.json")

shutil.move(source, destination)

print("File moved:", destination)

In [None]:
!pip install langchain_community faiss-cpu

In [None]:
import os, json, re, math, torch, faiss
import numpy as np
from typing import List, Dict, Any

from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer

In [None]:
MODEL_ID       = "google/gemma-2-2b-it"
EMBED_MODEL    = "sentence-transformers/all-mpnet-base-v2"
DEVICE         = "cuda" if torch.cuda.is_available() else "cpu"
TOP_K          = 5
DEBUG          = True   # set True to inspect intermediate steps



In [None]:
from huggingface_hub import login

login()   # It will ask you to paste your token


In [None]:
print("Loading Gemma...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
llm = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16 if DEVICE=="cuda" else torch.float32,
    device_map="auto" if DEVICE=="cuda" else None
)
print("Loaded Gemma on", DEVICE)


In [None]:
print("Loading embedding model...")
embed_model = SentenceTransformer(EMBED_MODEL, device=DEVICE)
embed_dim = embed_model.get_sentence_embedding_dimension()
print("Embedding dimension =", embed_dim)

In [None]:

# ---------- UTILITIES ----------
def planet_to_text(obj):
    """Flatten planet JSON object into readable text."""
    parts = [obj.get("planet_name", "Unnamed Planet")]
    for k, v in obj.items():
        if isinstance(v, dict):
            parts.append(f"{k}: " + ", ".join(f"{ik}={iv}" for ik, iv in v.items()))
        elif k != "planet_name":
            parts.append(f"{k}: {v}")
    return "\n".join(parts)


def chunk_text(text, max_len=300):
    words = text.split()
    chunks, cur = [], []

    for w in words:
        cur.append(w)
        if len(cur) >= max_len:
            chunks.append(" ".join(cur))
            cur = []

    if cur:
        chunks.append(" ".join(cur))

    return chunks


In [None]:

# ---------- INDEX BUILDING ----------
def build_index(planets):
    embeddings = []
    meta_objects = []

    for planet in planets:
        text = planet_to_text(planet)
        chunks = chunk_text(text)

        for chunk in chunks:
            emb = embed_model.encode([chunk], convert_to_numpy=True)[0].astype("float32")
            embeddings.append(emb)
            meta_objects.append(planet)

    embeddings = np.vstack(embeddings)

    print("Total chunks:", len(embeddings))
    print("Embedding shape:", embeddings.shape)

    faiss.normalize_L2(embeddings)

    index = faiss.IndexFlatIP(embeddings.shape[1])
    index.add(embeddings)

    meta = {"objects": meta_objects}
    print("FAISS index built:", index.ntotal)
    return index, meta



In [None]:

# ---------- RETRIEVAL ----------
def retrieve_objects(query, index, meta, top_k=5):
    q_emb = embed_model.encode([query], convert_to_numpy=True).astype("float32")
    faiss.normalize_L2(q_emb)

    D, I = index.search(q_emb, top_k)

    results = []
    seen = set()

    for idx in I[0]:
        planet = meta["objects"][idx]
        pid = id(planet)
        if pid not in seen:
            seen.add(pid)
            results.append(planet)

    return results


In [None]:

# ---------- FACT CHECK LOGIC ----------
def fact_check(question, candidates):
    q = question.lower()

    # Longest orbital period
    if "longest orbital period" in q:
        return max(candidates, key=lambda c: c["planet_profile"].get("orbital_period_days", -1))

    # Shortest orbital period
    if "shortest orbital period" in q:
        return min(candidates, key=lambda c: c["planet_profile"].get("orbital_period_days", 1e18))

    # Largest radius
    if "largest radius" in q or "biggest planet" in q:
        return max(candidates, key=lambda c: c["planet_profile"].get("radius_earth_radii", -1))

    # Smallest radius
    if "smallest radius" in q:
        return min(candidates, key=lambda c: c["planet_profile"].get("radius_earth_radii", 1e18))

    # fallback = semantic reasoning
    scores = [
        embed_model.encode([question + " " + c["planet_name"]])[0].sum()
        for c in candidates
    ]
    return candidates[int(np.argmax(scores))]


In [None]:

# ---------- MAIN QA FUNCTION ----------
def answer_query(question, index, meta, top_k=5):
    print("\nüîç Query:", question)

    cands = retrieve_objects(question, index, meta, top_k)
    print("\nRetrieved:", [c["planet_name"] for c in cands])

    if not cands:
        return "No planets found."

    selected = fact_check(question, cands)

    return {
        "answer": f"The planet most aligned with your query is **{selected['planet_name']}**.",
        "planet": selected
    }


# ---------- LOAD ----------
def load_planets(path):
    with open(path, "r") as f:
        return json.load(f)


# ---------- RUN ----------
planets = load_planets("/content/exoplanet_data/exoplanets.json")
index, meta = build_index(planets)

out = answer_query("Which exoplanets were discovered in 2015?", index, meta)

print("\n=== FINAL OUTPUT ===")
print(out)

In [None]:
FILTER_PROMPT = """
You are a query-to-filter translator for an astronomy database.
Given a natural language question about exoplanets, return ONLY a JSON object
describing the structured filters. Do NOT add commentary.

Examples:

Q: "Which planets were discovered in 2015?"
A: {{"filter": {{"year": 2015}}}}

Q: "Planets with eccentricity greater than 0.5"
A: {{"filter": {{"planet_profile.eccentricity": {{">$": 0.5}}}}}}

Q: "Show me planets orbiting K-type stars"
A: {{"filter": {{"host_star.spectral_type": "K"}}}}

If no structured filter exists, return:
{{"filter": null}}

Now convert the following question:

Q: {query}
A:
"""


In [None]:
import re

def clean_json_str(text):
    # Extract the first JSON object found in the string
    match = re.search(r'{.*?}', text, re.DOTALL)
    if match:
        return match.group(0)
    return '{}'


In [None]:
def llm_to_filter(query):
    prompt = FILTER_PROMPT.format(query=query)
    inputs = tokenizer(prompt, return_tensors="pt")

    device = next(llm.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    outputs = llm.generate(**inputs, max_new_tokens=100)
    out_str = tokenizer.decode(outputs[0], skip_special_tokens=True)

    print("Raw LLM output:", out_str)  # Debug to see raw string

    json_str = clean_json_str(out_str)
    try:
        return json.loads(json_str)
    except json.JSONDecodeError as e:
        print("JSON decode error:", e)
        return {"filter": None}


In [None]:
import operator

OPS = {
    ">":  operator.gt,
    "<":  operator.lt,
    ">=": operator.ge,
    "<=": operator.le,
    "==": operator.eq,
}

def get_nested(obj, path):
    """Access nested fields like host_star.temperature_k"""
    parts = path.split(".")
    for p in parts:
        obj = obj.get(p)
        if obj is None:
            return None
    return obj


def apply_filter(planets, filt):
    if filt is None:
        return planets  # no filtering, return all

    def get_nested(d, keys):
        for key in keys:
            if d is None or key not in d:
                return None
            d = d[key]
        return d

    filtered = []
    for planet in planets:
        match = True
        for k, v in filt.items():
            keys = k.split(".")
            val = get_nested(planet, keys)
            if isinstance(v, dict):
                # Handle operators like {">$": 0.5}, etc.
                for op, comp_val in v.items():
                    if op == ">$" and not (val is not None and val > comp_val):
                        match = False
                    # Add other ops as needed
            else:
                if val != v:
                    match = False
            if not match:
                break
        if match:
            filtered.append(planet)
    return filtered



In [None]:
def retrieve_with_filter(question, index, meta, top_k=10):
    # Step 1: LLM ‚Üí filter
    filt = llm_to_filter(question)["filter"]

    # Step 2: Apply filter if exists
    filtered_planets = apply_filter(meta["objects"], filt) if filt else meta["objects"]

    if not filtered_planets:
        return []

    # Step 3: Semantic narrowing
    # turn into mini-index
    tmp_texts = [planet_to_text(p) for p in filtered_planets]
    tmp_embs = embed_model.encode(tmp_texts)
    tmp_embs = tmp_embs.astype("float32")

    q_emb = embed_model.encode([question]).astype("float32")
    faiss.normalize_L2(q_emb)
    faiss.normalize_L2(tmp_embs)

    D, I = faiss.IndexFlatIP(tmp_embs.shape[1]).search(tmp_embs, 1)

    # rank filtered planets by their cosine similarity
    ranked = sorted(
        zip(filtered_planets, D.squeeze().tolist()),
        key=lambda x: x[1],
        reverse=True
    )

    return [p for p, score in ranked[:top_k]]


In [None]:
def answer_query_v2(question, index, meta):
    print("\nüîç Query:", question)

    cands = retrieve_with_filter(question, index, meta)

    if not cands:
        return "No matching planets found."

    best = cands[0]

    return {
        "answer": f"Best match: **{best['planet_name']}**",
        "planet": best,
    }


In [None]:
result = answer_query_v2("Which exoplanets were discovered in 2015?", index, meta)

print("\n=== FINAL OUTPUT ===")
print(result)

In [None]:
import nbformat as nbf

fname = "exoplanets.ipynb"
out = "bot_exoplanets_.ipynb"

nb = nbf.read(fname, as_version=4)

# remove metadata that GitHub hates
for cell in nb["cells"]:
    if "metadata" in cell:
        for key in ["id", "colab", "outputId", "executionInfo"]:
            cell["metadata"].pop(key, None)

# remove notebook-level widget metadata
for key in ["colab", "widgets"]:
    nb["metadata"].pop(key, None)

nbf.write(nb, out)
print("üî• Cleaned! Upload bot_astro_exoplanets_GITHUB.ipynb to GitHub.")
