<a href="https://colab.research.google.com/github/vamsiChittem/story-teller/blob/main/story_teller_with_llc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Quick probe to make sure search & fetch are OK
print("SEARCH:", search_web_safe("first Moon landing", 3))
print("FETCH:", fetch_clean_safe("https://en.wikipedia.org/wiki/Apollo_11")[:300], "...")


In [None]:
# --- Install deps (safe for Colab Py3.10/3.11) ---
!pip -q install duckduckgo-search==5.3.1 trafilatura==1.12.2 wikipedia-api

# --- Imports & helpers ---
import re, html
import trafilatura as tf
from duckduckgo_search import DDGS
import wikipediaapi
from urllib.parse import urlparse

def search_web_safe(query: str, n: int = 5):
    """DuckDuckGo search with Wikipedia fallback."""
    cleaned = []
    try:
        with DDGS() as ddgs:
            res = list(ddgs.text(query, max_results=int(n), region="wt-wt", safesearch="moderate"))
        for r in res:
            url = (r.get("href") or r.get("url") or "").strip()
            title = (r.get("title") or "").strip()
            snippet = (r.get("body") or "").strip()
            if url and title:
                cleaned.append({"title": title, "url": url, "snippet": snippet})
    except Exception as e:
        print("DDG search error:", e)

    # fallback to Wikipedia if needed
    if not cleaned:
        try:
            wiki = wikipediaapi.Wikipedia("en")
            page = wiki.page(query)
            if page.exists():
                cleaned.append({"title": page.title, "url": page.fullurl, "snippet": page.summary[:240]})
        except Exception as e:
            print("Wikipedia fallback error:", e)
    return cleaned

def fetch_clean_safe(url: str, timeout: int = 15) -> str:
    """Fetch + clean page text; uses Wikipedia API when URL is a wiki page."""
    try:
        downloaded = tf.fetch_url(url, timeout=timeout, no_ssl=True)
        if downloaded:
            text = tf.extract(downloaded, include_comments=False, include_tables=False) or ""
            text = html.unescape(re.sub(r"\s+", " ", text)).strip()
            if len(text) >= 200:
                return text
    except Exception as e:
        print("Trafilatura fetch error:", e)

    # Wikipedia API fallback
    try:
        if "wikipedia.org" in url:
            title = url.split("/wiki/")[-1].replace("_", " ")
            wiki = wikipediaapi.Wikipedia("en")
            page = wiki.page(title)
            if page.exists():
                return re.sub(r"\s+", " ", page.text).strip()
    except Exception as e:
        print("Wikipedia API read error:", e)
    return ""

def make_source_index(sources):
    return "\n".join([f"[{i}] {s['title']} — {urlparse(s['url']).netloc}" for i,s in enumerate(sources, start=1)])


In [None]:
print("SEARCH:", search_web_safe("first Moon landing", 3))
print("FETCH:", fetch_clean_safe("https://en.wikipedia.org/wiki/Apollo_11")[:300], "...")


DDG search error: https://duckduckgo.com/ 202 Ratelimit
Wikipedia fallback error: Please, be nice to Wikipedia and specify user agent - https://meta.wikimedia.org/wiki/User-Agent_policy. Current user_agent: 'en' is not sufficient. Use Wikipedia(user_agent='your-user-agent', language='en')
SEARCH: []
Trafilatura fetch error: fetch_url() got an unexpected keyword argument 'timeout'
Wikipedia API read error: Please, be nice to Wikipedia and specify user agent - https://meta.wikimedia.org/wiki/User-Agent_policy. Current user_agent: 'en' is not sufficient. Use Wikipedia(user_agent='your-user-agent', language='en')
FETCH:  ...


In [None]:
# ============================================================
# 🌐📚 Web-Grounded Story Teller — Colab (Py3.10+)
# Wikipedia user-agent set to your email: vamsikrishnavk098@gmail.com
# ============================================================

# --- System & Installs ---
!sudo apt -q update && sudo apt -q install -y ffmpeg >/dev/null
!pip -q install duckduckgo-search==5.3.1 trafilatura==1.12.2 wikipedia-api==0.6.0
!pip -q install "transformers>=4.44.0" "accelerate>=0.34.0" "gradio>=4.44.0" sentencepiece safetensors gTTS

# --- Imports & Config ---
import os, re, html, time, traceback, torch
import gradio as gr
import trafilatura as tf
from urllib.parse import urlparse
from typing import List, Dict, Tuple
from duckduckgo_search import DDGS
import wikipediaapi
from gtts import gTTS
from transformers import AutoTokenizer, AutoModelForCausalLM

SEED = 42
torch.manual_seed(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.float16 if DEVICE=="cuda" else torch.float32

# ✅ Polite Wikipedia user agent with your email
WIKI_AGENT = "WebGroundedStoryTeller/1.0 (Colab; contact: vamsikrishnavk098@gmail.com)"
wiki = wikipediaapi.Wikipedia(language="en", user_agent=WIKI_AGENT)

MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
GEN_KW = dict(max_new_tokens=512, do_sample=True, temperature=0.8, top_p=0.95, top_k=50, repetition_penalty=1.1)
SAFETY_BAN = ["sexual","explicit","gore","suicide","self-harm","hate","slur","violence","terrorism","extremist"]

print("Loading LLM…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=DTYPE, low_cpu_mem_usage=True).to(DEVICE).eval()

def chat_prompt(system: str, user: str) -> str:
    return f"<|system|>\n{system}\n<|user|>\n{user}\n<|assistant|>\n"

@torch.inference_mode()
def generate_llm(text: str, **kw) -> str:
    ids = tokenizer(text, return_tensors="pt").to(DEVICE)
    out = model.generate(**ids, **{**GEN_KW, **kw}, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id)
    s = tokenizer.decode(out[0], skip_special_tokens=True)
    return s.split("<|assistant|>")[-1].strip() if "<|assistant|>" in s else s.strip()

# ---------- Wikipedia search (primary) ----------
def search_wikipedia(query: str, n: int = 5) -> List[Dict]:
    titles = wiki.search(query)
    cleaned = []
    for title in (titles or [])[:n]:
        page = wiki.page(title)
        if page.exists():
            cleaned.append({"title": page.title, "url": page.fullurl, "snippet": page.summary[:240] if hasattr(page,"summary") else ""})
    return cleaned

# ---------- DuckDuckGo search (optional; may rate-limit) ----------
def search_ddg(query: str, n: int = 5, retries: int = 2, backoff: float = 1.5) -> List[Dict]:
    for attempt in range(retries + 1):
        try:
            with DDGS() as ddgs:
                res = list(ddgs.text(query, max_results=int(n), region="wt-wt", safesearch="moderate"))
            cleaned = []
            for r in res:
                url = (r.get("href") or r.get("url") or "").strip()
                title = (r.get("title") or "").strip()
                snippet = (r.get("body") or "").strip()
                if url and title:
                    cleaned.append({"title": title, "url": url, "snippet": snippet})
            if cleaned:
                return cleaned
        except Exception as e:
            print(f"DDG attempt {attempt+1}/{retries+1} error:", e)
        time.sleep(backoff * (attempt + 1))
    return []

# ---------- Fetch with Wikipedia API fallback ----------
def fetch_clean_safe(url: str, timeout: int = 15) -> str:
    try:
        fetched = tf.fetch_url(url, timeout=timeout, no_ssl=True)
        if fetched:
            text = tf.extract(fetched, include_comments=False, include_tables=False) or ""
            text = html.unescape(re.sub(r"\s+", " ", text)).strip()
            if len(text) >= 200:
                return text
    except Exception as e:
        print("Trafilatura fetch error:", e)
    try:
        if "wikipedia.org" in url:
            title = url.split("/wiki/")[-1].replace("_", " ")
            page = wiki.page(title)
            if page.exists():
                return re.sub(r"\s+", " ", page.text).strip()
    except Exception as e:
        print("Wikipedia API read error:", e)
    return ""

# ---------- Helpers ----------
def make_source_index(sources: List[Dict]) -> str:
    return "\n".join([f"[{i}] {s['title']} — {urlparse(s['url']).netloc}" for i,s in enumerate(sources, start=1)])

def urls_block(sources: List[Dict]) -> str:
    return "\n".join([f"[{i}] {s['url']}" for i,s in enumerate(sources, start=1)])

def chunk_text(text: str, max_chars: int = 2000):
    return [text[i:i+max_chars] for i in range(0, len(text), max_chars)]

# ---------- Fact distillation ----------
FACT_SYS = ("You are a careful research assistant. Given content from multiple web pages, "
            "extract concise factual bullets about the topic. Avoid opinions. "
            "End each bullet with the source index in brackets like [1].")

def distill_facts(topic: str, pages: List[Tuple[int, str]], max_bullets: int = 12) -> str:
    per_page = []
    for idx, content in pages:
        if not content:
            continue
        for j, ch in enumerate(chunk_text(content, 1600), start=1):
            user = (f"Topic: {topic}\nSource Index: [{idx}]\nContent (part {j}):\n{ch}\n\n"
                    f"Extract 3–5 short factual bullets, each ending with [{idx}].")
            per_page.append(generate_llm(chat_prompt(FACT_SYS, user), max_new_tokens=220))
    merged = "\n".join(per_page) if per_page else "No content."
    user = (f"Topic: {topic}\nBelow are bullets from several sources with citations like [1], [2]. "
            f"Merge, deduplicate, and produce at most {max_bullets} crisp bullets, preserving citations "
            f"(allow [1,2] when a fact appears in multiple sources).\n\n{merged}")
    return generate_llm(chat_prompt(FACT_SYS, user), max_new_tokens=320)

# ---------- Grounded story ----------
STORY_SYS = ("You are a creative but faithful storyteller. Write a coherent story grounded ONLY in the factual bullets. "
             "Do not invent specific facts beyond the bullets. You may add safe narrative glue and dialogue. "
             "End with a 'Sources' section listing the citation numbers you used (e.g., [1], [2]).")

def safety_ok(text: str) -> bool:
    t = text.lower()
    return not any(b in t for b in SAFETY_BAN)

def craft_story(bullets: str, style: str, audience: str, length_hint: str) -> str:
    user = (f"Audience: {audience}\nStyle & Tone: {style}\nLength Hint: {length_hint}\n\n"
            f"FACT BULLETS:\n{bullets}\n\nWrite the story now. End with a 'Sources' section listing citation numbers.")
    return generate_llm(chat_prompt(STORY_SYS, user), max_new_tokens=680)

# ---------- gTTS narration ----------
def tts_mp3(text: str, path: str = "/content/story.mp3") -> str:
    clean = re.sub(r"\s+", " ", text).strip() or "There is no story to narrate."
    if len(clean) > 5000: clean = clean[:5000] + " ..."
    gTTS(clean).save(path)
    return path

# ---------- Main pipeline ----------
def run_pipeline(topic: str, style: str, audience: str, length_hint: str,
                 n_results: int, narrate: bool, wiki_only: bool):
    try:
        topic = (topic or "").strip()
        if not topic:
            return "Please enter a topic.", "", "", None

        # Sources
        if wiki_only:
            results = search_wikipedia(topic, n=int(n_results))
        else:
            results = search_ddg(topic, n=int(n_results)) or search_wikipedia(topic, n=int(n_results))
        if not results:
            return ("❌ No sources found. Enable 'Wikipedia-only', reduce results to 3–4, or try a broader topic."), "", "", None

        texts = [(i, fetch_clean_safe(r["url"])) for i, r in enumerate(results, start=1)]
        src_idx = make_source_index(results)
        facts = distill_facts(topic, texts, max_bullets=12) or "⚠️ Could not distill facts."
        story = craft_story(facts, style=style, audience=audience, length_hint=length_hint)
        if not safety_ok(story):
            story = ("⚠️ Safety filter: The topic/content may be unsuitable. Try a different topic or a kid-friendly angle.")
        story_full = f"{story}\n\nSources (URLs):\n{urls_block(results)}"

        audio_path = None
        if narrate and story and "⚠️" not in story and "❌" not in story:
            try:
                main_text = story.split("\nSources")[0]
                audio_path = tts_mp3(main_text, "/content/story.mp3")
            except Exception as e:
                print("TTS error:", e)
                audio_path = None

        return story_full, src_idx, facts, audio_path

    except Exception as e:
        print("=== PIPELINE ERROR ===")
        print("".join(traceback.format_exc()))
        return f"❌ Runtime error: {e}", "", "", None

# ---------- UI ----------
try:
    demo.close()
except Exception:
    pass

with gr.Blocks(title="🌐📚 Web-Grounded Story Teller") as demo:
    gr.Markdown("# 🌐📚 Web-Grounded Story Teller\nEnter any topic. I will **search sources**, extract **facts**, and weave a **grounded story** with **citations**.")
    with gr.Row():
        with gr.Column():
            topic = gr.Textbox(label="Your topic / request", value="A cozy story about the first Moon landing for kids")
            audience = gr.Textbox(label="Audience", value="Kids (7–10)")
            style = gr.Textbox(label="Style & Tone", value="Cozy, positive, simple language")
            length_hint = gr.Textbox(label="Length hint", value="~500–700 words")
            n_results = gr.Slider(1, 10, value=3, step=1, label="Results to use")
            wiki_only = gr.Checkbox(value=True, label="Use Wikipedia-only (avoid rate limits)")
            narrate = gr.Checkbox(value=True, label="🔊 Narrate the story (MP3 via gTTS)")
            go = gr.Button("✨ Tell me a story")
        with gr.Column():
            story_out = gr.Markdown(label="Story with citations")
            gr.Markdown("### Sources (index)")
            src_out = gr.Markdown()
            gr.Markdown("### Distilled Fact Bullets")
            facts_out = gr.Markdown()
            audio_out = gr.Audio(label="Narration (MP3)", interactive=False)

    go.click(run_pipeline,
             inputs=[topic, style, audience, length_hint, n_results, narrate, wiki_only],
             outputs=[story_out, src_out, facts_out, audio_out])

demo.queue().launch(share=True, debug=False)


Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:2 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:3 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:4 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:5 https://cli.github.com/packages stable InRelease
Hit:6 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:7 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:9 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:11 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Reading package lists...
Building dependency tree...
Reading state information...
38 packages can be upgraded. Run 'apt list --upgradable' to see them.
[1;33mW: [0mSkipping acquire of configured file 'main/source/Sources' as reposi



In [1]:
# ============================================================
# 🌐📚 Web-Grounded Story Teller — OOM-safe, no bitsandbytes
# - Wikipedia REST search (polite UA with your email)
# - Optional DuckDuckGo fallback
# - Hierarchical fact distillation to keep prompts short
# - TinyLlama 1.1B Chat in FP16 (no bitsandbytes)
# - OOM-safe generate (truncates + retries)
# - gTTS narration + Gradio UI
# ============================================================

# ---------------- System & Installs ----------------
!sudo apt -q update && sudo apt -q install -y ffmpeg >/dev/null
!pip -q install duckduckgo-search==5.3.1 trafilatura==1.12.2
!pip -q install "transformers>=4.44.0" "accelerate>=0.34.0" "gradio>=4.44.0" sentencepiece safetensors
!pip -q install gTTS requests

# ---------------- Imports & Config ----------------
import os, re, html, time, json, traceback, requests, torch
import gradio as gr
import trafilatura as tf
from urllib.parse import urlparse
from typing import List, Dict, Tuple
from duckduckgo_search import DDGS
from gtts import gTTS
from transformers import AutoTokenizer, AutoModelForCausalLM

# Make CUDA allocator less fragile (helps fragmentation)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

SEED = 42
torch.manual_seed(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.float16 if DEVICE=="cuda" else torch.float32

# Wikipedia REST with your email as UA (be polite 🫶)
WIKI_AGENT = "WebGroundedStoryTeller/1.0 (Colab; contact: vamsikrishnavk098@gmail.com)"
session = requests.Session()
session.headers.update({"User-Agent": WIKI_AGENT, "Accept": "application/json"})

# Model & generation limits (kept modest to avoid OOM)
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
MAX_PROMPT_TOKENS = 1024         # cap input length to model
FACT_GEN_TOKENS   = 200          # per generation during distillation
STORY_GEN_TOKENS  = 360          # final story length (lower = safer)
PAGE_BULLETS_MAX  = 5            # bullets per page after compression
FINAL_BULLETS_MAX = 10           # total bullets used for story

GEN_KW = dict(
    do_sample=True,
    temperature=0.8,
    top_p=0.95,
    top_k=50,
    repetition_penalty=1.1
)

SAFETY_BAN = ["sexual","explicit","gore","suicide","self-harm","hate","slur","violence","terrorism","extremist"]

# ---------------- Load FP16 Model (no bitsandbytes) ----------------
print("Loading TinyLlama 1.1B Chat (FP16)…")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    low_cpu_mem_usage=True,
).to(DEVICE).eval()

def chat_prompt(system: str, user: str) -> str:
    return f"<|system|>\n{system}\n<|user|>\n{user}\n<|assistant|>\n"

def num_tokens(txt: str) -> int:
    return len(tokenizer(txt, return_tensors="pt").input_ids[0])

def truncate_to_tokens(txt: str, max_tokens: int) -> str:
    ids = tokenizer(txt, return_tensors="pt").input_ids[0]
    if len(ids) <= max_tokens:
        return txt
    # keep the tail (latest instructions/context)
    ids = ids[-max_tokens:]
    return tokenizer.decode(ids, skip_special_tokens=True)

@torch.inference_mode()
def safe_generate(prompt: str, max_new_tokens: int) -> str:
    """Generate with OOM fallback + prompt truncation."""
    try:
        encoded = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        out = model.generate(
            **encoded,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True,
            **GEN_KW
        )
        s = tokenizer.decode(out[0], skip_special_tokens=True)
        return s.split("<|assistant|>")[-1].strip() if "<|assistant|>" in s else s.strip()
    except torch.cuda.OutOfMemoryError:
        torch.cuda.empty_cache()
        short = truncate_to_tokens(prompt, MAX_PROMPT_TOKENS // 2)
        encoded = tokenizer(short, return_tensors="pt").to(DEVICE)
        out = model.generate(
            **encoded,
            max_new_tokens=max_new_tokens // 2,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True,
            **GEN_KW
        )
        s = tokenizer.decode(out[0], skip_special_tokens=True)
        return s.split("<|assistant|>")[-1].strip() if "<|assistant|>" in s else s.strip()

# ---------------- Wikipedia REST search ----------------
def wikipedia_search(query: str, limit: int = 5) -> List[Dict]:
    try:
        url = "https://en.wikipedia.org/w/rest.php/v1/search/page"
        r = session.get(url, params={"q": query, "limit": int(limit)}, timeout=10)
        if not r.ok:
            print("Wikipedia search HTTP", r.status_code, r.text[:200])
            return []
        data = r.json()
        pages = data.get("pages", []) or []
        cleaned = []
        for p in pages[:limit]:
            title   = p.get("title") or ""
            excerpt = re.sub("<[^>]+>", "", (p.get("excerpt") or ""))
            if title:
                cleaned.append({
                    "title": title,
                    "url": f"https://en.wikipedia.org/wiki/{title.replace(' ', '_')}",
                    "snippet": excerpt[:240]
                })
        return cleaned
    except Exception as e:
        print("Wikipedia search error:", e)
        return []

def wikipedia_fetch_plain_by_title(title: str) -> str:
    try:
        api = "https://en.wikipedia.org/w/api.php"
        params = {
            "action": "query",
            "prop": "extracts",
            "explaintext": 1,
            "exsectionformat": "plain",
            "redirects": 1,
            "titles": title,
            "format": "json"
        }
        r = session.get(api, params=params, timeout=12)
        if not r.ok:
            print("Wikipedia extract HTTP", r.status_code, r.text[:200])
            return ""
        data = r.json()
        pages = data.get("query", {}).get("pages", {})
        if not pages: return ""
        page = next(iter(pages.values()))
        text = page.get("extract", "") or ""
        return re.sub(r"\s+", " ", text).strip()
    except Exception as e:
        print("Wikipedia extract error:", e)
        return ""

# ---------------- DuckDuckGo (optional) ----------------
def search_ddg(query: str, n: int = 5, retries: int = 2, backoff: float = 1.5) -> List[Dict]:
    for attempt in range(retries + 1):
        try:
            with DDGS() as ddgs:
                res = list(ddgs.text(query, max_results=int(n), region="wt-wt", safesearch="moderate"))
            cleaned = []
            for r in res:
                url = (r.get("href") or r.get("url") or "").strip()
                title = (r.get("title") or "").strip()
                snippet = (r.get("body")  or "").strip()
                if url and title:
                    cleaned.append({"title": title, "url": url, "snippet": snippet})
            if cleaned:
                return cleaned
        except Exception as e:
            print(f"DDG attempt {attempt+1}/{retries+1} error:", e)
        time.sleep(backoff * (attempt + 1))
    return []

# ---------------- Fetch & clean ----------------
def fetch_clean_safe(url: str, timeout: int = 15) -> str:
    if "wikipedia.org/wiki/" in url:
        title = url.split("/wiki/")[-1].replace("_", " ")
        return wikipedia_fetch_plain_by_title(title)
    try:
        downloaded = tf.fetch_url(url, timeout=timeout, no_ssl=True)
        if downloaded:
            text = tf.extract(downloaded, include_comments=False, include_tables=False) or ""
            text = html.unescape(re.sub(r"\s+", " ", text)).strip()
            if len(text) >= 200:
                return text
    except Exception as e:
        print("Trafilatura fetch error:", e)
    return ""

# ---------------- Helpers ----------------
def make_source_index(sources: List[Dict]) -> str:
    return "\n".join([f"[{i}] {s['title']} — {urlparse(s['url']).netloc}" for i,s in enumerate(sources, start=1)])

def urls_block(sources: List[Dict]) -> str:
    return "\n".join([f"[{i}] {s['url']}" for i,s in enumerate(sources, start=1)])

def chunk_text(text: str, max_chars: int = 1800):
    return [text[i:i+max_chars] for i in range(0, len(text), max_chars)]

# ---------------- Fact distillation (hierarchical) ----------------
FACT_SYS = (
    "You are a careful research assistant. From each page chunk, extract 2-3 short, factual bullets. "
    "Avoid opinions. End each bullet with the source index like [1]."
)
MERGE_SYS = (
    "You merge bullets. Deduplicate and keep the strongest facts. Preserve citations like [1] or [1,2]. "
    f"Limit to at most {PAGE_BULLETS_MAX} bullets for this page."
)
FINAL_SYS = (
    "You merge page summaries into final bullets. Deduplicate across pages, preserve citations, "
    f"and output at most {FINAL_BULLETS_MAX} crisp bullets."
)

def distill_facts(topic: str, pages: List[Tuple[int,str]]) -> str:
    page_summaries = []
    for idx, content in pages:
        if not content:
            continue
        # 1) Chunk and extract small bullet sets per chunk
        chunk_bullets = []
        for j, ch in enumerate(chunk_text(content, 1500), start=1):
            user = (
                f"Topic: {topic}\nSource Index: [{idx}]\n"
                f"Content part {j}:\n{ch}\n\n"
                "Extract 2-3 short factual bullets. Each bullet MUST end with "
                f"[{idx}]."
            )
            prompt = chat_prompt(FACT_SYS, truncate_to_tokens(user, MAX_PROMPT_TOKENS))
            out = safe_generate(prompt, max_new_tokens=FACT_GEN_TOKENS)
            chunk_bullets.append(out)

        # 2) Merge within the page down to PAGE_BULLETS_MAX
        merged_text = "\n".join(chunk_bullets) if chunk_bullets else ""
        if merged_text.strip():
            user = f"Topic: {topic}\nBullets from source [{idx}]:\n{merged_text}\n\nMerge now."
            prompt = chat_prompt(MERGE_SYS, truncate_to_tokens(user, MAX_PROMPT_TOKENS))
            page_summary = safe_generate(prompt, max_new_tokens=FACT_GEN_TOKENS)
            lines = [l.strip() for l in page_summary.splitlines() if l.strip()]
            page_summaries.append("\n".join(lines[:PAGE_BULLETS_MAX]))

    # 3) Final merge across pages to FINAL_BULLETS_MAX
    merged_all = "\n\n".join(page_summaries) if page_summaries else "No content."
    user = f"Topic: {topic}\nPage bullet summaries:\n{merged_all}\n\nProduce final bullets."
    prompt = chat_prompt(FINAL_SYS, truncate_to_tokens(user, MAX_PROMPT_TOKENS))
    final_bullets = safe_generate(prompt, max_new_tokens=FACT_GEN_TOKENS)
    lines = [l.strip() for l in final_bullets.splitlines() if l.strip()]
    return "\n".join(lines[:FINAL_BULLETS_MAX])

# ---------------- Grounded story ----------------
STORY_SYS = (
    "You are a creative but faithful storyteller. Write a coherent story grounded ONLY in the factual bullets. "
    "Do not invent specific facts beyond the bullets. You may add safe narrative glue and dialogue. "
    "End with a 'Sources' section listing the citation numbers you used (e.g., [1], [2])."
)

def safety_ok(text: str) -> bool:
    t = text.lower()
    return not any(b in t for b in SAFETY_BAN)

def craft_story(bullets: str, style: str, audience: str, length_hint: str) -> str:
    user = (
        f"Audience: {audience}\nStyle & Tone: {style}\nLength Hint: {length_hint}\n\n"
        f"FACT BULLETS:\n{bullets}\n\nWrite the story now. "
        "End with a 'Sources' section listing citation numbers."
    )
    prompt = chat_prompt(STORY_SYS, truncate_to_tokens(user, MAX_PROMPT_TOKENS))
    return safe_generate(prompt, max_new_tokens=STORY_GEN_TOKENS)

# ---------------- gTTS narration ----------------
def tts_mp3(text: str, path: str = "/content/story.mp3") -> str:
    clean = re.sub(r"\s+", " ", text).strip() or "There is no story to narrate."
    if len(clean) > 5000: clean = clean[:5000] + " ..."
    gTTS(clean).save(path)
    return path

# ---------------- Source selection helpers ----------------
def wikipedia_search_sources(query: str, limit: int) -> List[Dict]:
    return wikipedia_search(query, limit=limit)

def ddg_or_wiki(query: str, limit: int) -> List[Dict]:
    return search_ddg(query, n=limit) or wikipedia_search(query, limit=limit)

# ---------------- Main pipeline ----------------
def run_pipeline(topic: str, style: str, audience: str, length_hint: str,
                 n_results: int, narrate: bool, wiki_only: bool):
    try:
        topic = (topic or "").strip()
        if not topic:
            return "Please enter a topic.", "", "", None

        # Sources
        results = (wikipedia_search_sources(topic, int(n_results))
                   if wiki_only else ddg_or_wiki(topic, int(n_results)))
        if not results:
            return ("❌ No sources found. Enable 'Wikipedia-only', reduce results to 2–3, or try a broader topic."), "", "", None

        # Fetch texts
        texts = [(i, fetch_clean_safe(r["url"])) for i, r in enumerate(results, start=1)]
        src_idx = make_source_index(results)

        # Distill (hierarchical, capped)
        facts = distill_facts(topic, texts) or "⚠️ Could not distill facts."

        # Story
        story = craft_story(facts, style=style, audience=audience, length_hint=length_hint)
        if not safety_ok(story):
            story = ("⚠️ Safety filter: The topic/content may be unsuitable. "
                     "Try a different topic or a kid-friendly angle.")

        story_full = f"{story}\n\nSources (URLs):\n{urls_block(results)}"

        # Narration
        audio_path = None
        if narrate and story and "⚠️" not in story and "❌" not in story:
            try:
                main_text = story.split("\nSources")[0]
                audio_path = tts_mp3(main_text, "/content/story.mp3")
            except Exception as e:
                print("TTS error:", e)
                audio_path = None

        return story_full, src_idx, facts, audio_path

    except Exception as e:
        print("=== PIPELINE ERROR ===")
        print("".join(traceback.format_exc()))
        return f"❌ Runtime error: {e}", "", "", None

# ---------------- UI ----------------
try:
    demo.close()
except Exception:
    pass

with gr.Blocks(title="🌐📚 Web-Grounded Story Teller (OOM-safe, no BnB)") as demo:
    gr.Markdown("# 🌐📚 Web-Grounded Story Teller\nGrounded in web facts • FP16 • OOM-safe.")
    with gr.Row():
        with gr.Column():
            topic = gr.Textbox(label="Your topic / request",
                               value="A cozy story about the first Moon landing for kids")
            audience = gr.Textbox(label="Audience", value="Kids (7–10)")
            style = gr.Textbox(label="Style & Tone", value="Cozy, positive, simple language")
            length_hint = gr.Textbox(label="Length hint", value="~350–500 words")
            n_results = gr.Slider(1, 6, value=3, step=1, label="Results to use")
            wiki_only = gr.Checkbox(value=True, label="Use Wikipedia-only (avoid rate limits)")
            narrate = gr.Checkbox(value=True, label="🔊 Narrate the story (MP3 via gTTS)")
            go = gr.Button("✨ Tell me a story")
        with gr.Column():
            story_out = gr.Markdown(label="Story with citations")
            gr.Markdown("### Sources (index)")
            src_out = gr.Markdown()
            gr.Markdown("### Distilled Fact Bullets")
            facts_out = gr.Markdown()
            audio_out = gr.Audio(label="Narration (MP3)", interactive=False)

    go.click(run_pipeline,
             inputs=[topic, style, audience, length_hint, n_results, narrate, wiki_only],
             outputs=[story_out, src_out, facts_out, audio_out])

demo.queue().launch(share=True, debug=False)


Hit:1 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:2 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:3 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:4 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:5 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Get:6 https://cli.github.com/packages stable InRelease [3,917 B]
Get:7 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Get:9 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease [24.3 kB]
Hit:10 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:11 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:12 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [3,310 kB]
Get:13 http://security.ubuntu.com/ubuntu jammy-security/universe amd6

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://52522a7c1165bbe882.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


