<a href="https://colab.research.google.com/github/pinballsurgeon/NLP_Engine_Study/blob/main/gemini_game_of_life.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [46]:
!pip -q install --upgrade google-genai google-generativeai pillow imageio imageio-ffmpeg numpy matplotlib

In [47]:
import os, time, json, numpy as np
from pathlib import Path

# ---------------------------
# constants
# ---------------------------
# Grid and run
N = 8                    # grid size (finite edges)
STEPS = 50               # number of agentic turns
P_ALIVE = 0.35           # initial density
RNG_SEED = 2025
TEMPS = [0.0, 0.25, 0.50, 0.75, 1.00]

MODEL_NAME = "gemini-2.5-flash-lite"

SYSTEM_HDR = (
  "You control an NxN Game of Life grid with finite edges (no wrap). "
  "Apply Conway rule B3/S23 exactly: a live cell survives with 2 or 3 live neighbors; "
  "a dead cell becomes live if it has exactly 3 live neighbors; otherwise it is dead. "
  "Neighbors are the 8 surrounding cells within the grid; cells outside are ignored. "
  "Return STRICT JSON only with schema: {\"next_grid\": [[0 or 1,...],...], \"reason\": \"...\"}. "
  "The next_grid must have the same size as the input grid and contain only integers 0 or 1."
)

# Minimal few-shot (3x3 blinker)
FEWSHOT_IN  = [[0,1,0],[0,1,0],[0,1,0]]
FEWSHOT_OUT = [[0,0,0],[1,1,1],[0,0,0]]



try:
    from google.colab import userdata
    GEMINI_API_KEY = userdata.get('GEMINI_API_KEY') or os.environ.get('GEMINI_API_KEY', '')
except Exception:
    GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY', '')

if not GEMINI_API_KEY:
    try:
        from getpass import getpass
        GEMINI_API_KEY = getpass("Enter GEMINI_API_KEY (input hidden): ").strip()
    except Exception:
        GEMINI_API_KEY = ""

if GEMINI_API_KEY:
    os.environ["GEMINI_API_KEY"] = GEMINI_API_KEY
    print("✅ GEMINI_API_KEY configured.")
else:
    print("⚠️ No GEMINI_API_KEY found yet. You can still run the classic track; set it before agent steps.")


def trk_name(t: float) -> str:
    return f"g_t{t:.2f}"

TRACKS = ["classic"] + [trk_name(t) for t in TEMPS]

# Color palette (hex) for plots/panels
COLOR_CLASSIC = "#60a5fa"   # blue
COLOR_T = {
    0.00: "#a78bfa",  # purple
    0.25: "#22c55e",  # green
    0.50: "#fb923c",  # orange
    0.75: "#ef4444",  # red
    1.00: "#ec4899",  # pink
}

STRICT_MODE = True       # if True: halt a Gemini track on invalid JSON
BOUNDARY = "finite"      # no wrap (easier for model reasoning)

VIDEO_W, VIDEO_H = 1088, 1920     # divisible by 16 (codec-friendly)
FPS = 4                           # 25 steps ~ 2.5s; we can extend later with loops
TITLE = "Gemini Game of Life"


# ---------------------------
# 4) Output paths + run snapshot
# ---------------------------
RUN_ID = f"agol-{int(time.time())}"
OUT_DIR = Path(f"./out_{RUN_ID}")
OUT_DIR.mkdir(parents=True, exist_ok=True)

RUN_CONFIG = {
    "model": MODEL_NAME,
    "grid": {"N": N, "boundary": BOUNDARY, "p_alive": P_ALIVE},
    "steps": STEPS,
    "temps": TEMPS,
    "tracks": TRACKS,
    "strict_mode": STRICT_MODE,
    "video": {"W": VIDEO_W, "H": VIDEO_H, "fps": FPS, "title": TITLE},
    "rng_seed": RNG_SEED,
}

with open(OUT_DIR / "run_config.json", "w") as f:
    json.dump(RUN_CONFIG, f, indent=2)

# ---------------------------
# 5) Friendly summary
# ---------------------------
print("🎯 Project: 3 Gemini temperatures vs Classic (100% agent next_grid)")
print(f"🧠 Model: {MODEL_NAME}   Key set: {bool(GEMINI_API_KEY)}")
print(f"🧩 Grid: N={N}, steps={STEPS}, boundary={BOUNDARY}, seed={RNG_SEED}")
print(f"🧪 Tracks: {', '.join(TRACKS)} (temps={TEMPS})")
print(f"🎥 Video: {VIDEO_W}×{VIDEO_H} @ {FPS} fps")
print("📁 Outputs →", OUT_DIR.resolve())


✅ GEMINI_API_KEY configured.
🎯 Project: 3 Gemini temperatures vs Classic (100% agent next_grid)
🧠 Model: gemini-2.5-flash-lite   Key set: True
🧩 Grid: N=8, steps=50, boundary=finite, seed=2025
🧪 Tracks: classic, g_t0.00, g_t0.25, g_t0.50, g_t0.75, g_t1.00 (temps=[0.0, 0.25, 0.5, 0.75, 1.0])
🎥 Video: 1088×1920 @ 4 fps
📁 Outputs → /content/out_agol-1755953812


In [48]:
# Cell 1 — Classic GoL (finite), deterministic seeding, and fast metrics

import numpy as np

# ---------- Classic B3/S23 on FINITE edges (no wrap) ----------
def life_neighbors_finite(grid, y, x):
    n = len(grid)
    s = 0
    for dy in (-1,0,1):
        for dx in (-1,0,1):
            if dy==0 and dx==0:
                continue
            yy, xx = y+dy, x+dx
            if 0 <= yy < n and 0 <= xx < n:
                s += 1 if grid[yy][xx] else 0
    return s

def life_step_finite(grid):
    n = len(grid)
    nxt = [[0]*n for _ in range(n)]
    for y in range(n):
        for x in range(n):
            c = life_neighbors_finite(grid, y, x)
            alive = 1 if grid[y][x] else 0
            if alive:
                nxt[y][x] = 1 if (c==2 or c==3) else 0
            else:
                nxt[y][x] = 1 if (c==3) else 0
    return nxt

# ---------- Deterministic seed (mini pack) ----------
def seed_mini_pack(n=8, rng_seed=RNG_SEED, p_alive=P_ALIVE):
    """
    Compose a reliable starter: a glider + a blinker fragment,
    then sprinkle a few random cells with fixed RNG seed.
    """
    assert n >= 8, "seed_mini_pack expects n>=8"
    g = [[0]*n for _ in range(n)]

    # glider at (1,1)
    for (yy,xx) in [(1,2),(2,3),(3,1),(3,2),(3,3)]:
        g[yy][xx] = 1
    # blinker at (5,2) horizontal
    for xx in (2,3,4):
        g[5][xx] = 1

    # light random sprinkle (stable by seed)
    rng = np.random.default_rng(rng_seed)
    mask = (rng.random((n,n)) < p_alive*0.15)
    for y in range(n):
        for x in range(n):
            g[y][x] = 1 if (g[y][x] or mask[y,x]) else 0
    return g

# ---------- Metrics ----------
def alive_density(grid):
    arr = np.array(grid, dtype=np.uint8)
    return float(arr.mean())

def perimeter(grid):
    """4-neighbor edge mismatches inside finite grid."""
    arr = np.array(grid, dtype=np.uint8)
    n = arr.shape[0]
    # horizontal mismatches
    Lh = (arr[:,1:] != arr[:,:-1]).sum()
    # vertical mismatches
    Lv = (arr[1:,:] != arr[:-1,:]).sum()
    return int(Lh + Lv)

def volatility(prev, cur):
    """(births, deaths, total flips) from prev->cur."""
    a = np.array(prev, dtype=np.uint8)
    b = np.array(cur, dtype=np.uint8)
    births = int((~a.astype(bool) & b.astype(bool)).sum())
    deaths = int((a.astype(bool) & ~b.astype(bool)).sum())
    return births, deaths, births + deaths

def hamming(a, b):
    """Total differing cells."""
    A = np.array(a, dtype=np.uint8)
    B = np.array(b, dtype=np.uint8)
    return int((A != B).sum())

print("✅ Cell 1 ready: classic GoL, seeding, metrics.")


✅ Cell 1 ready: classic GoL, seeding, metrics.


In [49]:
# Cell 2 — Gemini client + robust JSON caller for full next_grid (strict mode supported)

import os, json, time

# ---------------------------
# 0) SDK init (new first, legacy fallback)
# ---------------------------
_use_new = False
_client_new = None
_types_new = None

try:
    from google import genai as genai_new
    from google.genai import types as types_new
    _client_new = genai_new.Client(api_key=os.environ.get("GEMINI_API_KEY",""))
    _types_new = types_new
    _use_new = True
except Exception:
    # Legacy SDK will be used in llm_call_json fallback
    try:
        import google.generativeai as genai_old
        genai_old.configure(api_key=os.environ.get("GEMINI_API_KEY",""))
    except Exception as e:
        raise RuntimeError(
            "Gemini SDK init failed. Ensure google-genai OR google-generativeai is installed and GEMINI_API_KEY is set."
        ) from e

def build_agent_prompt(grid_now):
    n = len(grid_now)
    return (
        f"{SYSTEM_HDR}\n"
        f"Example input: {json.dumps(FEWSHOT_IN)}\n"
        f"Example output: {json.dumps({'next_grid': FEWSHOT_OUT, 'reason':'example'})}\n"
        f"Current grid size: {n}\n"
        f"Current grid: {json.dumps(grid_now)}\n"
        "Output JSON only."
    )

# ---------------------------
# 2) Helpers for parsing & repair
# ---------------------------
def _strip_code_fences(txt: str) -> str:
    s = (txt or "").strip()
    if s.startswith("```"):
        nl = s.find("\n")
        if nl != -1:
            s = s[nl+1:]
        if s.endswith("```"):
            s = s[:-3]
        s = s.strip()
    return s

def _coerce_binary_matrix(obj, N):
    """Strict: ensure NxN of ints in {0,1}; raise on error."""
    if not (isinstance(obj, list) and len(obj) == N and all(isinstance(r, list) for r in obj)):
        raise ValueError("bad shape")
    out = []
    for r in obj:
        if len(r) != N:
            raise ValueError("bad row length")
        row = []
        for x in r:
            try:
                xi = int(round(float(x)))
            except Exception:
                raise ValueError("non-numeric")
            row.append(1 if xi >= 1 else 0)
        out.append(row)
    return out

def _repair_to_NxN(obj, N):
    """Best-effort: flatten numbers, coerce to {0,1}, pad/crop to NxN."""
    vals = []

    def _flatten(v):
        if isinstance(v, list):
            for z in v:
                _flatten(z)
        else:
            vals.append(v)

    _flatten(obj)
    if not vals:
        return [[0]*N for _ in range(N)]

    # coerce to binary
    bin_vals = []
    for v in vals:
        try:
            xi = int(round(float(v)))
        except Exception:
            xi = 0
        bin_vals.append(1 if xi >= 1 else 0)

    total = N*N
    if len(bin_vals) < total:
        bin_vals += [0]*(total - len(bin_vals))
    else:
        bin_vals = bin_vals[:total]

    return [bin_vals[i*N:(i+1)*N] for i in range(N)]

# ---------------------------
# 3) Robust extractor for new SDK responses
# ---------------------------
def _extract_text_from_new(resp):
    # Preferred
    try:
        txt = resp.text
        if isinstance(txt, str) and txt.strip():
            return txt
    except Exception:
        pass
    # Walk candidates -> content -> parts
    try:
        cand_list = getattr(resp, "candidates", None) or []
        chunks = []
        for c in cand_list:
            content = getattr(c, "content", None)
            parts = getattr(content, "parts", None) if content is not None else None
            if parts:
                for p in parts:
                    t = getattr(p, "text", None)
                    if isinstance(t, str):
                        chunks.append(t)
        txt = "\n".join(chunks)
        if txt.strip():
            return txt
    except Exception:
        pass
    # Fallback: stringified object
    try:
        s = str(resp)
        if s.strip():
            return s
    except Exception:
        pass
    return ""

# ---------------------------
# 4) Low-level JSON call (new SDK → legacy fallback)
# ---------------------------
def llm_call_json(prompt: str, temperature: float = 0.0):
    """
    Returns (obj, latency_s, raw_text).
    Tries new SDK first with robust extraction; if empty or error, falls back to legacy SDK.
    """
    t0 = time.perf_counter()
    raw = ""

    # Try new SDK
    if _use_new and _client_new is not None and _types_new is not None:
        try:
            try:
                part = _types_new.Part.from_text(text=prompt)
            except Exception:
                part = _types_new.Part(text=prompt)
            try:
                cfg = _types_new.GenerateContentConfig(
                    response_mime_type="application/json",
                    thinking_config=_types_new.ThinkingConfig(thinking_budget=0),
                    temperature=temperature,
                )
            except Exception:
                cfg = _types_new.GenerateContentConfig(
                    response_mime_type="application/json",
                    thinking_config=_types_new.ThinkingConfig(thinking_budget=0),
                )
            resp = _client_new.models.generate_content(
                model=MODEL_NAME,
                contents=[_types_new.Content(role="user", parts=[part])],
                config=cfg,
            )
            raw = _extract_text_from_new(resp)
            txt = _strip_code_fences(raw)
            if txt:
                try:
                    obj = json.loads(txt)
                    return obj, (time.perf_counter()-t0), raw
                except Exception:
                    # fall through to legacy if JSON parse fails
                    pass
        except Exception:
            pass

    # Legacy fallback
    try:
        import google.generativeai as genai_old
        model = genai_old.GenerativeModel(MODEL_NAME)
        r = model.generate_content(prompt, generation_config={"temperature": temperature})
        raw = getattr(r, "text", "") or ""
        txt = _strip_code_fences(raw)
        obj = json.loads(txt)
        return obj, (time.perf_counter()-t0), raw
    except Exception:
        return None, (time.perf_counter()-t0), raw

# ---------------------------
# 5) High-level: get next_grid (strict or best-effort)
# ---------------------------
def call_next_grid(grid_now, temperature: float, strict: bool = None):
    """
    Build prompt, call Gemini, and return dict:
      {
        "next_grid": NxN list[list[int]],
        "latency_s": float,
        "ok": bool,
        "error": str|None,
        "reason": str|None,
        "raw": str
      }
    strict:
      - True  → require perfect NxN JSON, else ok=False
      - False → attempt best-effort repair to NxN
      - None  → default to STRICT_MODE if defined, else True
    """
    if strict is None:
        strict = globals().get("STRICT_MODE", True)

    prompt = build_agent_prompt(grid_now)
    obj, latency, raw = llm_call_json(prompt, temperature=temperature)

    out = {
        "next_grid": None,
        "latency_s": float(latency),
        "ok": False,
        "error": None,
        "reason": None,
        "raw": raw,
    }

    try:
        if obj is None:
            raise ValueError("no JSON")
        ng = obj.get("next_grid", None)
        reason = obj.get("reason", None) if isinstance(obj, dict) else None

        if strict:
            next_grid = _coerce_binary_matrix(ng, len(grid_now))
        else:
            try:
                next_grid = _coerce_binary_matrix(ng, len(grid_now))
            except Exception:
                next_grid = _repair_to_NxN(ng, len(grid_now))

        out.update(next_grid=next_grid, ok=True, reason=reason)
        return out
    except Exception as e:
        out.update(error=f"{type(e).__name__}: {e}")
        return out

print("✅ Cell 2 ready: Gemini caller rebuilt (robust extraction, legacy fallback, strict/repair modes).")


✅ Cell 2 ready: Gemini caller rebuilt (robust extraction, legacy fallback, strict/repair modes).


In [50]:
# Cell 3 — Run tracks (Classic + 5 Gemini temps), collect histories & metrics, persist

import json, math
import numpy as np
from copy import deepcopy
from pathlib import Path

# ---- sanity on required globals from prior cells ----
_required = [
    "N","STEPS","P_ALIVE","RNG_SEED","TRACKS","TEMPS","STRICT_MODE","OUT_DIR",
    "seed_mini_pack","life_step_finite","alive_density","perimeter","volatility","hamming",
    "call_next_grid","trk_name",
]
_missing = [k for k in _required if k not in globals()]
if _missing:
    raise RuntimeError(f"Missing definitions from earlier cells: {_missing}. Re-run Cells 0–2 first.")

# ---- 0) Legacy-name compatibility (if some code still tries g_t{:.1f}) ----
LEGACY_TO_CANON = {f"g_t{t:.1f}": trk_name(t) for t in TEMPS}

def canon_name(name: str) -> str:
    return LEGACY_TO_CANON.get(name, name)

# ---- 1) Seed once, clone across tracks ----
g0 = seed_mini_pack(n=N, rng_seed=RNG_SEED, p_alive=P_ALIVE)

histories = {trk: [deepcopy(g0)] for trk in TRACKS}  # keys are "classic" + g_t{:.2f}
metrics   = {
    trk: dict(
        step=[],
        density=[],
        perimeter=[],
        volatility=[],
        hamming_vs_classic=[],
        latency_s=[],
        ok=[],
        error=[],
    ) for trk in TRACKS
}

halted      = {trk: False for trk in TRACKS}
halt_reason = {trk: None  for trk in TRACKS}

# helpers
def _dens(g): return alive_density(g)
def _peri(g): return perimeter(g)
def _vol(prev, cur):
    b,d,flips = volatility(prev, cur)
    return b, d, flips

# ---- 2) Main loop over steps ----
for t in range(1, STEPS+1):
    # Classic evolves by its own physics (no LLM)
    prev_c = histories["classic"][-1]
    next_c = life_step_finite(prev_c)
    histories["classic"].append(next_c)

    rho = _dens(next_c); L = _peri(next_c)
    b,d,flips = _vol(prev_c, next_c)
    metrics["classic"]["step"].append(t)
    metrics["classic"]["density"].append(rho)
    metrics["classic"]["perimeter"].append(L)
    metrics["classic"]["volatility"].append(flips)
    metrics["classic"]["hamming_vs_classic"].append(0)   # by definition
    metrics["classic"]["latency_s"].append(0.0)
    metrics["classic"]["ok"].append(True)
    metrics["classic"]["error"].append(None)

    # Gemini tracks (100% agentic): next_grid comes from the model
    for temp in TEMPS:
        trk = trk_name(temp)  # <<< canonical name g_t{:.2f}
        prev_g = histories[trk][-1]

        if halted[trk]:
            # keep timeline aligned by repeating the last state
            histories[trk].append(prev_g)
            gt_from_prev = life_step_finite(prev_g)
            metrics[trk]["step"].append(t)
            metrics[trk]["density"].append(_dens(prev_g))
            metrics[trk]["perimeter"].append(_peri(prev_g))
            metrics[trk]["volatility"].append(0)
            metrics[trk]["hamming_vs_classic"].append(hamming(prev_g, gt_from_prev))
            metrics[trk]["latency_s"].append(0.0)
            metrics[trk]["ok"].append(False)
            metrics[trk]["error"].append(halt_reason[trk])
            continue

        # Call Gemini for full next grid
        res = call_next_grid(prev_g, temperature=temp, strict=STRICT_MODE)
        ok      = bool(res.get("ok", False))
        err_str = res.get("error")
        next_g  = res.get("next_grid") if ok else prev_g  # if invalid, freeze state

        histories[trk].append(next_g)

        # Metrics vs *classic-from-this-track's previous state*
        gt_from_prev = life_step_finite(prev_g)
        rho = _dens(next_g); L = _peri(next_g)
        b,d,flips = _vol(prev_g, next_g)
        ham = hamming(next_g, gt_from_prev)

        metrics[trk]["step"].append(t)
        metrics[trk]["density"].append(rho)
        metrics[trk]["perimeter"].append(L)
        metrics[trk]["volatility"].append(flips)
        metrics[trk]["hamming_vs_classic"].append(ham)
        metrics[trk]["latency_s"].append(float(res.get("latency_s", 0.0)))
        metrics[trk]["ok"].append(ok)
        metrics[trk]["error"].append(err_str)

        # Halt this track if strict mode and invalid JSON
        if (not ok) and STRICT_MODE:
            halted[trk] = True
            halt_reason[trk] = err_str

# ---- 3) Persist artifacts ----
OUT_DIR = Path(OUT_DIR)
for trk in TRACKS:
    (OUT_DIR / f"history_{trk}.json").write_text(json.dumps(histories[trk]))
    clean = {k: [float(x) if isinstance(x, (np.floating,)) else x for x in v]
             for k,v in metrics[trk].items()}
    (OUT_DIR / f"metrics_{trk}.json").write_text(json.dumps(clean))

(OUT_DIR / "halt_status.json").write_text(json.dumps({"halted": halted, "reason": halt_reason}, indent=2))

# ---- 4) Console summary ----
print("✅ Tracks complete.")
for trk in TRACKS:
    ok_count = sum(1 for x in metrics[trk]["ok"] if x)
    print(f"  • {trk:8s}  steps_ok={ok_count}/{STEPS}  halted={halted[trk]}  reason={halt_reason[trk]}")


✅ Tracks complete.
  • classic   steps_ok=50/50  halted=False  reason=None
  • g_t0.00   steps_ok=50/50  halted=False  reason=None
  • g_t0.25   steps_ok=50/50  halted=False  reason=None
  • g_t0.50   steps_ok=50/50  halted=False  reason=None
  • g_t0.75   steps_ok=50/50  halted=False  reason=None
  • g_t1.00   steps_ok=50/50  halted=False  reason=None


In [51]:
# Cell 4 — 2x3 portrait MP4 (Classic + 5 temps) + animated 3-chart review

import io, os, math, numpy as np, imageio.v2 as imageio, matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont

# ---------- sanity / fallbacks ----------
if "TEMPS" not in globals() or "TRACKS" not in globals():
    raise RuntimeError("Missing TEMPS/TRACKS. Run Cells 0–3 first (or 3R to restore).")
if "histories" not in globals() or "metrics" not in globals():
    raise RuntimeError("Missing histories/metrics. Run Cell 3 (or 3R to restore).")
if "VIDEO_W" not in globals() or "VIDEO_H" not in globals() or "FPS" not in globals():
    VIDEO_W, VIDEO_H, FPS = 1088, 1920, 10
if "TITLE" not in globals():
    TITLE = "Agentic Game of Life"

# ensure we have a canonical track name helper
if "trk_name" not in globals():
    def trk_name(t: float) -> str:
        return f"g_t{t:.2f}"

# ---------- fonts ----------
FONT_CANDIDATES = [
    "/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf",
    "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
    "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf",
    "/usr/share/fonts/truetype/ubuntu/UbuntuMono-R.ttf",
]
def load_ttf(size):
    for p in FONT_CANDIDATES:
        if os.path.exists(p):
            return ImageFont.truetype(p, size)
    return ImageFont.load_default()

TITLE_SIZE = 60  # bump if you want bigger
HUD_SIZE   = 32
FONT_TITLE = load_ttf(TITLE_SIZE)
FONT_HUD   = load_ttf(HUD_SIZE)

# ---------- colors ----------
def _hex_to_rgb(h):
    h = h.lstrip("#")
    return tuple(int(h[i:i+2], 16) for i in (0,2,4))

if "COLOR_CLASSIC" not in globals():
    COLOR_CLASSIC = "#60a5fa"   # blue
if "COLOR_T" not in globals():
    COLOR_T = {0.00:"#a78bfa", 0.25:"#22c55e", 0.50:"#fb923c", 0.75:"#ef4444", 1.00:"#ec4899"}

CLASSIC_COLOR = _hex_to_rgb(COLOR_CLASSIC)
TEMP_COLOR = {}
for t in TEMPS:
    ct = float(t)
    TEMP_COLOR[trk_name(ct)] = _hex_to_rgb(COLOR_T.get(ct, "#ef4444"))

BG = (10,10,12)
FG = (235,235,235)
BORDER = (190,190,190)

# ---------- geometry (2 columns x 3 rows) ----------
GAP = 24
tmp = Image.new("RGB", (1,1))
dd = ImageDraw.Draw(tmp)
tw, th = dd.textbbox((0,0), TITLE, font=FONT_TITLE)[2:]
TITLE_H = th + 40

# Decide N from histories
_some_trk = next(iter(histories.keys()))
N = len(histories[_some_trk][0])

ROWS, COLS = 3, 2
PANEL_W = (VIDEO_W - (COLS + 1) * GAP) // COLS
PANEL_H = (VIDEO_H - TITLE_H - (ROWS + 1) * GAP) // ROWS
PANEL   = min(PANEL_W, PANEL_H)
SCALE   = max(1, PANEL // N)
PANEL   = SCALE * N  # exact multiple for crisp nearest

def grid_pos(r, c):
    x = GAP + c * (PANEL + GAP)
    y = TITLE_H + GAP + r * (PANEL + GAP)
    return x, y

# Track placement (reading order, 2 per row):
ORDER = ["classic"] + [trk_name(t) for t in TEMPS]
# Keep exactly 6 for 2x3 layout; if more, take first 6
ORDER = ORDER[:6]
if len(ORDER) < 6:
    # If fewer than 6 tracks exist, pad duplicates of classic just to layout without crashing
    ORDER += ["classic"] * (6 - len(ORDER))

RC_MAP = [(0,0),(0,1),(1,0),(1,1),(2,0),(2,1)]  # row, col for each of 6 slots

# ---------- helpers ----------
def grid_to_panel(grid, alive=(245,245,245), dead=(20,20,24)):
    arr = np.array(grid, dtype=np.uint8)
    rgb = np.zeros((N,N,3), dtype=np.uint8)
    rgb[arr==1] = alive
    rgb[arr==0] = dead
    return Image.fromarray(rgb, "RGB").resize((PANEL, PANEL), Image.NEAREST)

def overlay_mismatches(img_rgba, agent_grid, gt_grid, color=(255,64,64,120)):
    ov = Image.new("RGBA", (PANEL,PANEL), (0,0,0,0))
    d  = ImageDraw.Draw(ov)
    for y in range(N):
        for x in range(N):
            if int(agent_grid[y][x]) != int(gt_grid[y][x]):
                x0, y0 = x*SCALE, y*SCALE
                d.rectangle([x0, y0, x0+SCALE-1, y0+SCALE-1], fill=color)
    img_rgba.alpha_composite(ov)

def draw_box_label(canvas, xy, text, color=BORDER, pad=10):
    d = ImageDraw.Draw(canvas)
    x,y = xy
    x0,y0,x1,y1 = d.textbbox((x,y), text, font=FONT_HUD)
    box = [x, y, x + (x1 - x0) + 2*pad, y + (y1 - y0) + 2*pad]
    d.rectangle(box, fill=(0,0,0,200))
    d.rectangle(box, outline=color, width=3)
    d.text((x+pad, y+pad), text, fill=FG, font=FONT_HUD)

def compose_mosaic_frame(step_idx, states, prev_states, per_meta):
    # For fair comparison, GT is classic step from each track's *own* previous state
    gt = {trk: life_step_finite(prev_states[trk]) for trk in ORDER}

    canvas = Image.new("RGB", (VIDEO_W, VIDEO_H), BG)
    d = ImageDraw.Draw(canvas)

    # title
    title = f"{TITLE} — step {step_idx:02d}"
    TW, TH = d.textbbox((0,0), title, font=FONT_TITLE)[2:]
    d.text(((VIDEO_W - TW)//2, (TITLE_H - TH)//2), title, fill=FG, font=FONT_TITLE)

    # panels
    for trk, (r,c) in zip(ORDER, RC_MAP):
        px, py = grid_pos(r, c)
        grid_now  = states[trk]
        prev_grid = prev_states[trk]

        panel = grid_to_panel(grid_now, alive=(245,245,245) if trk=="classic" else (232,232,232)).convert("RGBA")
        if trk != "classic":
            overlay_mismatches(panel, grid_now, gt[trk])

        canvas.paste(panel.convert("RGB"), (px,py))
        color = CLASSIC_COLOR if trk=="classic" else TEMP_COLOR.get(trk, (200,200,200))
        d.rectangle([px-2, py-2, px+PANEL+1, py+PANEL+1], outline=color, width=3)

        # HUD
        cur = np.array(grid_now, dtype=np.uint8)
        prv = np.array(prev_grid, dtype=np.uint8)
        rho = float(cur.mean())
        L   = perimeter(cur.tolist())
        flips = int((cur != prv).sum())
        births = int(np.logical_and(cur==1, prv==0).sum())
        deaths = int(np.logical_and(cur==0, prv==1).sum())
        hud = f"{trk}  ρ={rho:.3f}  L={L}  Δ(+{births}/-{deaths})"
        draw_box_label(canvas, (px, py-46), hud, color=color)

        if trk != "classic":
            meta   = per_meta.get(trk, {})
            lat_ms = meta.get("latency_ms", 0.0)
            ham    = meta.get("hamming", 0)
            status = "ok" if meta.get("ok", False) else "halted"
            extra  = f"lat={lat_ms:.0f}ms ham={ham}"
            draw_box_label(canvas, (px, py+PANEL+8), f"{status} {extra}", color=color)

    return np.array(canvas, dtype=np.uint8)

def _per_step_meta(t):
    info = {}
    for trk in ORDER:
        if trk == "classic":
            info[trk] = {"ok": True, "latency_ms": 0.0, "hamming": 0}
        else:
            idx = t-1
            ok  = bool(metrics[trk]["ok"][idx]) if idx < len(metrics[trk]["ok"]) else False
            lat = float(metrics[trk]["latency_s"][idx])*1000.0 if idx < len(metrics[trk]["latency_s"]) else 0.0
            ham = int(metrics[trk]["hamming_vs_classic"][idx]) if idx < len(metrics[trk]["hamming_vs_classic"]) else 0
            info[trk] = {"ok": ok, "latency_ms": lat, "hamming": ham}
    return info

# ---------- metric series for animated charts ----------
def binary_entropy(p):
    if p <= 0.0 or p >= 1.0:
        return 0.0
    return float(-(p*math.log2(p) + (1-p)*math.log2(1-p)))

PERI_MAX = 2 * N * (N - 1)
VOL_MAX  = N * N

def compute_series(histories, t_max):
    series = {trk: {"H": [], "Lnorm": [], "Vnorm": []} for trk in ORDER}
    for trk in ORDER:
        for t in range(1, t_max+1):
            cur  = np.array(histories[trk][t],   dtype=np.uint8)
            prev = np.array(histories[trk][t-1], dtype=np.uint8)
            p  = float(cur.mean())
            H  = binary_entropy(p)
            L  = perimeter(cur.tolist()) / PERI_MAX
            flips = int((cur != prev).sum())
            V  = flips / VOL_MAX
            series[trk]["H"].append(H)
            series[trk]["Lnorm"].append(L)
            series[trk]["Vnorm"].append(V)
    return series

def chart_frame(series, upto_t, width=VIDEO_W, height=VIDEO_H, sizes=None):
    """
    Clean light charts with small fonts and a single top legend.
    sizes (optional): dict(title,label,tick,legend,lw_classic,lw_agent)
    """
    import io, matplotlib.pyplot as plt
    import numpy as np

    # small, readable defaults (~20–30% of your previous large sizes)
    if sizes is None:
        sizes = dict(title=16, label=12, tick=10, legend=11,
                     lw_classic=2.2, lw_agent=2.0)

    plt.style.use("default")  # light theme

    with plt.rc_context({
        "font.size": sizes["tick"],
        "axes.titlesize": sizes["title"],
        "axes.labelsize": sizes["label"],
        "xtick.labelsize": sizes["tick"],
        "ytick.labelsize": sizes["tick"],
        "legend.fontsize": sizes["legend"],
        "axes.grid": True,
        "grid.alpha": 0.25,
    }):
        fig, axes = plt.subplots(
            3, 1,
            figsize=(width/200, height/200),  # compact figure
            dpi=160,
            constrained_layout=True
        )
        # leave headroom for global legend
        fig.subplots_adjust(top=0.88)

        metrics_order = [("H","Entropy"), ("Lnorm","Perimeter norm"), ("Vnorm","Volatility norm")]

        def _lstyle(temp):
            return {0.00:"solid", 0.25:"dashed", 0.50:"dashdot", 0.75:"dotted", 1.00:(0,(1,1))}.get(temp, "solid")

        handles, labels = [], []
        for ax_idx, (key, label) in enumerate(metrics_order):
            ax = axes[ax_idx]
            xs = np.arange(1, upto_t+1)

            for trk in ORDER:
                ys = series[trk][key][:upto_t]
                if trk == "classic":
                    (line,) = ax.plot(xs, ys, color=COLOR_CLASSIC, linewidth=sizes["lw_classic"], label="classic")
                else:
                    temp = float(trk.split("g_t")[-1])
                    (line,) = ax.plot(xs, ys, color=COLOR_T[temp], linestyle=_lstyle(temp),
                                      linewidth=sizes["lw_agent"], label=f"t={temp:.2f}")
                if ax_idx == 0:
                    handles.append(line); labels.append(line.get_label())

            ax.set_xlim(1, max(1, len(xs)))
            ax.set_ylim(0, 1)
            ax.set_ylabel(label)
            if ax_idx < 2:
                ax.tick_params(labelbottom=False)

        axes[-1].set_xlabel("step")

        # clean, single legend outside the plots
        fig.legend(handles, labels, ncol=3, frameon=False,
                   loc="upper center", bbox_to_anchor=(0.5, 0.995),
                   handlelength=2.6, borderaxespad=0.2, labelspacing=0.4)

        buf = io.BytesIO()
        plt.tight_layout()
        fig.savefig(buf, format="png")  # white background
        plt.close(fig)
        buf.seek(0)

    chart_img = Image.open(buf).convert("RGB")

    # paste on a white canvas (no dark background)
    canvas = Image.new("RGB", (VIDEO_W, VIDEO_H), (255,255,255))
    cw, ch = chart_img.size
    scale = min((VIDEO_W - 2*GAP)/cw, (VIDEO_H - 2*GAP)/ch)
    chart_img = chart_img.resize((int(cw*scale), int(ch*scale)), Image.BICUBIC)
    x = (VIDEO_W - chart_img.size[0]) // 2
    y = (VIDEO_H - chart_img.size[1]) // 2
    canvas.paste(chart_img, (x,y))
    return np.array(canvas, dtype=np.uint8)



# ---------- assemble video ----------
frames = []
T = min(len(histories[ORDER[0]])-1, *(len(histories[trk])-1 for trk in ORDER))

# Scene A — 2x3 mosaic evolution
for t in range(1, T+1):
    states = {trk: histories[trk][t]   for trk in ORDER}
    prevs  = {trk: histories[trk][t-1] for trk in ORDER}
    meta   = _per_step_meta(t)
    frames.append(compose_mosaic_frame(t, states, prevs, meta))

# Scene B — animated charts (~1s per step at FPS=10)
series = compute_series(histories, T)
for t in range(1, T+1):
    cf = chart_frame(series, t)
    for _ in range(FPS):
        frames.append(cf)

# write MP4
mp4_path = OUT_DIR / "agol_2x3_plus_charts.mp4"
with imageio.get_writer(
    mp4_path, fps=FPS, codec="libx264", macro_block_size=16,
    ffmpeg_params=["-pix_fmt","yuv420p","-movflags","+faststart","-crf","20","-preset","medium"]
) as w:
    for fr in frames:
        w.append_data(fr)


  return Image.fromarray(rgb, "RGB").resize((PANEL, PANEL), Image.NEAREST)
  fig.subplots_adjust(top=0.88)
  ax.set_xlim(1, max(1, len(xs)))
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight_layout()
  fig.subplots_adjust(top=0.88)
  plt.tight