In [None]:
!pip install langchain-groq tavily-python langgraph PyPDF2 reportlab sentence-transformers langsmith


import os
os.environ["GROQ_API_KEY"] = "GROQ_API_KEY"
os.environ["TAVILY_API_KEY"] = "TAVILY_API_KEY"
os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_API_KEY"] = "LANGSMITH_API_KEY"

Collecting langchain-groq
  Downloading langchain_groq-1.1.1-py3-none-any.whl.metadata (2.4 kB)
Collecting tavily-python
  Downloading tavily_python-0.7.17-py3-none-any.whl.metadata (9.0 kB)
Collecting PyPDF2
  Downloading pypdf2-3.0.1-py3-none-any.whl.metadata (6.8 kB)
Collecting reportlab
  Downloading reportlab-4.4.6-py3-none-any.whl.metadata (1.7 kB)
Collecting groq<1.0.0,>=0.30.0 (from langchain-groq)
  Downloading groq-0.37.1-py3-none-any.whl.metadata (16 kB)
Downloading langchain_groq-1.1.1-py3-none-any.whl (19 kB)
Downloading tavily_python-0.7.17-py3-none-any.whl (18 kB)
Downloading pypdf2-3.0.1-py3-none-any.whl (232 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.6/232.6 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading reportlab-4.4.6-py3-none-any.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading groq-0.37.1-py3-none-any.whl (137 kB)
[2K   [

In [None]:


import os, re, json, datetime, html, time, hashlib
from dataclasses import dataclass
from typing import List, Optional, TypedDict
from pathlib import Path
import importlib

from langgraph.graph import StateGraph, END
from IPython.display import Markdown, display
os.environ['EMBEDDING_DEBUG'] = '0'

# Global buffer for ALL raw LLM logs

GLOBAL_RAW_LOG: List[str] = []

#  LangSmith tracing

try:
    from langsmith import traceable
except ImportError:
    # No-op decorator if langsmith not installed
    def traceable(*targs, **tkwargs):
        def decorator(fn):
            return fn
        return decorator

# API Keys (Groq + Tavily)

GROQ_API_KEY = os.getenv("GROQ_API_KEY", "").strip()
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "").strip()

if not GROQ_API_KEY or not TAVILY_API_KEY:
    raise SystemExit(
        "ERROR: GROQ_API_KEY and TAVILY_API_KEY must be set in environment.\n"
        "In Colab, run:\n"
        "import os\n"
        "os.environ['GROQ_API_KEY'] = 'your_groq_key'\n"
        "os.environ['TAVILY_API_KEY'] = 'your_tavily_key'"
    )


# External clients & PDF lib (reading)

try:
    from tavily import TavilyClient
except Exception:
    raise RuntimeError("Missing tavily library. Install with: pip install tavily-python")

# PDF lib for reading user PDFs
_pdf_lib = None
if importlib.util.find_spec("PyPDF2"):
    import PyPDF2 as _pdf_lib
else:
    _pdf_lib = None  # PDF ingestion will error if not installed

tavily_client = TavilyClient(api_key=TAVILY_API_KEY)


# PDF writer for saving raw LLM outputs

try:
    from reportlab.pdfgen import canvas
    from reportlab.lib.pagesizes import letter
    _pdf_writer_available = True
except ImportError:
    _pdf_writer_available = False


def save_raw_to_pdf(text: str, filename: str):
    if not _pdf_writer_available:
        print("[PDF] reportlab not installed; skipping PDF save.")
        return
    try:
        c = canvas.Canvas(filename, pagesize=letter)
        width, height = letter
        x_margin = 40
        y = height - 50
        max_width_chars = 110
        for line in text.split("\n"):
            while len(line) > max_width_chars:
                chunk = line[:max_width_chars]
                line = line[max_width_chars:]
                if y < 40:
                    c.showPage()
                    y = height - 50
                c.drawString(x_margin, y, chunk)
                y -= 15
            if y < 40:
                c.showPage()
                y = height - 50
            c.drawString(x_margin, y, line)
            y -= 15
        c.save()
        print(f"[PDF Saved] {filename}")
    except Exception as e:
        print(f"[PDF ERROR] Failed to save PDF '{filename}': {e}")


def save_all_raw_to_one_pdf(filename: str = "all_llm_raw_output.pdf"):
    if not GLOBAL_RAW_LOG:
        print("[PDF] No raw LLM data to save.")
        return
    full_text = "\n".join(GLOBAL_RAW_LOG)
    save_raw_to_pdf(full_text, filename)



# Groq LLM via LangChain

from langchain_groq import ChatGroq

GROQ_MODEL = os.getenv("GROQ_MODEL", "llama-3.1-8b-instant")
_LLM_CACHE_PATH = Path("/tmp/groq_inference_cache.json")
try:
    _LLM_CACHE = json.load(open(_LLM_CACHE_PATH)) if _LLM_CACHE_PATH.exists() else {}
except Exception:
    _LLM_CACHE = {}


def _save_llm_cache():
    try:
        json.dump(_LLM_CACHE, open(_LLM_CACHE_PATH, "w"), indent=2, ensure_ascii=False)
    except Exception:
        pass


def _prompt_key(prompt: str) -> str:
    return hashlib.sha256(prompt.encode("utf-8")).hexdigest()


groq_llm = ChatGroq(
    model=GROQ_MODEL,
    temperature=0.0,
    max_tokens=800,
    groq_api_key=GROQ_API_KEY,
)


@traceable(name="groq_llm_call")
def groq_chat(prompt: str, max_retries: int = 3, max_new_tokens: int = 800, timeout: int = 60) -> str:

    key = _prompt_key(prompt)
    if key in _LLM_CACHE:
        txt = _LLM_CACHE[key]
        GLOBAL_RAW_LOG.append(
            f"\n\n=== CACHED LLM CALL ===\nTimestamp: {datetime.datetime.now(datetime.timezone.utc).isoformat()}\n"
            f"Prompt:\n{prompt}\n\nResponse:\n{txt}\n"
        )
        return txt
    try:
        resp = groq_llm.invoke(prompt)
    except Exception as e:
        print(f"Groq LLM error: {e}")
        return ""
    txt = ""
    try:
        if hasattr(resp, "content"):
            if isinstance(resp.content, str):
                txt = resp.content
            elif isinstance(resp.content, list):
                txt = "".join(str(part) for part in resp.content)
            else:
                txt = str(resp.content)
        else:
            txt = str(resp)
    except Exception:
        txt = str(resp)
    txt = (txt or "").strip()
    _LLM_CACHE[key] = txt
    _save_llm_cache()
    GLOBAL_RAW_LOG.append(
        f"\n\n=== RAW LLM CALL ===\nTimestamp: {datetime.datetime.now(datetime.timezone.utc).isoformat()}\n"
        f"Prompt:\n{prompt}\n\nResponse:\n{txt}\n"
    )
    return txt



# Checkpoint structure

@dataclass
class Checkpoint:
    id: str
    topic: str
    objectives: List[str]
    success_criteria: str


CHECKPOINTS = [
    Checkpoint(
        id="cp1",
        topic="Basics of Neural Networks",
        objectives=[
            "Understand what an artificial neuron is",
            "Understand input, hidden, and output layers",
            "Understand the concept of forward propagation",
        ],
        success_criteria="Learner can explain a simple feedforward neural network.",
    ),
    Checkpoint(
        id="cp2",
        topic="Gradient Descent",
        objectives=[
            "Understand loss minimization",
            "Understand gradient as slope of loss",
            "Understand iterative parameter updates",
        ],
        success_criteria="Learner can describe how gradient descent updates parameters.",
    ),
    Checkpoint(
        id="cp3",
        topic="Activation Functions",
        objectives=[
            "Know common activations (sigmoid, tanh, relu)",
            "Understand when/why to use each",
        ],
        success_criteria="Learner can choose and justify an activation for a simple task.",
    ),
    Checkpoint(
        id="cp4",
        topic="Backpropagation",
        objectives=[
            "Understand chain rule for gradients",
            "Understand how weight updates propagate backward",
        ],
        success_criteria="Learner can explain backpropagation at high level.",
    ),
]



# Agent State

class AgentState(TypedDict):
    cp_id: str
    checkpoint: Optional[Checkpoint]
    user_notes: str
    user_pdfs: List[str]
    gathered_context: str
    context_sources: List[str]
    relevance_score_model: Optional[int]
    refetch_attempted: bool
    score_meta: Optional[dict]

    processed_chunks: List[str]
    questions: List[str]
    learner_answers: List[str]
    score_percent: Optional[float]
    pass_threshold_met: Optional[bool]
    # added field:
    temp_vector_store: Optional[dict]


_score_re = re.compile(r"\b([1-5])\b")


def parse_score_from_text(raw: str) -> int:
    if not raw:
        return 3
    m = _score_re.search(raw)
    return int(m.group(1)) if m else 3



# PDF extraction helpers

def extract_text_from_pdf(path: str) -> str:
    if not _pdf_lib:
        raise RuntimeError("PyPDF2 not installed. `pip install PyPDF2` to enable PDF ingestion.")
    text = []
    with open(path, "rb") as f:
        reader = _pdf_lib.PdfReader(f)
        for p in reader.pages:
            try:
                page_text = p.extract_text() or ""
            except Exception:
                page_text = ""
            if page_text:
                text.append(page_text)
    return "\n".join(text)


def gather_texts_from_pdfs(paths: List[str]) -> str:
    out = ""
    for p in paths:
        try:
            t = extract_text_from_pdf(p)
            if t.strip():
                out += f"\n--- PDF: {os.path.basename(p)} ---\n{t}\n"
        except Exception as e:
            print(f"Failed to read PDF {p}: {e}")
    return out



# Summarizer (uses Groq)

def summarize_text(text: str) -> str:
    if not text.strip():
        return text
    prompt = f"""
Summarize the following text into a focused explanation matching the learning objectives.
Keep it concise, clean, and relevant.

Text:
\"\"\"{text[:5000]}\"\"\"\n
Summary:
"""
    summary = groq_chat(prompt)
    return summary.strip() or text[:5000]



# Tavily wrapper with rate limiting (< 10 searches/min)

_search_timestamps: List[float] = []


def _enforce_search_rate_limit():
    global _search_timestamps
    now = time.time()
    _search_timestamps = [t for t in _search_timestamps if now - t < 60]
    if len(_search_timestamps) >= 9:
        oldest = min(_search_timestamps)
        wait = 60 - (now - oldest)
        if wait > 0:
            print(f"[Rate Limit] Tavily search rate reached. Waiting {wait:.1f}s to stay under 10/min...")
            time.sleep(wait)
        now = time.time()
        _search_timestamps = [t for t in _search_timestamps if now - t < 60]
    _search_timestamps.append(time.time())


@traceable(name="tavily_search")
def search_tavily(query: str, max_results: int = 5) -> List[dict]:
    _enforce_search_rate_limit()
    try:
        res = tavily_client.search(query=query, max_results=max_results)
        return res.get("results", []) if isinstance(res, dict) else []
    except Exception as e:
        print("Tavily search failed:", e)
        return []



# Evidence cleaning & user JSON helpers

def clean_evidence(raw: str) -> str:
    if not raw:
        return ""
    s = raw
    s = re.sub(r"```.*?```", "", s, flags=re.DOTALL)
    s = re.sub(r"\s*\\n\s*", " ", s)
    s = re.sub(r'^\s*\{.*?"evidence"\s*:\s*', "", s, flags=re.DOTALL)
    s = s.replace('"covered":', "")
    s = s.replace("{", "").replace("}", "")
    s = s.replace('"""', "").replace("'''", "")
    s = s.strip()
    s = re.sub(r'^\s*["\']?evidence["\']?\s*[:\-]?\s*', "", s, flags=re.I)
    s = html.unescape(s).strip()
    if len(s) > 300:
        s = s[:300].rsplit(" ", 1)[0] + "..."
    return s


def simplify_score_meta_for_user(score_meta: Optional[dict]):
    if not score_meta:
        return None
    covered = score_meta.get("covered_count", 0)
    total = score_meta.get("total", 1)
    objectives = []
    for d in score_meta.get("details", []):
        objectives.append(
            {
                "objective": d.get("objective"),
                "covered": True if str(d.get("covered", "no")).lower() == "yes" else False,
                "evidence": clean_evidence(d.get("evidence", "")),
            }
        )
    coverage_percent = int(round((covered / total) * 100))
    summary = f"{covered}/{total} objectives covered"
    explain = f"{coverage_percent}% — {summary}"
    return {
        "coverage_percent": coverage_percent,
        "summary": summary,
        "explain": explain,
        "objective_reports": objectives,
    }


# Embeddings + Temporary in-memory vector store (Milestone 2)

_emb_model = None
_np = None
_EMBEDDING_DEBUG = os.getenv("EMBEDDING_DEBUG", "0") in ("1", "true", "True")

# Chunking configuration
CHUNK_SIZE = 1200
CHUNK_OVERLAP = 250
MIN_CHUNK_LENGTH = 300

try:
    from sentence_transformers import SentenceTransformer
    import numpy as np

    try:
        print("[Embedding] Loading SentenceTransformer 'all-MiniLM-L6-v2' ... (this may take a few seconds)")
        _emb_model = SentenceTransformer("all-MiniLM-L6-v2")
        _np = np
        print("[Embedding] Model loaded:all-MiniLM-L6-v2")
    except Exception as e:
        print(f"[Embedding] Failed to load SentenceTransformer model: {e}")
        _emb_model = None
        _np = None
except Exception:
    _emb_model = None
    try:
        import numpy as np
        _np = np
    except Exception:
        _np = None


def is_context_relevant_semantically(
    context: str,
    cp,
    threshold: float = 0.35
) -> bool:
    """
    Embedding-based semantic relevance check.
    Returns False if context is unrelated to checkpoint topic/objectives.
    """
    if not context.strip():
        return False

    # If embeddings unavailable, don't block pipeline
    if _emb_model is None or _np is None:
        return True

    reference_text = cp.topic + " " + " ".join(cp.objectives)

    try:
        ctx_vec = _emb_model.encode([context[:2000]], convert_to_numpy=True)
        ref_vec = _emb_model.encode([reference_text], convert_to_numpy=True)

        ctx_vec = ctx_vec / (_np.linalg.norm(ctx_vec, axis=1, keepdims=True) + 1e-12)
        ref_vec = ref_vec / (_np.linalg.norm(ref_vec, axis=1, keepdims=True) + 1e-12)

        similarity = float(ctx_vec @ ref_vec.T)

        if _EMBEDDING_DEBUG:
            print(f"[Semantic Relevance] similarity={similarity:.3f}")

        return similarity >= threshold

    except Exception:
        return True




def build_temp_vector_store(chunks: List[str]):
    """
    Build temporary in-memory vector store for this session.
    Returns dict: { 'chunks': [...], 'vectors': np.array or None, 'meta': {...} }
    """
    if not chunks:
        return {"chunks": [], "vectors": None, "meta": {"embeddings_used": False}}
    if _emb_model and _np is not None:
        try:
            vecs = _emb_model.encode(chunks, convert_to_numpy=True, show_progress_bar=False)
            # normalize vectors for cosine sim
            norms = _np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12
            vecs = vecs / norms
            store = {"chunks": chunks, "vectors": vecs, "meta": {"embeddings_used": True}}
            if _EMBEDDING_DEBUG:
                print(f"[Embedding] Built vector store: vectors shape = {vecs.shape}")
            return store
        except Exception as e:
            print(f"[Embedding] Exception while encoding chunks: {e}")
            pass
    # fallback: no embeddings (vectors=None) — token-overlap will be used
    if _EMBEDDING_DEBUG:
        print("[Embedding] No embedding model available; using token-overlap fallback.")
    return {"chunks": chunks, "vectors": None, "meta": {"embeddings_used": False}}


def embedding_debug_print(store, label: str = ""):
    try:
        emb_used = bool(store and store.get("vectors") is not None)
        if emb_used:
            vecs = store["vectors"]
            shape = getattr(vecs, "shape", None)
            print(f"[Embedding Confirm] {label} embeddings used: True | vector shape: {shape}")
            # print a tiny sample: first vector first 6 values
            sample = vecs[0][:6].tolist() if hasattr(vecs[0], "tolist") else list(vecs[0][:6])
            print(f"[Embedding Confirm] sample vec[0][:6] ~ {sample}")
        else:
            print(f"[Embedding Confirm] {label} embeddings used: False")
    except Exception as e:
        print(f"[Embedding Confirm] Error during debug print: {e}")


def retrieve_top_k(store, query: str, k: int = 3) -> List[str]:
    chunks = store.get("chunks", []) or []
    vectors = store.get("vectors", None)

    if not chunks:
        return []

    if vectors is not None and _emb_model is not None and _np is not None:
        try:
            qv = _emb_model.encode([query], convert_to_numpy=True, show_progress_bar=False)
            qv = qv / (_np.linalg.norm(qv, axis=1, keepdims=True) + 1e-12)

            sims = vectors @ qv.T
            sims = sims.reshape(-1)   # ✅ CRITICAL FIX

            idx_sorted = sims.argsort()[::-1][:k]
            top_chunks = [chunks[i] for i in idx_sorted if i < len(chunks)]

            if _EMBEDDING_DEBUG:
                top_info = [(int(i), float(sims[i])) for i in idx_sorted]
                print(f"[Retrieve Debug] top-k indices+sim: {top_info}")

            return top_chunks

        except Exception as e:
            print(f"[Retrieve] Embedding retrieval failed: {e}")

    # ---- fallback: token overlap ----
    q_words = set(re.findall(r"\w+", query.lower()))
    scored = []
    for i, c in enumerate(chunks):
        c_words = set(re.findall(r"\w+", c.lower()))
        scored.append((len(q_words & c_words), i))

    scored.sort(reverse=True)
    return [chunks[i] for s, i in scored[:k] if s > 0] or chunks[:k]

    if top:
        if _EMBEDDING_DEBUG:
            print(f"[Retrieve Debug] token-overlap scores (top): {scored[:k]}")
        return top
    # last fallback: first k chunks
    return chunks[:k]



# LangGraph Nodes

def get_checkpoint_by_id(cp_id: str) -> Checkpoint:
    for cp in CHECKPOINTS:
        if cp.id == cp_id:
            return cp
    raise ValueError(f"Checkpoint {cp_id} not found")


def start_checkpoint(state: AgentState) -> AgentState:
    cp = get_checkpoint_by_id(state["cp_id"])
    state = dict(state)
    state["checkpoint"] = cp
    state["gathered_context"] = ""
    state["context_sources"] = []
    state["score_meta"] = None
    state["processed_chunks"] = []
    state["questions"] = []
    state["score_percent"] = None
    state["pass_threshold_met"] = None
    state["temp_vector_store"] = None
    return state


@traceable(name="gather_context_node")
def gather_context(state: AgentState) -> AgentState:
    cp = state["checkpoint"]
    notes = state["user_notes"]
    pdfs = state.get("user_pdfs", [])
    context = ""
    sources = []
    if notes.strip():
        context += f"User Notes:\n{notes.strip()}\n"
        sources.append("user_notes")
    if pdfs:
        pdf_text = gather_texts_from_pdfs(pdfs)
        if pdf_text.strip():
            context += pdf_text
            sources.append("pdf_upload")
    if not context.strip():
        query = f"{cp.topic} - " + "; ".join(cp.objectives)
        results = search_tavily(query=query)
        for item in results:
            content = item.get("content")
            if content:
                context += content + "\n"
        if context.strip():
            sources.append("web_search")
            context = summarize_text(context)
    if context.strip() and len(context) > 5000:
        context = summarize_text(context)
    state = dict(state)
    state["gathered_context"] = context
    state["context_sources"] = sources
    print("\n" + "="*80)
    print(f"[Gathered Context] Source(s): {', '.join(sources) if sources else 'None'}")
    print("-"*80)
    print(context.strip() if context.strip() else "[No context gathered]")
    print("="*80 + "\n")

    return state


@traceable(name="validate_context_node")
def validate_context(state: AgentState) -> AgentState:
    cp = state["checkpoint"]
    context = state["gathered_context"]
    refetch = state["refetch_attempted"]

    #  SEMANTIC FILTER
    is_relevant = is_context_relevant_semantically(context, cp)

    if not is_relevant and not refetch:
        print("[Semantic Filter] Context is unrelated. Refetching from web...")

        query = f"{cp.topic} for beginners; " + "; ".join(cp.objectives)
        results = search_tavily(query=query)

        new_context = ""
        for item in results:
            c = item.get("content")
            if c:
                new_context += c + "\n"

        if new_context.strip():
            new_context = summarize_text(new_context)

        context = new_context
        refetch = True


    def score_ctx_objectives_only(ctx: str):
        objectives = cp.objectives
        covered = 0
        details = []

        for obj in objectives:
            prompt = f"""
You must respond with a single valid JSON object and NOTHING else.

Keys:
- "covered": either "yes" or "no"
- "evidence": one short sentence (<=30 words) quoting or paraphrasing the CONTEXT

OBJECTIVE:
\"\"\"{obj}\"\"\"\n
CONTEXT:
\"\"\"{ctx[:8000]}\"\"\"\n
Return only JSON like: {{ "covered": "yes", "evidence": "..." }}
"""
            raw = groq_chat(prompt).strip()

            cov = "no"
            evidence = ""

            try:
                import json as _json, re as _re
                m = _re.search(r"\{.*\}", raw, _re.DOTALL)
                if m:
                    parsed = _json.loads(m.group(0))
                    cov = str(parsed.get("covered", "no")).lower()
                    evidence = str(parsed.get("evidence", "")).strip()
                    if cov not in ("yes", "no"):
                        cov = "no"
            except Exception:
                pass

            if cov == "yes":
                covered += 1

            details.append({
                "objective": obj,
                "covered": cov,
                "evidence": evidence
            })

        total = len(objectives) or 1
        base_score = round((covered / total) * 5)
        score = min(5, max(1, base_score))

        meta = {
            "covered_count": covered,
            "total": total,
            "details": details
        }
        return score, meta

    score, score_meta = score_ctx_objectives_only(context)

    if score <= 2 and not refetch:
        print("Low relevance detected. Refetching...")
        query = f"{cp.topic} for beginners; " + "; ".join(cp.objectives)
        results = search_tavily(query=query)

        new_context = ""
        for item in results:
            c = item.get("content")
            if c:
                new_context += c + "\n"

        if new_context.strip():
            new_context = summarize_text(new_context)

        score, score_meta = score_ctx_objectives_only(new_context)
        context = new_context
        refetch = True

    state = dict(state)
    state["gathered_context"] = context
    state["relevance_score_model"] = score
    state["refetch_attempted"] = refetch
    state["score_meta"] = score_meta

    if refetch and "web_search" not in state["context_sources"]:
        state["context_sources"].append("web_search")

    print(f"Score for {cp.id}: {score} ({score_meta['covered_count']}/{score_meta['total']} objectives covered)")
    for d in score_meta["details"]:
        print(f" - Obj: {d['objective'][:60]}... => {d['covered']} | evidence: {clean_evidence(d['evidence'])[:140]}")

    return state



# Milestone 2: Context processing (chunking) + build temp vector store

@traceable(name="process_context_node")
def process_context(state: AgentState) -> AgentState:
    context = (state.get("gathered_context") or "").strip()
    chunks: List[str] = []

    if context:
        start = 0
        text_len = len(context)

        while start < text_len:
            end = start + CHUNK_SIZE
            chunk = context[start:end].strip()

            if len(chunk) >= MIN_CHUNK_LENGTH:
                chunks.append(chunk)

            # move forward with overlap
            start = end - CHUNK_OVERLAP
            if start < 0:
                start = 0

            # stop infinite loop
            if start >= text_len:
                break

    # Fallback safety
    if not chunks and context:
        chunks = [context]

    state = dict(state)
    state["processed_chunks"] = chunks

    # Build temporary vector store (Milestone 2 requirement)
    store = build_temp_vector_store(chunks)
    state["temp_vector_store"] = store

    # Debug confirmation
    if _EMBEDDING_DEBUG:
        print(f"[Chunking] Produced {len(chunks)} chunks "
              f"(size={CHUNK_SIZE}, overlap={CHUNK_OVERLAP})")
        embedding_debug_print(store, label=f"cp={state.get('cp_id')}")

    return state




# Milestone 2: Question generation (uses retrieval of top-k chunks to focus generation)

@traceable(name="generate_questions_node")
def generate_questions(state: AgentState) -> AgentState:
    cp = state["checkpoint"]
    chunks = state.get("processed_chunks") or []
    store = state.get("temp_vector_store") or {"chunks": chunks, "vectors": None, "meta": {"embeddings_used": False}}

    # Use retrieval: combine top-k chunks for question generation
    focus_for_generation = ""
    if chunks:
        query = cp.topic + " " + " ".join(cp.objectives)
        top_chunks = retrieve_top_k(store, query=query, k=3)
        focus_for_generation = "\n\n".join(top_chunks)
    else:
        focus_for_generation = "\n\n".join(chunks)[:4000]

    base_context = (focus_for_generation or "\n\n".join(chunks))[:4000]
    prompt = f"""
You are a helpful tutor.

Based on the following CONTEXT and LEARNING OBJECTIVES,
generate 3 to 5 short, focused questions that test conceptual understanding.
Avoid yes/no questions. Make them open-ended but concise.

Return ONLY a JSON object like:
{{ "questions": ["Q1 ...", "Q2 ...", "..."] }}

CONTEXT:
\"\"\"{base_context}\"\"\"\n
OBJECTIVES:
{json.dumps(cp.objectives, ensure_ascii=False)}
"""
    raw = groq_chat(prompt).strip()
    questions: List[str] = []
    try:
        m = re.search(r"\{.*\}", raw, re.DOTALL)
        if m:
            data = json.loads(m.group(0))
            qlist = data.get("questions", [])
            if isinstance(qlist, list):
                questions = [q.strip() for q in qlist if isinstance(q, str) and q.strip()]
    except Exception:
        pass
    if not questions:
        for line in raw.splitlines():
            line = line.strip()
            if not line:
                continue
            line = re.sub(r"^[\-\*\d\.\)\s]+", "", line)
            if len(line) > 5:
                questions.append(line)
    questions = questions[:5]
    state = dict(state)
    state["questions"] = questions
    print(f"Generated {len(questions)} questions for {cp.id}")
    return state



# Milestone 2: Understanding verification (scoring)
# - Strict: counts too-short answers as 0, uses retrieval per-question for focused grading

@traceable(name="verify_understanding_node")
def verify_understanding(state):

    cp = state["checkpoint"]
    questions = state.get("questions") or []
    answers = state.get("learner_answers") or []
    store = state.get("temp_vector_store") or {"chunks": state.get("processed_chunks", []), "vectors": None, "meta": {"embeddings_used": False}}

    if not questions or not answers:
        print(f"No questions or learner answers for {cp.id}; skipping verification.")
        state = dict(state)
        state["score_percent"] = None
        state["pass_threshold_met"] = None
        return state

    n = min(len(questions), len(answers))
    if n == 0:
        print(f"No overlapping Q/A pairs for {cp.id}; skipping verification.")
        state = dict(state)
        state["score_percent"] = None
        state["pass_threshold_met"] = None
        return state

    scores = []

    def meaningful_word_count(text: str) -> int:
        if not text or not text.strip():
            return 0
        words = re.findall(r"\w+", text)
        return sum(1 for w in words if len(w) > 2)

    for i in range(n):
        q = questions[i]
        a = answers[i] or ""

        # quick strict rule: if answer has less than 3 meaningful words → score 0
        if meaningful_word_count(a) < 3:
            score_val = 0
            print(f"[Verify] {cp.id} Q{i+1} detected as empty/too-short → score: 0")
            scores.append(score_val)
            continue

        # Retrieve top-k chunks relevant to this question for precise grading
        top_chunks = retrieve_top_k(store, query=q, k=3)
        context_for_grading = "\n\n".join(top_chunks)[:4000] if top_chunks else (state.get("gathered_context")[:4000] or "")

        prompt = f"""
You are an AI tutor grading a learner's short answer. RESPOND ONLY WITH A SINGLE JSON OBJECT and NOTHING ELSE.

CONTEXT (ground truth reference):
\"\"\"{context_for_grading}\"\"\"\n

LEARNING OBJECTIVES:
{json.dumps(cp.objectives, ensure_ascii=False)}

QUESTION:
{q}

LEARNER ANSWER:
{a}

SCORING RULES (STRICT):
1. If the learner answer is empty, contains only punctuation, or has fewer than 3 meaningful words (words with length >2), return score 0.
2. Check whether the answer mentions or correctly paraphrases key concepts from CONTEXT and OBJECTIVES.
   - If the answer shows no relevant concepts from the context → return a score in [0,30].
   - If the answer shows partial understanding (mentions some correct concepts but misses important parts) → return a score in (30,70).
   - If the answer is largely correct and aligned with context → return a score in [70,100].
3. Be strict: do not give partial credit for blank or off-topic answers.
4. Output MUST be valid JSON only, with keys:
   {{ "score": <integer 0-100>, "explanation": "<short reason, <=30 words>" }}

Now produce the JSON for this question strictly following rules above.
"""
        raw = groq_chat(prompt).strip()

        score_val = 0
        try:
            m = re.search(r"\{.*\}", raw, re.DOTALL)
            if m:
                data = json.loads(m.group(0))
                score_val = int(data.get("score", 0))
        except Exception:
            m2 = re.search(r"(\d{1,3})", raw)
            if m2:
                score_val = int(m2.group(1))
        score_val = max(0, min(100, score_val))
        scores.append(score_val)
        print(f"[Verify] {cp.id} Q{i+1} score: {score_val}")

    avg_score = sum(scores) / len(scores) if scores else 0.0
    passed = avg_score >= 70.0
    state = dict(state)
    state["score_percent"] = avg_score
    state["pass_threshold_met"] = passed
    print(f"Overall quiz score for {cp.id}: {avg_score:.1f}% (pass >= 70%)")
    return state


def feynman_node(state: AgentState) -> AgentState:
    cp = state["checkpoint"]
    print(
        f"[Feynman Placeholder] Learner did NOT meet 70% threshold for {cp.id} ({cp.topic}). "
        "In future, this node will trigger Feynman-style explanation & re-teaching."
    )
    return state


def route_after_verification(state: AgentState) -> str:
    score = state.get("score_percent") or 0.0
    if score >= 70.0:
        return "pass"
    else:
        return "feynman"



# Build graph

def build_graph():
    g = StateGraph(AgentState)
    g.add_node("start_checkpoint", start_checkpoint)
    g.add_node("gather_context", gather_context)
    g.add_node("validate_context", validate_context)
    g.add_node("process_context", process_context)
    g.add_node("generate_questions", generate_questions)
    g.add_node("verify_understanding", verify_understanding)
    g.add_node("feynman_node", feynman_node)
    g.set_entry_point("start_checkpoint")
    g.add_edge("start_checkpoint", "gather_context")
    g.add_edge("gather_context", "validate_context")
    g.add_edge("validate_context", "process_context")
    g.add_edge("process_context", "generate_questions")
    g.add_edge("generate_questions", "verify_understanding")
    g.add_conditional_edges(
        "verify_understanding",
        route_after_verification,
        {
            "pass": END,
            "feynman": "feynman_node",
        },
    )
    return g.compile()


graph = build_graph()



# Helper: read multi-line input (end with an 'END' line)

def read_multiline(prompt_msg: str) -> str:
    print(prompt_msg)
    print("Enter/Paste your text. End with a single line containing only: END")
    lines = []
    while True:
        try:
            line = input()
        except EOFError:
            break
        if line.strip() == "END":
            break
        lines.append(line)
    return "\n".join(lines).strip()



# Interactive per-checkpoint runner

def run_single_checkpoint_interactive(cp_id: str) -> dict:
    cp = get_checkpoint_by_id(cp_id)
    print(f"\n--- Checkpoint {cp.id}: {cp.topic} ---")
    notes = read_multiline("Provide user notes for this checkpoint (or leave blank and type END):")
    pdfs_input = input("Enter comma-separated PDF paths for this checkpoint (or leave blank): ").strip()
    pdfs = [p.strip() for p in pdfs_input.split(",") if p.strip()] if pdfs_input else []

    # initialize state and run nodes up to question generation
    state: AgentState = {
        "cp_id": cp_id,
        "checkpoint": None,
        "user_notes": notes,
        "user_pdfs": pdfs,
        "gathered_context": "",
        "context_sources": [],
        "relevance_score_model": None,
        "refetch_attempted": False,
        "score_meta": None,
        "processed_chunks": [],
        "questions": [],
        "learner_answers": [],
        "score_percent": None,
        "pass_threshold_met": None,
        "temp_vector_store": None,
    }

    # start -> gather -> validate -> process -> generate
    state = start_checkpoint(state)
    state = gather_context(state)
    state = validate_context(state)
    state = process_context(state)
    state = generate_questions(state)

    # Display generated questions and collect answers interactively
    questions = state.get("questions") or []
    print("\nGenerated questions:")
    for i, q in enumerate(questions, start=1):
        print(f"{i}. {q}")

    print("\nNow provide your answers. For each question, paste the answer and end with END.")
    answers = []
    for i, q in enumerate(questions, start=1):
        ans = read_multiline(f"\nAnswer for Q{i}: {q}")
        answers.append(ans)

    state["learner_answers"] = answers

    # verify
    state = verify_understanding(state)

    # if failed, run feynman placeholder
    if not state.get("pass_threshold_met"):
        state = feynman_node(state)

    # return summary info
    return {
        "cp_id": cp_id,
        "topic": cp.topic,
        "context_score": state.get("relevance_score_model"),
        "quiz_score": state.get("score_percent"),
        "passed": state.get("pass_threshold_met"),
        "sources": state.get("context_sources"),
        "questions": state.get("questions"),
        "answers": state.get("learner_answers"),
    }



# Evaluation suite (automated tests for Q relevance & scoring)

def _generate_good_answer_for_question(context: str, question: str) -> str:

    prompt = f"""
Use the CONTEXT below to write a concise, correct, and focused answer to the QUESTION.
Answer must be at least 25 words (to avoid short-answer penalties) and at most 80 words.
Keep it factual and directly relevant to the question.

CONTEXT:
\"\"\"{context[:3500]}\"\"\"\n

QUESTION:
{question}

Answer:
"""
    raw = groq_chat(prompt)
    txt = raw.strip()
    words = re.findall(r"\w+", txt)
    if len([w for w in words if len(w) > 2]) < 15:
        fallback = f"Provide a direct, explanatory answer (>=25 words) to: {question}\nUsing: {context[:800]}"
        raw2 = groq_chat(fallback)
        txt = raw2.strip() or txt
    return txt


def run_evaluation_suite():
    """
    Runs automated evaluation across all CHECKPOINTS.
    For each checkpoint:
      - gather/process/generate questions
      - create 'good' answers via LLM (ensured to be long enough)
      - create 'bad' answers (short/off-topic)
      - run verify_understanding for both sets and record metrics
    """
    overall = []
    print("Running automated evaluation suite for all checkpoints...\n")
    for cp in CHECKPOINTS:
        # initialize state
        state: AgentState = {
            "cp_id": cp.id,
            "checkpoint": None,
            "user_notes": "",  # no notes; will use web search fallback
            "user_pdfs": [],
            "gathered_context": "",
            "context_sources": [],
            "relevance_score_model": None,
            "refetch_attempted": False,
            "score_meta": None,
            "processed_chunks": [],
            "questions": [],
            "learner_answers": [],
            "score_percent": None,
            "pass_threshold_met": None,
            "temp_vector_store": None,
        }

        state = start_checkpoint(state)
        state = gather_context(state)
        state = validate_context(state)
        state = process_context(state)
        state = generate_questions(state)

        questions = state.get("questions") or []
        context_text = state.get("gathered_context") or "\n".join(state.get("processed_chunks", [])) or ""
        store = state.get("temp_vector_store") or {"chunks": [], "vectors": None, "meta": {"embeddings_used": False}}

        # Double-confirm embedding usage for this checkpoint
        emb_used = bool(store.get("vectors") is not None)
        print(f"[Eval] Checkpoint {cp.id} embeddings_used = {emb_used}")
        if emb_used:
            try:
                embedding_debug_print(store, label=f"eval-cp={cp.id}")
            except Exception:
                pass

        # Build 'good' answers using the LLM (ensured to be long enough)
        good_answers = []
        for q in questions:
            ans = _generate_good_answer_for_question(context_text, q)
            if len(re.findall(r"\w+", ans)) < 30:
                ans = ans + " " + ("This answer expands on the main points to ensure full coverage of the objective. " * 2)
            good_answers.append(ans)

        # Build 'bad' answers (short/off-topic)
        bad_answers = ["I don't know." for _ in questions]

        # Evaluate good answers
        state_good = dict(state)
        state_good["learner_answers"] = good_answers
        state_good = verify_understanding(state_good)
        good_score = state_good.get("score_percent") or 0.0
        good_pass = state_good.get("pass_threshold_met") or False

        # Evaluate bad answers
        state_bad = dict(state)
        state_bad["learner_answers"] = bad_answers
        state_bad = verify_understanding(state_bad)
        bad_score = state_bad.get("score_percent") or 0.0
        bad_pass = state_bad.get("pass_threshold_met") or False

        q_rel = 1.0 if questions else 0.0

        overall.append({
            "cp_id": cp.id,
            "topic": cp.topic,
            "num_questions": len(questions),
            "q_rel": q_rel,
            "context_score": state.get("relevance_score_model"),
            "good_score": good_score,
            "good_pass": good_pass,
            "bad_score": bad_score,
            "bad_pass": bad_pass,
            "embeddings_used": emb_used,
        })

    # Print summary
    print("\n=== Evaluation Summary ===\n")
    total_q_rel = 0.0
    good_pass_count = 0
    bad_fail_count = 0
    emb_used_count = 0
    for r in overall:
        print(f"Checkpoint: {r['cp_id']} - {r['topic']}")
        print(f"  Questions: {r['num_questions']}, Q-rel: {r['q_rel']:.2f}")
        print(f"  Context score (1-5): {r['context_score']}")
        print(f"  Good answers -> score: {r['good_score']:.1f}%, pass: {r['good_pass']}")
        print(f"  Bad answers  -> score: {r['bad_score']:.1f}%, pass: {r['bad_pass']}")
        print(f"  Embeddings used: {r.get('embeddings_used')}")
        print()
        total_q_rel += r['q_rel']
        if r['good_pass']:
            good_pass_count += 1
        if not r['bad_pass']:
            bad_fail_count += 1
        if r.get('embeddings_used'):
            emb_used_count += 1

    n = len(overall) or 1
    print("--- Overall Metrics ---")
    print(f"Average question relevance (fraction): {total_q_rel / n:.3f}")
    print(f"Good answers pass-rate (should be high): {good_pass_count / n * 100:.1f}%")
    print(f"Bad answers fail-rate (should be high): {bad_fail_count / n * 100:.1f}%")
    print(f"Embeddings used in checkpoints: {emb_used_count}/{n}\n")

    return overall



# Interactive run for multiple checkpoints

def interactive_run():
    print("Interactive Milestone 2 runner.")
    print("Available checkpoints:")
    for cp in CHECKPOINTS:
        print(f" - {cp.id}: {cp.topic}")

    chosen = input("Enter comma-separated checkpoint ids to run (or 'all'): ").strip()
    if chosen.lower() == "all" or not chosen:
        ids = [cp.id for cp in CHECKPOINTS]
    else:
        ids = [c.strip() for c in chosen.split(",") if c.strip()]

    results = []
    for cp_id in ids:
        try:
            res = run_single_checkpoint_interactive(cp_id)
            results.append(res)
        except Exception as e:
            print(f"Error running {cp_id}: {e}")

    # Build markdown summary table
    lines = []
    lines.append("### Summary Table\n")
    lines.append("| Topic | Objectives | Sources | Context Score (1–5) | Quiz Score (%) | Pass (>=70%) |")
    lines.append("|-------|------------|---------|---------------------|----------------|--------------|")
    for r in results:
        cp = get_checkpoint_by_id(r["cp_id"])
        obj_list = [f"- {o}" for o in cp.objectives]
        objectives_str = "<br>".join(obj_list)
        sources_str = ", ".join(r["sources"]) if r["sources"] else "None"
        context_score = r["context_score"]
        quiz_score = r["quiz_score"]
        quiz_display = "-" if quiz_score is None else f"{quiz_score:.1f}"
        passed = r["passed"]
        passed_str = "✅" if passed else ("❌" if passed is not None else "-")
        lines.append(f"| {r['topic']} | {objectives_str} | {sources_str} | {context_score} | {quiz_display} | {passed_str} |")

    table_md = "\n".join(lines)
    try:
        display(Markdown(table_md))
    except Exception:
        print(table_md)

    # Save ALL raw LLM prompt+response logs into ONE PDF
    save_all_raw_to_one_pdf("all_llm_raw_output.pdf")


# Main

if __name__ == "__main__":
    print("Milestone 1 & 2 runner")
    print("Options:\n  1) interactive run (generate questions & answer interactively)\n  2) run evaluation suite (automated tests for Q relevance & scoring)")
    choice = input("Enter 1 or 2 (default 1): ").strip() or "1"
    if choice == "2":
        run_evaluation_suite()
    else:
        interactive_run()


[Embedding] Loading SentenceTransformer 'all-MiniLM-L6-v2' ... (this may take a few seconds)
[Embedding] Model loaded:all-MiniLM-L6-v2
Milestone 1 & 2 runner
Options:
  1) interactive run (generate questions & answer interactively)
  2) run evaluation suite (automated tests for Q relevance & scoring)
Enter 1 or 2 (default 1): 2
Running automated evaluation suite for all checkpoints...


[Gathered Context] Source(s): web_search
--------------------------------------------------------------------------------
**Forward Propagation in Neural Networks**

**Objective:** Calculate the activations at each neuron for each successive hidden layer until arriving at the output layer.

**Key Components:**

1. **Neural Network Layers:** Input, Hidden, and Output layers.
2. **Neurons:** Each layer consists of several neurons that receive and process input data.
3. **Forward Propagation:** The process of passing input data through the network, generating an output (prediction) at each layer.

**Step-b

  similarity = float(ctx_vec @ ref_vec.T)


Generated 3 questions for cp1
[Eval] Checkpoint cp1 embeddings_used = True
[Embedding Confirm] eval-cp=cp1 embeddings used: True | vector shape: (1, 384)
[Embedding Confirm] sample vec[0][:6] ~ [-0.0344725176692009, -0.0859455019235611, 0.007470434997230768, -0.011288967914879322, 0.0015978480223566294, 0.035555195063352585]
[Verify] cp1 Q1 score: 90
[Verify] cp1 Q2 score: 90
[Verify] cp1 Q3 score: 90
Overall quiz score for cp1: 90.0% (pass >= 70%)
[Verify] cp1 Q1 detected as empty/too-short → score: 0
[Verify] cp1 Q2 detected as empty/too-short → score: 0
[Verify] cp1 Q3 detected as empty/too-short → score: 0
Overall quiz score for cp1: 0.0% (pass >= 70%)

[Gathered Context] Source(s): web_search
--------------------------------------------------------------------------------
**Gradient Descent: An Optimization Algorithm for Machine Learning**

**Key Points:**

1. **Definition:** Gradient descent is an iterative optimization algorithm used to find the best weights and bias for a linea