<a href="https://colab.research.google.com/github/sadhvik02/Narrative-Building/blob/main/Task2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install sentence-transformers scikit-learn numpy




In [6]:
%%writefile narrative_builder.py
import json
import os
from datetime import datetime
from typing import List, Dict, Any, Optional

import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity


# --------------- Data Loading ----------------- #

def load_news(path: str) -> List[Dict[str, Any]]:
    """
    Load the 84MB news dataset.
    Supports:
      - A JSON array of articles
      - JSONL (one JSON per line)
    """
    if not os.path.exists(path):
        raise FileNotFoundError(f"News dataset not found at: {path}")

    articles = []
    with open(path, "r", encoding="utf-8") as f:
        first_char = f.read(1)
        f.seek(0)

        if first_char == "[":
            data = json.load(f)
            if isinstance(data, list):
                articles = data
            else:
                raise ValueError("Expected a list of articles in JSON file.")
        else:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                    articles.append(obj)
                except json.JSONDecodeError:
                    continue

    return articles


def filter_by_source_rating(articles: List[Dict[str, Any]], min_rating: float = 8.0):
    return [
        a for a in articles
        if float(a.get("source_rating", 0)) > min_rating
    ]


# --------------- Helpers ----------------- #

def get_headline(article: Dict[str, Any]) -> str:
    return (
        article.get("headline")
        or article.get("title")
        or article.get("short_title")
        or "Untitled article"
    )


def get_url(article: Dict[str, Any]) -> str:
    return article.get("url") or article.get("link") or ""


def get_date(article: Dict[str, Any]) -> Optional[datetime]:
    for key in ["date", "published_at", "published", "pub_date", "time"]:
        val = article.get(key)
        if not val:
            continue
        for fmt in [
            "%Y-%m-%d",
            "%Y-%m-%dT%H:%M:%S%z",
            "%Y-%m-%dT%H:%M:%S",
            "%Y-%m-%d %H:%M:%S",
        ]:
            try:
                return datetime.strptime(val, fmt)
            except Exception:
                continue
    return None


def get_date_str(article: Dict[str, Any]) -> str:
    dt = get_date(article)
    if dt:
        return dt.strftime("%Y-%m-%d")
    return (
        article.get("date")
        or article.get("published_at")
        or article.get("published")
        or ""
    )


def build_article_text(article: Dict[str, Any]) -> str:
    pieces = [
        get_headline(article),
        article.get("summary") or "",
        article.get("description") or "",
        article.get("content") or article.get("body") or "",
    ]
    return " ".join(p for p in pieces if p).strip()


# --------------- Core Narrative Logic ----------------- #

def select_relevant_articles(
    articles: List[Dict[str, Any]],
    topic: str,
    model: SentenceTransformer,
    max_articles: int = 100,
    min_sim: float = 0.35,
):
    texts = [build_article_text(a) for a in articles]
    indices_with_text = [i for i, t in enumerate(texts) if t]

    if not indices_with_text:
        return [], np.array([]), np.array([])

    texts_non_empty = [texts[i] for i in indices_with_text]

    embeddings = model.encode(texts_non_empty, normalize_embeddings=True)
    topic_embedding = model.encode([topic], normalize_embeddings=True)
    sims = cosine_similarity(topic_embedding, embeddings)[0]

    valid_indices = [i for i, s in enumerate(sims) if s >= min_sim]
    if not valid_indices:
        valid_indices = list(np.argsort(-sims)[:max_articles])

    valid_indices = sorted(valid_indices, key=lambda i: sims[i], reverse=True)
    valid_indices = valid_indices[:max_articles]

    selected_articles = [articles[indices_with_text[i]] for i in valid_indices]
    selected_embeddings = embeddings[valid_indices]
    selected_sims = sims[valid_indices]

    return selected_articles, selected_embeddings, selected_sims


def build_narrative_summary(
    topic: str,
    articles: List[Dict[str, Any]],
    sims: np.ndarray,
    max_sentences: int = 8,
) -> str:
    if not len(articles):
        return f"No strong matches were found in the news dataset for the topic '{topic}'."

    paired = list(zip(articles, sims))
    paired.sort(key=lambda x: x[1], reverse=True)

    sentences = []
    for article, score in paired[: max_sentences * 2]:
        date_str = get_date_str(article) or "an unknown date"
        headline = get_headline(article)
        src = article.get("source") or article.get("source_name") or "an unspecified source"

        sentences.append(
            f"On {date_str}, {src} reported: '{headline}', which is relevant to {topic}."
        )
        if len(sentences) >= max_sentences:
            break

    return " ".join(sentences)


def build_timeline(articles: List[Dict[str, Any]]):
    enriched = [(get_date(a), a) for a in articles]
    enriched.sort(key=lambda x: (x[0] is None, x[0]))

    timeline = []
    for dt, article in enriched:
        date_str = get_date_str(article)
        headline = get_headline(article)
        url = get_url(article)
        why = (
            f"This article contributes context or new developments related to the topic "
            f"by focusing on '{headline}'."
        )

        timeline.append(
            {
                "date": date_str,
                "headline": headline,
                "url": url,
                "why_it_matters": why,
            }
        )
    return timeline


def build_clusters(
    articles: List[Dict[str, Any]],
    embeddings: np.ndarray,
    num_clusters: int = 4,
):
    if not len(articles):
        return []

    n = len(articles)
    k = min(num_clusters, max(1, n // 5))
    if k <= 1 or n <= 3:
        return [
            {
                "cluster_id": 0,
                "label": "General theme",
                "article_indices": list(range(n)),
                "articles": [
                    {
                        "id": idx,
                        "headline": get_headline(a),
                        "url": get_url(a),
                    }
                    for idx, a in enumerate(articles)
                ],
            }
        ]

    km = KMeans(n_clusters=k, random_state=42, n_init="auto")
    labels = km.fit_predict(embeddings)

    clusters = []
    for cluster_id in range(k):
        indices = [i for i, lbl in enumerate(labels) if lbl == cluster_id]
        cluster_articles = [articles[i] for i in indices]

        if cluster_articles:
            label_headline = get_headline(cluster_articles[0])
            label = f"Theme around: {label_headline}"
        else:
            label = f"Cluster {cluster_id}"

        clusters.append(
            {
                "cluster_id": int(cluster_id),
                "label": label,
                "article_indices": indices,
                "articles": [
                    {
                        "id": int(i),
                        "headline": get_headline(articles[i]),
                        "url": get_url(articles[i]),
                    }
                    for i in indices
                ],
            }
        )

    return clusters


def build_graph(
    articles: List[Dict[str, Any]],
    embeddings: np.ndarray,
    sims: np.ndarray,
    sim_threshold: float = 0.6,
):
    n = len(articles)
    nodes = []
    edges = []

    for idx, article in enumerate(articles):
        nodes.append(
            {
                "id": int(idx),
                "headline": get_headline(article),
                "url": get_url(article),
                "date": get_date_str(article),
            }
        )

    if n <= 1:
        return {"nodes": nodes, "edges": []}

    pairwise = cosine_similarity(embeddings)

    for i in range(n):
        for j in range(i + 1, n):
            sim_ij = pairwise[i, j]
            if sim_ij < sim_threshold:
                continue

            d_i = get_date(articles[i])
            d_j = get_date(articles[j])

            if d_i and d_j:
                if d_i < d_j:
                    relation = "builds_on"
                    src, tgt = i, j
                else:
                    relation = "builds_on"
                    src, tgt = j, i
            else:
                relation = "adds_context"
                src, tgt = i, j

            edges.append(
                {
                    "source": int(src),
                    "target": int(tgt),
                    "relation": relation,
                    "similarity": float(sim_ij),
                }
            )

    return {"nodes": nodes, "edges": edges}


def run_narrative(topic: str, data_path: str, max_articles: int = 100):
    # 1. Load & filter
    articles = load_news(data_path)
    articles = filter_by_source_rating(articles, min_rating=8.0)

    if not articles:
        return {
            "narrative_summary": f"No articles with source_rating > 8 were found for topic '{topic}'.",
            "timeline": [],
            "clusters": [],
            "graph": {"nodes": [], "edges": []},
        }

    # 2. Model
    model = SentenceTransformer("all-MiniLM-L6-v2")

    # 3. Select relevant
    selected_articles, embeddings, sims = select_relevant_articles(
        articles,
        topic=topic,
        model=model,
        max_articles=max_articles,
    )

    if not selected_articles:
        return {
            "narrative_summary": f"No relevant articles could be matched for topic '{topic}'.",
            "timeline": [],
            "clusters": [],
            "graph": {"nodes": [], "edges": []},
        }

    # 4. Build outputs
    narrative_summary = build_narrative_summary(topic, selected_articles, sims)
    timeline = build_timeline(selected_articles)
    clusters = build_clusters(selected_articles, embeddings)
    graph = build_graph(selected_articles, embeddings, sims)

    return {
        "narrative_summary": narrative_summary,
        "timeline": timeline,
        "clusters": clusters,
        "graph": graph,
    }


Writing narrative_builder.py


In [7]:
!python narrative_builder.py \
  --data_path "/content/14e9e4cc-9174-48da-ad02-abb1330b48fe.json" \
  --topic "AI regulation" \
  --min_source_rating 8.0 \
  --max_articles 80 \
  --n_clusters 4

2025-11-17 08:01:07.723134: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763366467.750533    2046 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763366467.758284    2046 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1763366467.778108    2046 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763366467.778158    2046 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763366467.778162    2046 computation_placer.cc:177] computation placer alr

In [9]:
!pip install sentence-transformers scikit-learn python-dateutil




In [18]:
%%writefile narrative_builder.py
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""  # force CPU, avoid GPU issues / CUDA warnings

import argparse
import json
import os as _os
from datetime import datetime
from typing import List, Dict, Any, Optional

import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity


# -------------------- DATA LOADING --------------------

def load_news(path: str) -> List[Dict[str, Any]]:
    """
    Flexible loader for the 84MB news dataset.

    Supports:
    - Big JSON object with nested article lists
    - Top-level list of articles
    - JSONL (one JSON object per line)
    """
    if not _os.path.exists(path):
        raise FileNotFoundError(f"News dataset not found at: {path}")

    with open(path, "r", encoding="utf-8") as f:
        first_char = f.read(1)
        f.seek(0)

        # Try to parse as one big JSON (dict or list)
        try:
            data = json.load(f)
        except json.JSONDecodeError:
            # Fallback: assume JSONL (one JSON object per line)
            articles: List[Dict[str, Any]] = []
            f.seek(0)
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                    if isinstance(obj, dict):
                        articles.append(obj)
                except json.JSONDecodeError:
                    continue
            return articles

    # If it's a list at the top level, keep only dict entries
    if isinstance(data, list):
        return [x for x in data if isinstance(x, dict)]

    # If it's a dict, recursively collect all dicts that look like articles
    collected: List[Dict[str, Any]] = []

    def collect(obj):
        if isinstance(obj, dict):
            # Heuristic: has any typical article fields
            if any(
                k in obj
                for k in [
                    "headline",
                    "title",
                    "summary",
                    "description",
                    "content",
                    "body",
                    "url",
                    "link",
                ]
            ):
                collected.append(obj)
            # Recurse into nested values
            for v in obj.values():
                collect(v)
        elif isinstance(obj, list):
            for v in obj:
                collect(v)

    collect(data)
    return collected


def filter_by_source_rating(articles: List[Dict[str, Any]], min_rating: float = 8.0) -> List[Dict[str, Any]]:
    """
    Keep only articles with source_rating > min_rating.
    Safely ignores non-dict entries and bad ratings.
    """
    filtered: List[Dict[str, Any]] = []
    for a in articles:
        if not isinstance(a, dict):
            continue
        try:
            rating = float(a.get("source_rating", 0))
        except (TypeError, ValueError):
            rating = 0
        if rating > min_rating:
            filtered.append(a)
    return filtered


# -------------------- HELPERS --------------------

def get_headline(a: Dict[str, Any]) -> str:
    return (
        a.get("headline")
        or a.get("title")
        or a.get("short_title")
        or "Untitled"
    )


def get_url(a: Dict[str, Any]) -> str:
    return a.get("url") or a.get("link") or ""


def get_date(a: Dict[str, Any]) -> Optional[datetime]:
    """
    Parse many possible date formats and always return a *naive* datetime
    (timezone info stripped) so comparisons & sorting won't fail.
    """
    for k in ["date", "published_at", "published", "pub_date", "time"]:
        v = a.get(k)
        if not v:
            continue
        for fmt in [
            "%Y-%m-%d",
            "%Y-%m-%dT%H:%M:%S%z",
            "%Y-%m-%dT%H:%M:%S",
            "%Y-%m-%d %H:%M:%S",
        ]:
            try:
                dt = datetime.strptime(v, fmt)
                # Normalize to naive (drop timezone) to avoid TypeError
                dt = dt.replace(tzinfo=None)
                return dt
            except Exception:
                continue
    return None


def get_date_str(a: Dict[str, Any]) -> str:
    dt = get_date(a)
    if dt:
        return dt.strftime("%Y-%m-%d")
    return (
        a.get("date")
        or a.get("published_at")
        or a.get("published")
        or ""
    )


def build_text(a: Dict[str, Any]) -> str:
    parts = [
        get_headline(a),
        a.get("summary") or "",
        a.get("description") or "",
        a.get("content") or a.get("body") or "",
    ]
    return " ".join(t for t in parts if t).strip()


# -------------------- CORE LOGIC --------------------

def select_relevant(
    articles: List[Dict[str, Any]],
    topic: str,
    model: SentenceTransformer,
    max_articles: int = 100,
    min_sim: float = 0.35,
):
    """
    Embed articles and select those most relevant to the topic.
    """
    texts = [build_text(a) for a in articles]
    idxs = [i for i, t in enumerate(texts) if t]

    if not idxs:
        return [], np.array([]), np.array([])

    valid_texts = [texts[i] for i in idxs]
    embeds = model.encode(valid_texts, normalize_embeddings=True)
    topic_embed = model.encode([topic], normalize_embeddings=True)
    sims = cosine_similarity(topic_embed, embeds)[0]

    # Filter by similarity threshold
    valid = [i for i, s in enumerate(sims) if s >= min_sim]
    if not valid:
        # Fallback: top-k by similarity
        valid = list(np.argsort(-sims)[:max_articles])

    valid = sorted(valid, key=lambda i: sims[i], reverse=True)
    valid = valid[:max_articles]

    selected_articles = [articles[idxs[i]] for i in valid]
    selected_embeds = embeds[valid]
    selected_sims = sims[valid]

    return selected_articles, selected_embeds, selected_sims


def narrative_summary(
    topic: str,
    articles: List[Dict[str, Any]],
    sims: np.ndarray,
    max_len: int = 8,
) -> str:
    """
    Build 5–10 sentence narrative summary.
    """
    if not len(articles):
        return f"No strong matches were found for '{topic}'."

    paired = list(zip(articles, sims))
    paired.sort(key=lambda x: x[1], reverse=True)

    lines = []
    for article, _ in paired[: max_len * 2]:
        date = get_date_str(article) or "unknown date"
        headline = get_headline(article)
        src = article.get("source") or article.get("source_name") or "a news source"

        lines.append(
            f"On {date}, {src} reported '{headline}', which is relevant to {topic}."
        )
        if len(lines) >= max_len:
            break

    return " ".join(lines)


def build_timeline(articles: List[Dict[str, Any]]):
    """
    Chronological timeline: date, headline, url, why_it_matters.
    """
    items = [(get_date(a), a) for a in articles]
    # dt is always naive now, so this sort won't fail
    items.sort(key=lambda x: (x[0] is None, x[0]))

    out = []
    for dt, a in items:
        out.append(
            {
                "date": get_date_str(a),
                "headline": get_headline(a),
                "url": get_url(a),
                "why_it_matters": "Provides context or developments relevant to the topic.",
            }
        )
    return out


def build_clusters(
    articles: List[Dict[str, Any]],
    embeds: np.ndarray,
    k: int = 4,
):
    """
    Group semantically similar articles into clusters (themes).
    """
    n = len(articles)
    if n <= 3:
        return [
            {
                "cluster_id": 0,
                "label": "General theme",
                "articles": [
                    {
                        "id": i,
                        "headline": get_headline(a),
                        "url": get_url(a),
                    }
                    for i, a in enumerate(articles)
                ],
            }
        ]

    k = min(k, max(1, n // 5))
    km = KMeans(n_clusters=k, random_state=42, n_init="auto")
    labels = km.fit_predict(embeds)

    clusters = []
    for cid in range(k):
        idxs = [i for i, l in enumerate(labels) if l == cid]
        arts = [articles[i] for i in idxs]

        if arts:
            label = f"Theme: {get_headline(arts[0])}"
        else:
            label = f"Cluster {cid}"

        clusters.append(
            {
                "cluster_id": cid,
                "label": label,
                "articles": [
                    {
                        "id": i,
                        "headline": get_headline(articles[i]),
                        "url": get_url(articles[i]),
                    }
                    for i in idxs
                ],
            }
        )
    return clusters


def build_graph(
    articles: List[Dict[str, Any]],
    embeds: np.ndarray,
    sims: np.ndarray,
    threshold: float = 0.6,
):
    """
    Narrative graph:
      nodes = articles
      edges = semantic relations between articles.
    """
    n = len(articles)
    nodes = [
        {
            "id": i,
            "headline": get_headline(a),
            "url": get_url(a),
            "date": get_date_str(a),
        }
        for i, a in enumerate(articles)
    ]
    if n <= 1:
        return {"nodes": nodes, "edges": []}

    sim_matrix = cosine_similarity(embeds)
    edges = []

    for i in range(n):
        for j in range(i + 1, n):
            if sim_matrix[i, j] >= threshold:
                edges.append(
                    {
                        "source": i,
                        "target": j,
                        "relation": "related",
                        "similarity": float(sim_matrix[i, j]),
                    }
                )

    return {"nodes": nodes, "edges": edges}


# -------------------- CLI ENTRY --------------------

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--topic", required=True, help="Topic to build narrative for")
    parser.add_argument(
        "--data_path",
        required=True,
        help="Path to the 84MB news JSON/JSONL dataset",
    )
    parser.add_argument(
        "--max_articles",
        type=int,
        default=100,
        help="Maximum number of relevant articles to include",
    )
    args = parser.parse_args()

    # 1. Load and filter
    articles = load_news(args.data_path)
    articles = filter_by_source_rating(articles, min_rating=8.0)

    if not articles:
        out = {
            "narrative_summary": "No articles with source_rating > 8 found.",
            "timeline": [],
            "clusters": [],
            "graph": {"nodes": [], "edges": []},
        }
        print(json.dumps(out, indent=2, ensure_ascii=False))
        return

    # 2. Model
    model = SentenceTransformer("all-MiniLM-L6-v2")

    # 3. Select relevant
    selected, embeds, sims = select_relevant(
        articles,
        topic=args.topic,
        model=model,
        max_articles=args.max_articles,
    )

    if not selected:
        out = {
            "narrative_summary": f"No relevant articles could be matched for topic '{args.topic}'.",
            "timeline": [],
            "clusters": [],
            "graph": {"nodes": [], "edges": []},
        }
        print(json.dumps(out, indent=2, ensure_ascii=False))
        return

    # 4. Build narrative structure
    summary = narrative_summary(args.topic, selected, sims)
    timeline = build_timeline(selected)
    clusters = build_clusters(selected, embeds)
    graph = build_graph(selected, embeds, sims)

    out = {
        "narrative_summary": summary,
        "timeline": timeline,
        "clusters": clusters,
        "graph": graph,
    }

    # 5. Print final JSON to stdout
    print(json.dumps(out, indent=2, ensure_ascii=False))


if __name__ == "__main__":
    main()


Overwriting narrative_builder.py


In [19]:
!python narrative_builder.py \
  --data_path "/content/14e9e4cc-9174-48da-ad02-abb1330b48fe.json" \
  --topic "AI regulation" \
  --max_articles 80

2025-11-17 08:20:38.137143: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763367638.179914    6843 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763367638.192925    6843 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1763367638.223513    6843 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763367638.223576    6843 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763367638.223590    6843 computation_placer.cc:177] computation placer alr