In [2]:
! pip install networkx

Collecting networkx
  Obtaining dependency information for networkx from https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl.metadata
  Downloading networkx-3.5-py3-none-any.whl.metadata (6.3 kB)
Downloading networkx-3.5-py3-none-any.whl (2.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: networkx
Successfully installed networkx-3.5

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


Nodes: agents (personas).

Edges: possible social interactions (who sees whose message).

Build a homophilic network:
Start with a random graph (e.g., Erdős–Rényi or fixed degree).

Mean Belief (mean_belief): The average belief value across all agents. This shows the central tendency of opinions.

Standard Deviation (std_belief): The spread of beliefs. A higher value means more disagreement or diversity in opinions.

Fraction Extreme (frac_extreme): The proportion of agents whose beliefs are very strong (absolute value > 0.8). This measures how many agents hold extreme views.

Cluster Gap (cluster_gap): If there are at least 4 agents, the code uses KMeans clustering (with 2 clusters) on the beliefs. The absolute difference between the two cluster centers is computed. A large gap suggests the population is split into two distinct groups (i.e., polarization).

1. Mean Belief (mean_belief):

Range: -1 to 1
Interpretation:
-1 = all agents strongly disagree (e.g., with climate change).
0 = agents are, on average, neutral or evenly split.
1 = all agents strongly agree.


2. Standard Deviation (std_belief):

Range: 0 to 1 (theoretical max, but usually less)
Interpretation:
0 = all agents have the same belief (no diversity).
Higher values = more disagreement/diversity in beliefs.
Max is 1 if half are at -1 and half at 1.


3. Fraction Extreme (frac_extreme):

Range: 0 to 1
Interpretation:
0 = no agents have extreme beliefs (|belief| ≤ 0.8).
1 = all agents have extreme beliefs (|belief| > 0.8).
Values near 1 mean most agents are at the extremes.



4. Cluster Gap (cluster_gap):

Range: 0 to 2
Interpretation:
0 = no separation between two main groups (everyone similar).
2 = two groups at -1 and 1 (maximal polarization).
Higher values mean two distinct camps; low values mean beliefs are mixed or unimodal.
Summary Table:

Metric	Range	What High Value Means	What Low Value Means
mean_belief	-1 to 1	Consensus (all agree/disagree)	Neutral or split
std_belief	0 to 1	High diversity/disagreement	Uniformity
frac_extreme	0 to 1	Most agents are extreme	Most are moderate/neutral
cluster_gap	0 to 2	Two opposing camps (polarization)	No clear split


In [None]:
import argparse
import json
import re
from pathlib import Path
from typing import Dict, Any, List, Optional

import random
import numpy as np
import pandas as pd
import networkx as nx
from tqdm import tqdm
from sklearn.cluster import KMeans

import ollama
from prompts import *

def build_persona_description(row: pd.Series) -> str:
    return (
        f"- PersonaID: {row.get('PersonaID')}\n"
        f"- AgeGroup: {row.get('AgeGroup')}\n"
        f"- Gender: {row.get('Gender')}\n"
        f"- EducationLevel: {row.get('EducationLevel')}\n"
        f"- OccupationSector: {row.get('OccupationSector')}\n"
        f"- Region: {row.get('Region')}\n"
        f"- PoliticalIdeology: {row.get('PoliticalIdeology')}\n"
        f"- Trust_ScienceInstitutions: {row.get('Trust_ScienceInstitutions')}\n"
        f"- Belief_ClimateExists: {row.get('Belief_ClimateExists')}\n"
        f"- Belief_HumanContribution: {row.get('Belief_HumanContribution')}\n"
        f"- Emotional_WorryAboutClimate: {row.get('Emotional_WorryAboutClimate')}\n"
        f"- BehaviouralOrientation: {row.get('BehaviouralOrientation')}\n"
        f"- SocialConnectivity: {row.get('SocialConnectivity')}"
    )


def stance_to_numeric(stance: str) -> float:
    """
    Map Likert climateChangeStance to a numeric belief in [-1, 1].
    You can tweak these weights later if you like.
    """
    if not isinstance(stance, str):
        return 0.0
    s = stance.strip().lower()
    mapping = {
        "strongly disagree": -1.0,
        "slightly disagree": -0.5,
        "neutral": 0.0,
        "slightly agree": 0.5,
        "strongly agree": 1.0,
    }
    return mapping.get(s, 0.0)


def initial_belief_from_persona(row: pd.Series) -> float:
    """
    Use Belief_ClimateExists if it's Likert-like; otherwise default to 0.
    """
    val = row.get("Belief_ClimateExists")
    if isinstance(val, str):
        return stance_to_numeric(val)
    try:
        v = float(val)
        return float(np.clip(v, -1.0, 1.0))
    except Exception:
        return 0.0


def load_claims_with_label_JSON(path: str) -> List[Dict[str, Any]]:
    raw = json.loads(Path(path).read_text())
    out = []
    for it in raw:
        # claim_text = (
        #     it.get("claim") or it.get("claim_text") or it.get("statement")
        #     or it.get("text") or ""
        # )
        claim_text = it.get("claim")
        if not claim_text:
            raise ValueError(f"Claim text not found in item: {it}")
        
        out.append({
            "claim_id": it.get("claim_id") or it.get("id"),
            "claim_text": claim_text,
            "claim_stance_label": it.get("label")
            or it.get("claim_label")
            or it.get("verdict")
            or it.get("stance"),
        })
    return out


def load_claims_with_label_CSV(path: str) -> List[Dict[str, Any]]:
    try:
        df = pd.read_csv(Path(path))
    except Exception as e:
        print(f"Error reading CSV file at {path}: {e}")
        return []

    raw = df.to_dict("records")
    out = []
    for it in raw:
        claim_text = (
            it.get("claim") or it.get("claim_text") or it.get("statement")
            or it.get("text") or ""
        )
        if not claim_text:
            continue

        out.append({
            "claim_id": it.get("claim_id") or it.get("id"),
            "claim_text": claim_text,
            "claim_stance_label": it.get("stance_label")
        })
    return out


def filter_balanced_claims(claims: List[Dict[str, Any]], n_each: int = 100) -> List[Dict[str, Any]]:
    df = pd.DataFrame(claims)
    if "claim_stance_label" not in df.columns:
        raise ValueError("Expected 'claim_stance_label' field in claims.")

    df["claim_stance_label"] = df["claim_stance_label"].astype(str).str.upper()

    supports = df[df["claim_stance_label"] == "SUPPORTS"]
    refutes = df[df["claim_stance_label"] == "REFUTES"]

    n_each = min(n_each, len(supports), len(refutes))
    supports_sample = supports.sample(n=n_each, random_state=42)
    refutes_sample = refutes.sample(n=n_each, random_state=42)

    balanced = pd.concat([supports_sample, refutes_sample]).sample(
        frac=1, random_state=42
    ).to_dict(orient="records")
    return balanced


def coerce_json(text: str) -> Dict[str, Any]:
    """
    Your forgiving JSON parser, slightly cleaned.
    """
    text = text.strip().strip("`").replace("```json", "").replace("```", "")
    s, e = text.find("{"), text.rfind("}")
    if s != -1 and e != -1:
        text = text[s:e + 1]

    # Try JSON parse
    try:
        return json.loads(text)
    except Exception:
        # fallback simple regex
        stance_match = re.search(
            r'"climateChangeStance"\s*:\s*"([^"]+)"', text)
        claim_match = re.search(r'"claimStance"\s*:\s*"([^"]+)"', text)
        return {
            "climateChangeStance": stance_match.group(1) if stance_match else "Neutral",
            "claimStance": claim_match.group(1) if claim_match else "Not Support",
        }

def chat_once(model: str, temperature: float, system_msg: str, user_msg: str) -> str:
    r = ollama.chat(
        model=model,
        options={"temperature": temperature},
        messages=[
            {"role": "system", "content": system_msg},
            {"role": "user", "content": user_msg},
        ],
    )
    return r["message"]["content"].strip()


def chat_seq(model: str, temperature: float, messages: List[Dict[str, str]]) -> str:
    r = ollama.chat(
        model=model,
        options={"temperature": temperature},
        messages=messages,
    )
    return r["message"]["content"].strip()



class Agent:
    """
    One persona-based LLM agent.
    Belief is inferred from last climateChangeStance.
    """

    def __init__(self, idx: int, row: pd.Series):
        self.idx = idx
        self.persona_id = row.get("PersonaID", f"persona_{idx}")
        self.name = f"Persona_{self.persona_id}"
        self.persona_desc = build_persona_description(row)
        self.current_belief = initial_belief_from_persona(row)
        self.history: List[Dict[str, Any]] = []  # one entry per time step

    def last_stance_text(self) -> Optional[str]:
        if not self.history:
            return None
        h = self.history[-1]
        return h.get("llm_response_raw")

    def last_stance_struct(self) -> Optional[Dict[str, Any]]:
        if not self.history:
            return None
        return self.history[-1].get("llm_response")


def build_fully_connected_graph(n_agents: int) -> nx.Graph:
    """
    Every agent sees every other agent.
    Neutral, maximally mixed environment.
    """
    G = nx.complete_graph(n_agents)
    return G


def build_random_graph(n_agents: int, avg_degree: int = 4, seed: int = 42) -> nx.Graph:
    """
    Erdős–Rényi random graph, no bias from beliefs.
    """
    p = avg_degree / max(n_agents - 1, 1)
    G = nx.erdos_renyi_graph(n_agents, p, seed=seed)

    # ensure connectivity (optional but nice)
    if not nx.is_connected(G):
        components = list(nx.connected_components(G))
        for c1, c2 in zip(components[:-1], components[1:]):
            i = next(iter(c1))
            j = next(iter(c2))
            G.add_edge(i, j)
    return G


def build_small_world_graph(
    n_agents: int, k: int = 4, beta: float = 0.1, seed: int = 42
) -> nx.Graph:
    """
    Watts–Strogatz small-world graph.
    Also neutral w.r.t. beliefs.
    k = each node is connected to k nearest neighbors in a ring (must be even).
    beta = rewiring probability.
    """
    if k % 2 == 1:
        k += 1  # enforce even
    if k >= n_agents:
        k = max(2, n_agents - 1)
        if k % 2 == 1:
            k -= 1
    G = nx.watts_strogatz_graph(n_agents, k, beta, seed=seed)
    return G



def summarize_neighbors(
    agent_idx: int,
    agents: List[Agent],
    G: nx.Graph,
    max_neighbors: int = 3,
) -> str:
    """
    Build a short text summary of neighbors' last stances.
    This is what you feed into the user prompt as 'social context'.
    """
    neighbors = list(G.neighbors(agent_idx))
    if not neighbors:
        return "No one in your social circle has expressed an opinion yet."

    # Take up to max_neighbors neighbors that actually have history
    candidates = [j for j in neighbors if agents[j].history]
    if not candidates:
        return "No one in your social circle has expressed an opinion yet."

    random.shuffle(candidates)
    chosen = candidates[:max_neighbors]

    lines = []
    for j in chosen:
        h = agents[j].history[-1]
        n_name = agents[j].name
        llm_struct = h.get("llm_response") or {}
        cstance = llm_struct.get("climateChangeStance", "Neutral")
        claimstance = llm_struct.get("claimStance", "Support")
        lines.append(
            f"- {n_name} evaluated a similar climate-related claim and had climateChangeStance='{cstance}' "
            f"and claimStance='{claimstance}'."
        )

    return "\n".join(lines) if lines else "No one in your social circle has expressed an opinion yet."

def simulate_polarization(
    model: str,
    personas_df: pd.DataFrame,
    claims: List[Dict[str, Any]],
    graph_type: str = "fully_connected",
    steps: int = 10,
    temperature: float = 0.2,
    avg_degree: int = 4,
    small_world_k: int = 4,
    small_world_beta: float = 0.1,
    seed: int = 42,
) -> pd.DataFrame:
    """
    Run multi-agent simulation with 3 graph options.
    Returns a log DataFrame (one row per agent per time step).
    """
    random.seed(seed)
    np.random.seed(seed)

    # Initialize agents
    agents: List[Agent] = [
        Agent(idx=i, row=row) for i, (_, row) in enumerate(personas_df.iterrows())
    ]
    n_agents = len(agents)

    # Build graph
    if graph_type == "fully_connected":
        G = build_fully_connected_graph(n_agents)
    elif graph_type == "random":
        G = build_random_graph(n_agents, avg_degree=avg_degree, seed=seed)
    elif graph_type == "small_world":
        G = build_small_world_graph(
            n_agents, k=small_world_k, beta=small_world_beta, seed=seed)
    else:
        raise ValueError(f"Unknown graph_type '{graph_type}'. Use one of: fully_connected, random, small_world")

    print(f"Graph '{graph_type}' -> nodes={G.number_of_nodes()}, edges={G.number_of_edges()}")

    logs: List[Dict[str, Any]] = []

    # Cycle through claims
    if not claims:
        raise ValueError("No claims loaded.")

    for t in range(steps):
        claim = claims[t % len(claims)]
        claim_id = claim.get("claim_id")
        claim_text = str(claim.get("claim_text"))

        for agent in tqdm(agents, desc=f"Time step {t}", leave=False):
            # Build system prompt for this persona
            system_msg = GROUP_SYSTEM_TMPL.replace(
                "{PERSONA_DESCRIPTION}", agent.persona_desc
            )

            # Build neighbor summary
            neighbor_summary = summarize_neighbors(
                agent.idx, agents, G, max_neighbors=3
            )
            if "No one in your social circle" in neighbor_summary:
                user_prompt = USER_TMPL_NO_NEIGHBORS.replace(
                    "{CLAIM_TEXT}", claim_text
                )
            else:
                user_prompt = USER_TMPL_WITH_NEIGHBORS.replace(
                    "{CLAIM_TEXT}", claim_text
                ).replace("{NEIGHBOR_SUMMARY}", neighbor_summary)

            # Call model
            raw = chat_once(
                model=model,
                temperature=temperature,
                system_msg=system_msg,
                user_msg=user_prompt
                + "\n\nONLY return the JSON above — no words or explanation.",
            )

            parsed = coerce_json(raw)
            new_belief = stance_to_numeric(parsed.get("climateChangeStance", "Neutral"))
            agent.current_belief = new_belief

            # Append to agent history
            step_record = {
                "time": t,
                "persona_id": agent.persona_id,
                "agent_idx": agent.idx,
                "name": agent.name,
                "claim_id": claim_id,
                "claim_text": claim_text,
                "graph_type": graph_type,
                "climateChangeStance": parsed.get("climateChangeStance"),
                "claimStance": parsed.get("claimStance"),
                "belief_numeric": agent.current_belief,
                "llm_response": parsed,
                "llm_response_raw": raw,
                "neighbor_summary": neighbor_summary,
            }
            agent.history.append(step_record)
            logs.append(step_record)

    df_logs = pd.DataFrame(logs)
    return df_logs


# ============================================================
# 8. Polarization metrics
# ============================================================

def compute_polarization_metrics(df_logs: pd.DataFrame) -> pd.DataFrame:
    """
    Compute simple polarization metrics per time step from df_logs.
    """
    metrics = []
    for t, df_t in df_logs.groupby("time"):
        beliefs = df_t["belief_numeric"].values.reshape(-1, 1)
        mean = float(beliefs.mean())
        std = float(beliefs.std())
        frac_extreme = float(np.mean(np.abs(beliefs) > 0.8))

        # cluster gap (two-cluster separation)
        if len(beliefs) >= 4:
            km = KMeans(n_clusters=2, n_init=10, random_state=0)
            km.fit(beliefs)
            m1, m2 = km.cluster_centers_.flatten()
            cluster_gap = float(abs(m1 - m2))
        else:
            cluster_gap = float("nan")

        metrics.append({
            "time": t,
            "mean_belief": mean,
            "std_belief": std,
            "frac_extreme": frac_extreme,
            "cluster_gap": cluster_gap,
        })

    return pd.DataFrame(metrics)


# ============================================================
# 9. CLI main
# ============================================================

def main():
    parser = argparse.ArgumentParser(description="Multi-agent climate polarization simulation with LLM personas.")
    parser.add_argument("--model", required=True, help="Ollama model name")
    parser.add_argument("--personas", required=True, help="Path to personas CSV file")
    parser.add_argument("--claims", required=True, help="Path to claims CSV or JSON file")
    parser.add_argument(
        "--claims_format",
        choices=["csv", "json"],
        default="csv",
        help="Format of claims file",
    )
    parser.add_argument(
        "--graph_type",
        choices=["fully_connected", "random", "small_world"],
        default="fully_connected",
        help="Interaction graph type",
    )
    parser.add_argument("--steps", type=int, default=10, help="Number of time steps")
    parser.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature")
    parser.add_argument("--avg_degree", type=int, default=4, help="Avg degree for random graph")
    parser.add_argument("--small_world_k", type=int, default=4, help="k for small-world")
    parser.add_argument("--small_world_beta", type=float, default=0.1, help="beta for small-world")
    parser.add_argument("--n_personas", type=int, default=None, help="Subset of personas")
    parser.add_argument("--n_claims", type=int, default=50, help="Number of claims to use")
    parser.add_argument("--balanced_claims", action="store_true",
                        help="If set, balance SUPPORTS / REFUTES (only for labeled datasets)")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--out_prefix", default="group_polarization", help="Output prefix")

    args = parser.parse_args()

    # Personas
    personas_df = pd.read_csv(args.personas)
    if args.n_personas is not None:
        personas_df = personas_df.sample(
            n=min(args.n_personas, len(personas_df)),
            random_state=args.seed)

    # Claims
    if args.claims_format == "csv":
        claims = load_claims_with_label_CSV(args.claims)
    else:
        claims = load_claims_with_label_JSON(args.claims)

    if args.balanced_claims:
        claims = filter_balanced_claims(claims, n_each=args.n_claims // 2)
    else:
        # Just sample without balancing
        if args.n_claims is not None and args.n_claims < len(claims):
            claims = (
                pd.DataFrame(claims)
                .sample(n=args.n_claims, random_state=args.seed)
                .to_dict(orient="records"))

    print(f"Loaded {len(personas_df)} personas and {len(claims)} claims.")

    df_logs = simulate_polarization(
        model=args.model,
        personas_df=personas_df,
        claims=claims,
        graph_type=args.graph_type,
        steps=args.steps,
        temperature=args.temperature,
        avg_degree=args.avg_degree,
        small_world_k=args.small_world_k,
        small_world_beta=args.small_world_beta,
        seed=args.seed,
    )

    df_metrics = compute_polarization_metrics(df_logs)

    out_dir = Path("outputs")
    out_dir.mkdir(parents=True, exist_ok=True)

    logs_path = out_dir / f"{args.out_prefix}_logs.jsonl"
    metrics_path = out_dir / f"{args.out_prefix}_metrics.csv"

    # Save logs as JSONL
    with logs_path.open("w", encoding="utf-8") as f:
        for _, row in df_logs.iterrows():
            f.write(json.dumps(row.to_dict(), ensure_ascii=False) + "\n")

    # Save metrics
    df_metrics.to_csv(metrics_path, index=False)

    print(f"Saved logs to {logs_path}")
    print(f"Saved metrics to {metrics_path}")


if __name__ == "__main__":
    main()
