Info set indexing helpers

In [1]:
import torch
import torch.nn as nn
import numpy as np
import random

device = torch.device("cpu")

# --------- infoset indexing ---------

# For player 1:
#  idx 0..2 : (card=0,1,2, history="")
#  idx 3..5 : (card=0,1,2, history="cb")
def p1_infoset_index(card, history):
    if history == "":
        return card
    elif history == "cb":
        return 3 + card
    else:
        raise ValueError(f"invalid p1 history {history}")

# For player 2:
#  idx 0..2 : (card=0,1,2, history="c")
#  idx 3..5 : (card=0,1,2, history="b")
def p2_infoset_index(card, history):
    if history == "c":
        return card
    elif history == "b":
        return 3 + card
    else:
        raise ValueError(f"invalid p2 history {history}")


2. Tabular policies

Each player: a vector of 6 logits → Bernoulli over “aggressive” action:

first decision: 0 = check, 1 = bet

second-layer decision: 0 = fold, 1 = call

In [2]:
class KuhnPolicy(nn.Module):
    def __init__(self, num_infosets=6):
        super().__init__()
        # one scalar logit per infoset
        self.logits = nn.Parameter(torch.zeros(num_infosets))

    def action_dist(self, infoset_indices):
        """
        infoset_indices: LongTensor [B] of infoset ids (0..5)
        returns Bernoulli distribution over action 1 (bet/call)
        """
        logits = self.logits[infoset_indices]  # [B]
        probs = torch.sigmoid(logits)
        return torch.distributions.Bernoulli(probs=probs)

    def action_and_logp(self, infoset_index):
        idx = torch.tensor([infoset_index], dtype=torch.long)
        dist = self.action_dist(idx)
        a = dist.sample()            # 0 or 1
        logp = dist.log_prob(a)      # [1]
        return int(a.item()), logp[0]

    def probs_for_all_infosets(self):
        with torch.no_grad():
            return torch.sigmoid(self.logits).cpu().numpy()


One episode of Kuhn Poker self-play

We simulate until terminal, collecting each player’s decisions:

In [3]:
def kuhn_deal():
    cards = [0, 1, 2]
    random.shuffle(cards)
    return cards[0], cards[1]  # (p1_card, p2_card)

def kuhn_episode(pi1: KuhnPolicy, pi2: KuhnPolicy):
    """
    Run one Kuhn Poker episode with self-play policies.
    Returns:
      p1_infos, p1_actions, p1_logps, p2_infos, p2_actions, p2_logps, r1, r2
    """
    p1_card, p2_card = kuhn_deal()
    history = ""

    # storage for decisions
    p1_infos, p1_actions, p1_logps = [], [], []
    p2_infos, p2_actions, p2_logps = [], [], []

    # P1 acts first
    # ---- P1 first decision ----
    i1 = p1_infoset_index(p1_card, history)
    a1, logp1 = pi1.action_and_logp(i1)   # 0=check, 1=bet
    p1_infos.append(i1)
    p1_actions.append(a1)
    p1_logps.append(logp1)

    if a1 == 0:  # check
        history = "c"
        # ---- P2 decision after check ----
        i2 = p2_infoset_index(p2_card, history)
        a2, logp2 = pi2.action_and_logp(i2)  # 0=check, 1=bet
        p2_infos.append(i2)
        p2_actions.append(a2)
        p2_logps.append(logp2)

        if a2 == 0:  # cc -> showdown, pot 2
            history = "cc"
            pot = 2
            if p1_card > p2_card:
                r1 = pot / 2
            else:
                r1 = -pot / 2
            r2 = -r1

        else:       # cb -> P1 acts (call/fold)
            history = "cb"
            i1_2 = p1_infoset_index(p1_card, history)
            a1_2, logp1_2 = pi1.action_and_logp(i1_2)  # 0=fold, 1=call
            p1_infos.append(i1_2)
            p1_actions.append(a1_2)
            p1_logps.append(logp1_2)

            if a1_2 == 0:  # cbf -> P1 folds, loses 1
                history = "cbf"
                r1 = -1.0
                r2 = 1.0
            else:          # cbc -> showdown, pot 4
                history = "cbc"
                pot = 4
                if p1_card > p2_card:
                    r1 = pot / 2
                else:
                    r1 = -pot / 2
                r2 = -r1

    else:  # a1==1 bet
        history = "b"
        # ---- P2 decision after bet ----
        i2 = p2_infoset_index(p2_card, history)
        a2, logp2 = pi2.action_and_logp(i2)  # 0=fold, 1=call
        p2_infos.append(i2)
        p2_actions.append(a2)
        p2_logps.append(logp2)

        if a2 == 0:  # bf -> P2 folds, P1 wins 1
            history = "bf"
            r1 = 1.0
            r2 = -1.0
        else:        # bc -> showdown, pot 4
            history = "bc"
            pot = 4
            if p1_card > p2_card:
                r1 = pot / 2
            else:
                r1 = -pot / 2
            r2 = -r1

    # convert lists -> tensors
    p1_infos = torch.tensor(p1_infos, dtype=torch.long)
    p1_actions = torch.tensor(p1_actions, dtype=torch.float32)
    p1_logps = torch.stack(p1_logps) if len(p1_logps) > 0 else torch.tensor([])

    p2_infos = torch.tensor(p2_infos, dtype=torch.long)
    p2_actions = torch.tensor(p2_actions, dtype=torch.float32)
    p2_logps = torch.stack(p2_logps) if len(p2_logps) > 0 else torch.tensor([])

    return p1_infos, p1_actions, p1_logps, p2_infos, p2_actions, p2_logps, r1, r2


Batch rollout for PPO

We collect many episodes into one batch:


In [4]:
def kuhn_rollout_batch(pi1, pi2, batch_size):
    """
    Collect `batch_size` episodes.
    Returns dict with trajectories for both players.
    """

    p1_infos, p1_actions, p1_logps = [], [], []
    p2_infos, p2_actions, p2_logps = [], [], []

    # per-decision returns (one scalar per decision taken)
    p1_returns_per_decision, p2_returns_per_decision = [], []

    # per-episode returns (for logging only)
    p1_rewards, p2_rewards = [], []

    for _ in range(batch_size):
        (i1, a1, lp1,
         i2, a2, lp2,
         r1, r2) = kuhn_episode(pi1, pi2)

        # --- player 1 ---
        if i1.numel() > 0:
            p1_infos.append(i1)
            p1_actions.append(a1)
            p1_logps.append(lp1)
            # repeat the episode return for each decision in this episode
            p1_returns_per_decision.append(
                torch.full_like(lp1, float(r1))
            )

        # --- player 2 ---
        if i2.numel() > 0:
            p2_infos.append(i2)
            p2_actions.append(a2)
            p2_logps.append(lp2)
            p2_returns_per_decision.append(
                torch.full_like(lp2, float(r2))
            )

        p1_rewards.append(r1)
        p2_rewards.append(r2)

    def cat_list(tensors):
        if len(tensors) == 0:
            return torch.tensor([], dtype=torch.float32)
        return torch.cat(tensors, dim=0)

    return {
        "p1_infos": cat_list(p1_infos).long(),
        "p1_actions": cat_list(p1_actions).float(),
        "p1_logps": cat_list(p1_logps).float(),
        "p1_returns_per_decision": cat_list(p1_returns_per_decision).float(),
        "p1_rewards": torch.tensor(p1_rewards, dtype=torch.float32),

        "p2_infos": cat_list(p2_infos).long(),
        "p2_actions": cat_list(p2_actions).float(),
        "p2_logps": cat_list(p2_logps).float(),
        "p2_returns_per_decision": cat_list(p2_returns_per_decision).float(),
        "p2_rewards": torch.tensor(p2_rewards, dtype=torch.float32),
    }


PPO Loss for Tabular Policy

In [5]:
def kuhn_ppo_loss(
    policy,
    infos,
    actions,
    logp_old,
    returns_per_decision,
    clip_eps=0.2,
    entropy_coef=0.01,
):
    """
    infos: [N] infoset indices
    actions: [N] (0/1)
    logp_old: [N]
    returns_per_decision: [N] Monte Carlo returns for each decision
    """
    if infos.numel() == 0:
        return torch.tensor(0.0, requires_grad=True)

    # advantages: centered + normalized per batch
    adv = returns_per_decision
    adv = adv - adv.mean()
    adv = adv / (adv.std() + 1e-8)

    dist = policy.action_dist(infos)
    logp = dist.log_prob(actions)

    ratio = torch.exp(logp - logp_old)
    surr1 = ratio * adv
    surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv
    pg_loss = -torch.min(surr1, surr2).mean()

    entropy = dist.entropy().mean()
    loss = pg_loss - entropy_coef * entropy
    return loss


Self-play training with modes (PPO / O-PPO-Last / EMA)

In [6]:
import torch.optim as optim

def train_kuhn(
    mode="ppo",          # "ppo", "oppo_last", "oppo_ema"
    num_iters=400,
    batch_size=512,
    lr=3e-4,
    omega=0.3,
    beta=0.6,
):
    pi1 = KuhnPolicy().to(device)
    pi2 = KuhnPolicy().to(device)

    opt1 = optim.SGD(pi1.parameters(), lr=lr)
    opt2 = optim.SGD(pi2.parameters(), lr=lr)

    prev_grads1 = [torch.zeros_like(p) for p in pi1.parameters()]
    prev_grads2 = [torch.zeros_like(p) for p in pi2.parameters()]
    mom1 = [torch.zeros_like(p) for p in pi1.parameters()]
    mom2 = [torch.zeros_like(p) for p in pi2.parameters()]

    p1_return_hist, p2_return_hist = [], []
    exploit_list = []

    for it in range(num_iters):
        traj = kuhn_rollout_batch(pi1, pi2, batch_size)

        # ------ Player 1 update ------
        opt1.zero_grad()
        loss1 = kuhn_ppo_loss(
            pi1,
            traj["p1_infos"],
            traj["p1_actions"],
            traj["p1_logps"],
            traj["p1_returns_per_decision"],
        )
        loss1.backward()

        if mode == "ppo":
            opt1.step()

        elif mode == "oppo_last":
            with torch.no_grad():
                cur_grads = [p.grad.detach().clone() if p.grad is not None
                             else torch.zeros_like(p)
                             for p in pi1.parameters()]
            for p, g_t, g_prev in zip(pi1.parameters(), cur_grads, prev_grads1):
                if p.grad is not None:
                    p.grad = (1.0 + omega) * g_t - omega * g_prev
            opt1.step()
            prev_grads1 = cur_grads

        elif mode == "oppo_ema":
            with torch.no_grad():
                cur_grads = [p.grad.detach().clone() if p.grad is not None
                             else torch.zeros_like(p)
                             for p in pi1.parameters()]
            new_mom = []
            for p, g_t, m_prev in zip(pi1.parameters(), cur_grads, mom1):
                m_t = (1.0 - beta) * g_t + beta * m_prev
                new_mom.append(m_t)
                if p.grad is not None:
                    p.grad = (1.0 + omega) * g_t - omega * m_t
            opt1.step()
            mom1 = new_mom

        # ------ Player 2 update (symmetric) ------
        opt2.zero_grad()
        loss2 = kuhn_ppo_loss(
            pi2,
            traj["p2_infos"],
            traj["p2_actions"],
            traj["p2_logps"],
            traj["p2_returns_per_decision"],
        )
        loss2.backward()

        if mode == "ppo":
            opt2.step()

        elif mode == "oppo_last":
            with torch.no_grad():
                cur_grads = [p.grad.detach().clone() if p.grad is not None
                             else torch.zeros_like(p)
                             for p in pi2.parameters()]
            for p, g_t, g_prev in zip(pi2.parameters(), cur_grads, prev_grads2):
                if p.grad is not None:
                    p.grad = (1.0 + omega) * g_t - omega * g_prev
            opt2.step()
            prev_grads2 = cur_grads

        elif mode == "oppo_ema":
            with torch.no_grad():
                cur_grads = [p.grad.detach().clone() if p.grad is not None
                             else torch.zeros_like(p)
                             for p in pi2.parameters()]
            new_mom = []
            for p, g_t, m_prev in zip(pi2.parameters(), cur_grads, mom2):
                m_t = (1.0 - beta) * g_t + beta * m_prev
                new_mom.append(m_t)
                if p.grad is not None:
                    p.grad = (1.0 + omega) * g_t - omega * m_t
            opt2.step()
            mom2 = new_mom

        # logging averages
        p1_return_hist.append(traj["p1_rewards"].mean().item())
        p2_return_hist.append(traj["p2_rewards"].mean().item())

        # exact exploitability
        exp, br1, br2, v = kuhn_exploitability(pi1, pi2)
        exploit_list.append(exp)

        if (it + 1) % 50 == 0:
            print(f"iter {it+1} | mode={mode} | "
                  f"E[R1]={p1_return_hist[-1]:.3f} | exploit={exp:.4f}")

    return {
        "p1_returns": p1_return_hist,
        "p2_returns": p2_return_hist,
        "exploitability": exploit_list,
        "pi1": pi1,
        "pi2": pi2,
    }


Running and plotting

**Exploitability**

In [7]:
def enumerate_pure_strategies(num_infosets=6):
    """
    Returns list of pure strategies.
    Each strategy is a list of length 6 of 0/1 actions.
    """
    strategies = []
    for i in range(2**num_infosets):
        s = [(i >> k) & 1 for k in range(num_infosets)]
        strategies.append(s)
    return strategies

PURE_STRATS = enumerate_pure_strategies(6)


In [8]:
def kuhn_exploitability(pi1, pi2):
    """
    Exact exploitability of (pi1, pi2) in Kuhn Poker.
    """
    v = kuhn_expected_value(pi1, pi2)
    br1 = best_response_value_to_p2(pi2)
    br2 = best_response_value_to_p1(pi1)

    exploit = 0.5 * ((br1 - v) + (v - br2))  # symmetric definition
    return exploit, br1, br2, v


In [9]:
def kuhn_payoff_given_actions(card1, card2, history):
    """Computes terminal payoff to player1 given history + showdown."""
    if history == "cc":
        pot = 2
        return pot/2 if card1 > card2 else -pot/2
    elif history == "cbc":
        pot = 4
        return pot/2 if card1 > card2 else -pot/2
    elif history == "bc":
        pot = 4
        return pot/2 if card1 > card2 else -pot/2
    elif history == "bf":   # P2 folded
        return 1
    elif history == "cbf":  # P1 folded
        return -1
    else:
        raise ValueError("invalid terminal history", history)


# def kuhn_expected_value(pi1, pi2, pure1=None, pure2=None):
#     """
#     pure1/pure2 are lists of 6 deterministic actions (0/1) or None.
#     If None, use the stochastic policies.
#     Returns expected value to player 1 under all 6 card deals.
#     """

#     total = 0.0
#     for card1 in [0,1,2]:
#         for card2 in [0,1,2]:
#             if card1 == card2:
#                 continue  # impossible deal
#             p = 1/6  # uniform over 6 permutations

#             # ----- P1 first decision -----
#             i1 = p1_infoset_index(card1, "")
#             if pure1:
#                 a1 = pure1[i1]
#                 logp_dummy = None
#             else:
#                 a1, _ = pi1.action_and_logp(i1)

#             if a1 == 0:
#                 # P1 check
#                 # ----- P2 decision -----
#                 i2 = p2_infoset_index(card2, "c")
#                 if pure2:
#                     a2 = pure2[i2]
#                 else:
#                     a2, _ = pi2.action_and_logp(i2)

#                 if a2 == 0:
#                     payoff = kuhn_payoff_given_actions(card1, card2, "cc")
#                 else:
#                     # P2 bets, P1 responds
#                     i1b = p1_infoset_index(card1, "cb")
#                     if pure1:
#                         a1b = pure1[i1b]
#                     else:
#                         a1b, _ = pi1.action_and_logp(i1b)

#                     if a1b == 0:
#                         payoff = kuhn_payoff_given_actions(card1, card2, "cbf")
#                     else:
#                         payoff = kuhn_payoff_given_actions(card1, card2, "cbc")

#             else:
#                 # P1 bet
#                 i2 = p2_infoset_index(card2, "b")
#                 if pure2:
#                     a2 = pure2[i2]
#                 else:
#                     a2, _ = pi2.action_and_logp(i2)

#                 if a2 == 0:
#                     payoff = kuhn_payoff_given_actions(card1, card2, "bf")
#                 else:
#                     payoff = kuhn_payoff_given_actions(card1, card2, "bc")

#             total += p * payoff

#     return total
def kuhn_prob_action1_p1(pi1, card, history):
    # probability that P1 takes action 1 (bet/call) at (card, history)
    idx = p1_infoset_index(card, history)
    with torch.no_grad():
        logit = pi1.logits[idx].item()
    return 1.0 / (1.0 + np.exp(-logit))  # sigmoid


def kuhn_prob_action1_p2(pi2, card, history):
    idx = p2_infoset_index(card, history)
    with torch.no_grad():
        logit = pi2.logits[idx].item()
    return 1.0 / (1.0 + np.exp(-logit))


def kuhn_expected_value(pi1, pi2, pure1=None, pure2=None):
    """
    Exact expected value to player 1 under (pi1, pi2), optionally with pure best
    responses overriding one side.

    pure1 / pure2: list of length 6 with 0/1 actions per infoset, or None.
    """

    def prob_p1_action1(card, history):
        idx = p1_infoset_index(card, history)
        if pure1 is not None:
            # deterministic: either always 0 or always 1
            return float(pure1[idx])
        else:
            return kuhn_prob_action1_p1(pi1, card, history)

    def prob_p2_action1(card, history):
        idx = p2_infoset_index(card, history)
        if pure2 is not None:
            return float(pure2[idx])
        else:
            return kuhn_prob_action1_p2(pi2, card, history)

    total = 0.0
    # all 6 valid deals, each with prob 1/6
    for card1 in [0,1,2]:
        for card2 in [0,1,2]:
            if card1 == card2:
                continue
            p_deal = 1.0 / 6.0

            # P1 at root: history ""
            p1_bet = prob_p1_action1(card1, "")
            p1_check = 1.0 - p1_bet

            # If P1 checks: history "c"
            p2_bet_after_c = prob_p2_action1(card2, "c")
            p2_check_after_c = 1.0 - p2_bet_after_c

            # If P2 bets after check: history "cb"
            p1_call_after_cb = prob_p1_action1(card1, "cb")
            p1_fold_after_cb = 1.0 - p1_call_after_cb

            # If P1 bets at root: history "b"
            p2_call_after_b = prob_p2_action1(card2, "b")
            p2_fold_after_b = 1.0 - p2_call_after_b

            # ----- enumerate all terminal paths -----
            # 1) P1 check, P2 check: "cc"
            prob_cc = p_deal * p1_check * p2_check_after_c
            payoff_cc = kuhn_payoff_given_actions(card1, card2, "cc")

            # 2) P1 check, P2 bet, P1 fold: "cbf"
            prob_cbf = p_deal * p1_check * p2_bet_after_c * p1_fold_after_cb
            payoff_cbf = kuhn_payoff_given_actions(card1, card2, "cbf")

            # 3) P1 check, P2 bet, P1 call: "cbc"
            prob_cbc = p_deal * p1_check * p2_bet_after_c * p1_call_after_cb
            payoff_cbc = kuhn_payoff_given_actions(card1, card2, "cbc")

            # 4) P1 bet, P2 fold: "bf"
            prob_bf = p_deal * p1_bet * p2_fold_after_b
            payoff_bf = kuhn_payoff_given_actions(card1, card2, "bf")

            # 5) P1 bet, P2 call: "bc"
            prob_bc = p_deal * p1_bet * p2_call_after_b
            payoff_bc = kuhn_payoff_given_actions(card1, card2, "bc")

            total += (
                prob_cc * payoff_cc
                + prob_cbf * payoff_cbf
                + prob_cbc * payoff_cbc
                + prob_bf * payoff_bf
                + prob_bc * payoff_bc
            )

    return total


In [10]:
def best_response_value_to_p2(pi2):
    """
    Returns max_{pure strategies} u1(pure1, pi2)
    """
    best = -999
    for pure in PURE_STRATS:
        val = kuhn_expected_value(None, pi2, pure1=pure, pure2=None)
        if val > best:
            best = val
    return best


def best_response_value_to_p1(pi1):
    """
    Returns min_{pure strategies} u1(pi1, pure2)
    (because p2 wants to minimize p1's value)
    """
    worst = 999
    for pure in PURE_STRATS:
        val = kuhn_expected_value(pi1, None, pure1=None, pure2=pure)
        if val < worst:
            worst = val
    return worst


In [None]:
results_ppo = train_kuhn(
    mode="ppo",
    num_iters=400,
    batch_size=512,
    lr=3e-4,
)

results_last = train_kuhn(
    mode="oppo_last",
    num_iters=400,
    batch_size=512,
    lr=2e-4,   # slightly smaller for optimistic
    omega=0.3,
)

results_ema = train_kuhn(
    mode="oppo_ema",
    num_iters=400,
    batch_size=512,
    lr=1e-4,
    omega=0.3,
    beta=0.6,
)


import matplotlib.pyplot as plt
iters = range(400)

plt.figure(figsize=(7,4))
plt.plot(iters, results_ppo["p1_returns"], label="PPO p1")
plt.plot(iters, results_last["p1_returns"], label="O-PPO-Last p1")
plt.plot(iters, results_ema["p1_returns"], label="O-PPO-EMA p1")
plt.xlabel("iteration")
plt.ylabel("avg return (player 1)")
plt.legend()
plt.title("Kuhn Poker: average return of player 1")
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# basic line plot
iters = range(len(results_ppo["exploitability"]))

plt.figure(figsize=(7,4))
plt.plot(iters, results_ppo["exploitability"], label="PPO")
plt.plot(iters, results_last["exploitability"], label="O-PPO-Last")
plt.plot(iters, results_ema["exploitability"], label="O-PPO-EMA")
plt.xlabel("iteration")
plt.ylabel("exploitability")
plt.title("Kuhn Poker: exploitability over time")
plt.legend()
plt.tight_layout()
plt.show()
