# Week 11 - Coalgebraic RL with GT + Diagrammatic Backprop (DB)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sridharmahadevan/Category-Theory-for-AGI-UMass-CMPSCI-692CT/blob/main/notebooks/week11_coalgebra_gtdb_rl_demo.ipynb)

This notebook is a compact experiment-first demo for **coalgebras and coinduction** in RL:

- Environment as a coalgebra-like transition structure over states
- Three learning arms: **MLP**, **GT**, **GT+DB**
- Harder synthetic MDP (random starts, stochastic slips, reward noise)
- Compare sample efficiency and final success


In [None]:
# Environment
import random
from dataclasses import dataclass
from statistics import mean

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)


In [None]:
@dataclass(frozen=True)
class MDPConfig:
    n_states: int = 32
    start_state: int = 0
    goal_state: int = 31
    max_steps: int = 36
    step_penalty: float = -0.01
    goal_reward: float = 1.0
    random_start: bool = True
    slip_prob: float = 0.10
    reward_noise_std: float = 0.05

class LineWorldMDP:
    def __init__(self, cfg: MDPConfig):
        self.cfg = cfg
        self.state = cfg.start_state
        self.steps = 0

    def reset(self):
        self.steps = 0
        if self.cfg.random_start:
            self.state = random.randrange(self.cfg.n_states)
        else:
            self.state = self.cfg.start_state
        return self.state

    def step(self, action):
        self.steps += 1
        a = int(action)
        if random.random() < self.cfg.slip_prob:
            a = 1 - a

        if a == 0:
            ns = max(0, self.state - 1)
        else:
            ns = min(self.cfg.n_states - 1, self.state + 1)

        r = self.cfg.step_penalty
        done = False
        if ns == self.cfg.goal_state:
            r = self.cfg.goal_reward
            done = True
        if self.steps >= self.cfg.max_steps:
            done = True

        r += random.gauss(0.0, self.cfg.reward_noise_std)
        self.state = ns
        return ns, float(r), bool(done), {'is_goal': int(ns == self.cfg.goal_state)}

def build_transition_relation(n_states):
    rel = torch.zeros((n_states, n_states), dtype=torch.float32)
    for s in range(n_states):
        l = max(0, s - 1)
        r = min(n_states - 1, s + 1)
        if l != s:
            rel[s, l] = 1.0
        if r != s:
            rel[s, r] = 1.0
    return rel


In [None]:
class GTQ(nn.Module):
    def __init__(self, n_states, n_actions=2, hidden_dim=64):
        super().__init__()
        self.n_states = n_states
        self.emb = nn.Embedding(n_states, hidden_dim)
        self.msg = nn.Linear(hidden_dim, hidden_dim)
        self.head = nn.Linear(hidden_dim, n_actions)

    def q_values(self, state, rel):
        idx = torch.arange(self.n_states, device=rel.device)
        h = self.emb(idx)
        deg = rel.sum(dim=1, keepdim=True).clamp(min=1.0)
        agg = (rel @ h) / deg
        h2 = F.relu(h + self.msg(agg))
        q = self.head(h2)
        return q[state], h2

    def db_loss(self, h_all, rel):
        src, dst = torch.where(rel > 0.0)
        if src.numel() == 0:
            return h_all.new_zeros(())
        d = h_all[src] - h_all[dst]
        return (d.pow(2).sum(dim=1)).mean()

class MLPQ(nn.Module):
    def __init__(self, n_states, n_actions=2, hidden_dim=64):
        super().__init__()
        self.n_states = n_states
        self.fc1 = nn.Linear(n_states, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, n_actions)

    def q_values(self, state, rel):
        _ = rel
        x = torch.zeros((1, self.n_states), device=self.fc1.weight.device)
        x[0, state] = 1.0
        h = F.relu(self.fc1(x))
        q = self.fc2(h).squeeze(0)
        return q, h.repeat(self.n_states, 1)

    def db_loss(self, h_all, rel):
        _ = h_all
        _ = rel
        return torch.tensor(0.0, device=self.fc1.weight.device)


In [None]:
def train_one(model_name='gtdb', seed=0, episodes=180, gamma=0.99, lr=1e-3, eps0=0.2, eps_min=0.02, eps_decay=0.995, db_coef=0.1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cfg = MDPConfig()
    env = LineWorldMDP(cfg)
    rel = build_transition_relation(cfg.n_states).to(device)

    if model_name == 'mlp':
        model = MLPQ(cfg.n_states).to(device)
        db_coef_eff = 0.0
    else:
        model = GTQ(cfg.n_states).to(device)
        db_coef_eff = db_coef if model_name == 'gtdb' else 0.0

    opt = torch.optim.Adam(model.parameters(), lr=lr)

    ep_returns, ep_success, ep_td, ep_db = [], [], [], []
    eps = eps0

    for ep in range(episodes):
        s = env.reset()
        ret = 0.0
        td_hist, db_hist = [], []
        got_goal = 0

        for _ in range(cfg.max_steps):
            q_s, h_all = model.q_values(s, rel)
            if random.random() < eps:
                a = random.choice([0, 1])
            else:
                a = int(torch.argmax(q_s).item())

            ns, r, done, info = env.step(a)
            q_sa = q_s[a]

            with torch.no_grad():
                q_ns, _ = model.q_values(ns, rel)
                target = torch.tensor(r, device=device) + (0.0 if done else gamma * torch.max(q_ns))

            td = F.mse_loss(q_sa, target)
            db = model.db_loss(h_all, rel)
            loss = td + db_coef_eff * db

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            td_hist.append(float(td.item()))
            db_hist.append(float(db.item()))
            ret += r
            s = ns
            got_goal = max(got_goal, int(info['is_goal']))
            if done:
                break

        ep_returns.append(ret)
        ep_success.append(float(got_goal))
        ep_td.append(mean(td_hist) if td_hist else 0.0)
        ep_db.append(mean(db_hist) if db_hist else 0.0)
        eps = max(eps_min, eps * eps_decay)

    return {
        'model': model_name,
        'seed': seed,
        'return': ep_returns,
        'success': ep_success,
        'td': ep_td,
        'db': ep_db,
    }

def moving_avg(xs, w=20):
    out, buf, s = [], [], 0.0
    for x in xs:
        buf.append(x); s += x
        if len(buf) > w:
            s -= buf.pop(0)
        out.append(s / len(buf))
    return out

def episodes_to_target(success_series, target=0.8, window=20):
    ma = moving_avg(success_series, window)
    for i, v in enumerate(ma):
        if v >= target:
            return i + 1
    return len(success_series) + 1

def auc(y):
    if len(y) < 2:
        return 0.0
    return float(sum(0.5 * (y[i-1] + y[i]) for i in range(1, len(y))))


In [None]:
# Run a compact ablation
MODELS = ['gtdb', 'gt', 'mlp']
SEEDS = list(range(6))
EPISODES = 180

runs = []
for m in MODELS:
    for s in SEEDS:
        runs.append(train_one(model_name=m, seed=s, episodes=EPISODES))

print('completed runs:', len(runs))


In [None]:
# Aggregate metrics and simple table
summary = {}
for m in MODELS:
    rs = [r for r in runs if r['model'] == m]
    auc_vals = [auc(r['return']) for r in rs]
    final_success_vals = [mean(r['success'][-max(1, int(0.2*len(r['success']))):]) for r in rs]
    ett_vals = [episodes_to_target(r['success'], target=0.8, window=20) for r in rs]
    reached = [1.0 if v <= EPISODES else 0.0 for v in ett_vals]
    summary[m] = {
        'auc_mean': mean(auc_vals),
        'final_success_mean': mean(final_success_vals),
        'episodes_to_target_mean': mean(ett_vals),
        'reached_target_rate': mean(reached),
    }

for m in MODELS:
    s = summary[m]
    print(f"{m:>4} | AUC={s['auc_mean']:.2f} | final_success={s['final_success_mean']:.3f} | ETT={s['episodes_to_target_mean']:.1f} | reached={s['reached_target_rate']:.3f}")


In [None]:
# Plot return and success convergence
x = np.arange(1, EPISODES + 1)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

for m in MODELS:
    rs = [r for r in runs if r['model'] == m]
    ret_stack = np.array([moving_avg(r['return'], 20) for r in rs])
    suc_stack = np.array([moving_avg(r['success'], 20) for r in rs])

    ret_mean, ret_std = ret_stack.mean(axis=0), ret_stack.std(axis=0)
    suc_mean, suc_std = suc_stack.mean(axis=0), suc_stack.std(axis=0)

    axes[0].plot(x, ret_mean, label=m)
    axes[0].fill_between(x, ret_mean-ret_std, ret_mean+ret_std, alpha=0.15)

    axes[1].plot(x, suc_mean, label=m)
    axes[1].fill_between(x, suc_mean-suc_std, suc_mean+suc_std, alpha=0.15)

axes[0].set_title('Rolling Return (window=20)')
axes[0].set_xlabel('Episode'); axes[0].set_ylabel('Return')
axes[0].grid(True, alpha=0.3); axes[0].legend()

axes[1].set_title('Rolling Success Rate (window=20)')
axes[1].set_xlabel('Episode'); axes[1].set_ylabel('Success')
axes[1].set_ylim(-0.05, 1.05)
axes[1].grid(True, alpha=0.3); axes[1].legend()

plt.tight_layout()
plt.show()


## Interpretation

If **GT+DB** improves episodes-to-target or AUC under this harder stochastic setting, the explanation is language-level: DB enforces invariants from declared transition structure, reducing geometry-inconsistent solutions during RL optimization.

This is the key course message for coalgebras/coinduction in practical ML:

- coalgebra gives the behavior type,
- GT gives representation over that structure,
- DB makes the optimization respect the declared diagram.
