# 1. Gomoku

In [20]:
import numpy as np

class Gomoku:
    """
    N×N board, players in {+1 (black), -1 (white)}, empty=0.
    Win condition: 5 in a row (any direction).
    """
    def __init__(self, N=9, win_len=5):
        self.N = int(N)
        self.win_len = int(win_len)
        self.reset()

    def reset(self):
        self.board = np.zeros((self.N, self.N), dtype=np.int8)
        self.current_player = 1
        self.done = False
        self.last_move = None  # (r,c) or None
        return self.get_state()  # canonical 1×(N*N) if you want; see below

    # --- Canonical features for NN (3 planes): me, opp, to-move ---
    def to_planes(self):
        me  = (self.board == self.current_player).astype(np.float32)
        opp = (self.board == -self.current_player).astype(np.float32)
        turn = np.full_like(me, 1.0, dtype=np.float32)  # “to move” plane
        # shape: [3, N, N]
        return np.stack([me, opp, turn], axis=0)

    # For drop-in compatibility with your older code:
    def get_state(self):
        # vectorized canonical state (board * current_player) like tic-tac-toe
        return (self.board.flatten() * self.current_player).astype(np.float32)

    def get_legal_actions(self):
        # actions are 0..N*N-1
        flat = self.board.reshape(-1)
        return [i for i in range(self.N * self.N) if flat[i] == 0]

    def step(self, action):
        """Apply action index (0..N^2-1). Returns (state, reward, done)."""
        if self.done:
            return self.get_state(), 0.0, True
        r, c = divmod(int(action), self.N)
        if self.board[r, c] != 0:
            # illegal move penalty (AlphaZero typically forbids via masking;
            # but keep this as a safeguard)
            self.done = True
            return self.get_state(), -1.0, True

        self.board[r, c] = self.current_player
        self.last_move = (r, c)

        reward, done = self.check_winner_fast()
        self.done = done
        state = self.get_state()
        self.current_player *= -1
        return state, float(reward), bool(done)

    # --------- winner checks ----------
    def check_winner_fast(self):
        """
        Check only the lines passing through self.last_move for speed.
        Returns (reward_for_player_who_just_moved, done)
        reward ∈ {+1 win, 0 draw/ongoing, -1 loss} relative to the mover.
        """
        if self.last_move is None:
            return 0.0, False
        r, c = self.last_move
        p = self.board[r, c]
        if p == 0:
            return 0.0, False

        if ( self._count_dir(r, c, 1, 0, p) + self._count_dir(r, c, -1, 0, p) - 1 >= self.win_len or
             self._count_dir(r, c, 0, 1, p) + self._count_dir(r, c, 0, -1, p) - 1 >= self.win_len or
             self._count_dir(r, c, 1, 1, p) + self._count_dir(r, c, -1, -1, p) - 1 >= self.win_len or
             self._count_dir(r, c, 1, -1, p) + self._count_dir(r, c, -1, 1, p) - 1 >= self.win_len ):
            # player p (who just moved) wins: reward +1 for mover
            return 1.0, True

        if (self.board != 0).all():
            return 0.0, True  # draw

        return 0.0, False

    def _count_dir(self, r, c, dr, dc, p):
        """Count contiguous stones with color p starting at (r,c) inclusive along (dr,dc)."""
        N = self.N
        cnt = 0
        rr, cc = r, c
        while 0 <= rr < N and 0 <= cc < N and self.board[rr, cc] == p:
            cnt += 1
            rr += dr
            cc += dc
        return cnt

    def num_moves(self):
        """Return the number of moves played so far."""
        # If you track move history: return len(self.move_history)
        # If not, just count non-zero cells:
        return int((self.board != 0).sum())
    
    # --------- convenience ----------
    def render_ascii(self):
        sym = {1:'X', -1:'O', 0:'.'}
        print("\n".join(" ".join(sym[v] for v in row) for row in self.board))    

# 2. ResNet Policy-Value Model

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(channels)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        y = self.bn2(self.conv2(y))
        return F.relu(x + y)

class PVResNet(nn.Module):
    """
    Policy–Value net for Gomoku.
    Input: [B, 3, N, N]
    Outputs:
      - policy_logits: [B, N*N]
      - value:         [B, 1]  (tanh)
    """
    def __init__(self, board_size=9, channels=64, n_blocks=6):
        super().__init__()
        self.N = board_size
        C = channels

        self.stem = nn.Sequential(
            nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(C),
            nn.ReLU(inplace=True)
        )
        self.trunk = nn.Sequential(*[ResidualBlock(C) for _ in range(n_blocks)])

        # Policy head
        self.p_head = nn.Sequential(
            nn.Conv2d(C, 2, kernel_size=1, bias=False),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True)
        )
        self.p_fc = nn.Linear(2 * self.N * self.N, self.N * self.N)

        # Value head
        self.v_head = nn.Sequential(
            nn.Conv2d(C, 1, kernel_size=1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )
        self.v_fc1 = nn.Linear(1 * self.N * self.N, C)
        self.v_fc2 = nn.Linear(C, 1)

        # (optional) init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                nn.init.zeros_(m.bias)

    def forward(self, x):  # x: [B, 3, N, N]
        z = self.stem(x)
        z = self.trunk(z)

        # policy
        p = self.p_head(z)
        p = p.view(p.size(0), -1)
        policy_logits = self.p_fc(p)  # [B, N*N]

        # value
        v = self.v_head(z)
        v = v.view(v.size(0), -1)
        v = F.relu(self.v_fc1(v))
        value = torch.tanh(self.v_fc2(v))  # [-1, 1]

        return policy_logits, value

# 3. Monte Carlo Tree Search (MCTS)

In [22]:
import numpy as np
import torch
import torch.nn.functional as F

def gomoku_planes_from_raw(raw_board_tuple, player, N):
    """Build [3, N, N] planes from (raw_board, player_to_move)."""
    b = np.array(raw_board_tuple, dtype=np.int8).reshape(N, N)
    me  = (b == player).astype(np.float32)
    opp = (b == -player).astype(np.float32)
    turn = np.ones_like(me, dtype=np.float32)
    return np.stack([me, opp, turn], axis=0)

def legal_actions_from_raw(raw_board_tuple, N):
    b = np.array(raw_board_tuple, dtype=np.int8).reshape(N, N)
    return [i for i in range(N*N) if b.flat[i] == 0]

def apply_action_raw(raw_board_tuple, action, player, N):
    b = np.array(raw_board_tuple, dtype=np.int8).reshape(N, N).copy()
    r, c = divmod(int(action), N)
    b[r, c] = player
    return tuple(b.reshape(-1))

def check_terminal_gomoku(raw_board_tuple, N, win_len=5):
    """Return (value_for_current_node_player, done). Full scan (O(N^2)), fine for 9x9."""
    b = np.array(raw_board_tuple, dtype=np.int8).reshape(N, N)

    def has_k(p):
        # rows
        for r in range(N):
            run = 0
            for c in range(N):
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
        # cols
        for c in range(N):
            run = 0
            for r in range(N):
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
        # diag \
        for r0 in range(N):
            run = 0
            r, c = r0, 0
            while r < N and c < N:
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
                r += 1; c += 1
        for c0 in range(1, N):
            run = 0
            r, c = 0, c0
            while r < N and c < N:
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
                r += 1; c += 1
        # diag /
        for r0 in range(N):
            run = 0
            r, c = r0, 0
            while r >= 0 and c < N:
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
                r -= 1; c += 1
        for c0 in range(1, N):
            run = 0
            r, c = N-1, c0
            while r >= 0 and c < N:
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
                r -= 1; c += 1
        return False

    if has_k(+1): return (+1.0, True)
    if has_k(-1): return (-1.0, True)
    if (b != 0).all(): return (0.0, True)
    return (0.0, False)

# Opening bias
def center_bias(N, sigma=0.35):
    """Return normalized N*N center-biased prior (Gaussian on [0,1]^2 grid)."""
    xs = np.linspace(0, 1, N)
    X, Y = np.meshgrid(xs, xs, indexing='ij')
    cx = cy = 0.5
    g = np.exp(-((X-cx)**2 + (Y-cy)**2) / (2*(sigma**2)))
    v = g.reshape(-1).astype(np.float32)
    v /= v.sum() + 1e-8
    return v

def mix_opening_bias(node, N, bias_vec, eps=0.25):
    """Blend bias_vec into child priors at the root (legal only)."""
    if eps is None or eps <= 0.0:
        return
    legal = list(node.children.keys())
    if not legal: return
    b = bias_vec.copy()
    # renorm on legal only
    s = b[legal].sum()
    b = b / (s + 1e-8)
    for a in legal:
        p = node.children[a].prior
        node.children[a].prior = (1 - eps) * p + eps * float(b[a])

In [34]:
# ---- Node ----
class Node:
    def __init__(self, raw_board_tuple, player_to_move):
        self.state = raw_board_tuple          # tuple of length N*N (0/±1)
        self.player = int(player_to_move)     # +1 or -1
        self.children = {}                    # action -> Node
        self.visit_count = 0
        self.total_value = 0.0                # mean value from this node's player perspective
        self.prior = 0.0

    def value(self):
        return 0.0 if self.visit_count == 0 else self.total_value / self.visit_count

# ---- MCTS ----
class MCTS:
    def __init__(self, model, simulations=800, c_puct=2.0, win_len=5):
        self.model = model
        self.simulations = simulations
        self.c_puct = c_puct
        self.win_len = win_len
        self.ttable = {}   # (state_tuple, player, N) -> (priors[N*N], value_float)
        self.root = None

    # ---------- public API ----------
    def run(self, env, add_root_noise=True, dir_alpha=0.3, dir_eps=0.25,
            opening_bias=None, opening_eps=0.25):
        N = env.N
        self.sync_to_env(env)
    
        if add_root_noise:
            self._add_dirichlet_noise(self.root, N, alpha=dir_alpha, eps=dir_eps)
    
        if opening_bias is not None:
            # was: mix_opening_bias(self.root, N, opening_eps)
            mix_opening_bias(self.root, N, opening_bias, eps=opening_eps)
    
        for _ in range(self.simulations):
            self._simulate(self.root, N)
        return self.root

    def advance(self, action, N):
        if self.root and action in self.root.children:
            self.root = self.root.children[action]
        else:
            base_state = self.root.state if self.root else tuple([0]*(N*N))
            base_player = self.root.player if self.root else +1
            child_board = apply_action_raw(base_state, int(action), base_player, N)
            self.root = Node(child_board, -base_player)
            if not self.root.children:
                _ = self._expand_and_evaluate(self.root, N)

    def sync_to_env(self, env):
        N = env.N
        state = tuple(env.board.reshape(-1))
        player = env.current_player

        if self.root and self.root.state == state and self.root.player == player:
            return
        if self.root:
            for ch in self.root.children.values():
                if ch.state == state and ch.player == player:
                    self.root = ch
                    return
        self.root = Node(state, player)
        if not self.root.children:
            _ = self._expand_and_evaluate(self.root, N)

    # ---------- core search ----------
    def _simulate(self, node, N):
        winner, done = check_terminal_gomoku(node.state, N, self.win_len)
        if done:
            if winner == 0:
                v = 0.0
            elif winner == node.player:
                v = +1.0
            else:
                v = -1.0
            node.visit_count += 1
            node.total_value += v
            return v

        # short-circuit immediate win
        can_win, win_a = self.has_immediate_win(node.state, node.player, N, self.win_len)
        if can_win:
            v = +1.0
            node.visit_count += 1
            node.total_value += v
            if win_a not in node.children:
                nb = apply_action_raw(node.state, win_a, node.player, N)
                ch = Node(nb, -node.player)
                ch.prior = 1.0
                node.children[win_a] = ch
            child = node.children[win_a]
            child.visit_count += 1
            child.total_value  += -v   # from child’s perspective, value flips sign    
            return v

        if node.visit_count == 0 and not node.children:
            v = self._expand_and_evaluate(node, N)
            node.visit_count += 1
            node.total_value += v
            return v

        best_a, best_score = None, -1e9
        sqrt_N = np.sqrt(node.visit_count + 1e-8)
        for a, child in node.children.items():
            ucb = child.value() + self.c_puct * child.prior * (sqrt_N / (1 + child.visit_count))
            if ucb > best_score:
                best_score, best_a = ucb, a

        child = node.children[best_a]
        v_child = self._simulate(child, N)
        v = -v_child
        node.visit_count += 1
        node.total_value += v
        return v

    def _expand_and_evaluate(self, node, N):
        key = (node.state, node.player, N)
        if key in self.ttable:
            priors, v = self.ttable[key]
        else:
            planes = gomoku_planes_from_raw(node.state, node.player, N).astype(np.float32)
            st = torch.from_numpy(planes).unsqueeze(0)
            with torch.no_grad():
                logits, value = self.model(st)
                v = float(value.item())

            logits_np = logits.squeeze(0).detach().cpu().numpy()
            legal = legal_actions_from_raw(node.state, N)

            priors = np.zeros(N*N, dtype=np.float32)
            if legal:
                ll = logits_np[legal]
                ll -= ll.max()
                p_legal = np.exp(ll)
                p_legal /= (p_legal.sum() + 1e-8)
                priors[legal] = p_legal

            # zero-out moves that allow opponent immediate win
            suicidal = []
            for a in legal:
                if self.opponent_has_immediate_win_after(node.state, a, node.player, N, self.win_len):
                    suicidal.append(a)
            if 0 < len(suicidal) < len(legal):
                for a in suicidal:
                    priors[a] = 0.0
                s = priors[legal].sum()
                if s > 0:
                    priors[legal] /= s
                else:
                    survivors = [a for a in legal if a not in suicidal]
                    for a in survivors:
                        priors[a] = 1.0 / len(survivors)

            self.ttable[key] = (priors.astype(np.float32), v)

        # create all legal children
        legal = legal_actions_from_raw(node.state, N)
        for a in legal:
            child_board = apply_action_raw(node.state, int(a), node.player, N)
            ch = Node(child_board, -node.player)
            ch.prior = float(self.ttable[key][0][a])
            node.children[int(a)] = ch

        return self.ttable[key][1]

    # ---------- root noise ----------
    def _add_dirichlet_noise(self, node, N, alpha=0.3, eps=0.25):
        if not node.children: return
        legal_actions = list(node.children.keys())
        if not legal_actions: return
        noise = np.random.dirichlet([alpha]*len(legal_actions)).astype(np.float32)
        for a, n in zip(legal_actions, noise):
            p = node.children[a].prior
            node.children[a].prior = (1 - eps)*p + eps*float(n)

    # ---------- action picker ----------
    def select_action(self, root, temperature=1.0, tactical_guard=True):
        N2 = int(np.sqrt(len(root.state)))
        visits = np.zeros(N2*N2, dtype=np.float64)
        legal  = np.zeros(N2*N2, dtype=bool)
        for a, child in root.children.items():
            visits[a] = child.visit_count
            legal[a]  = True
    
        legal_idxs = np.where(legal)[0]
        if legal_idxs.size == 0:
            legal_idxs = np.array(legal_actions_from_raw(root.state, N2), dtype=int)
            return int(np.random.choice(legal_idxs)) if legal_idxs.size > 0 else 0
    
        # Drop any move that lets opponent win in 1 (if any safe survivors exist)
        if tactical_guard:
            survivors = [a for a in legal_idxs
                         if not self.opponent_has_immediate_win_after(root.state, int(a),
                                                                 root.player, N2, self.win_len)]
            
            if len(survivors) > 0:
                legal_idxs = np.array(survivors, dtype=int)
                visits = visits  # visits are still aligned by absolute indices; OK
    
        if temperature == 0.0:
            vmax = visits[legal_idxs].max()
            ties = legal_idxs[visits[legal_idxs] == vmax]
            return int(np.random.choice(ties))
    
        x = visits[legal_idxs] ** (1.0 / max(1e-8, temperature))
        p = x / (x.sum() + 1e-8)
        return int(np.random.choice(legal_idxs, p=p))

    # ---------- tiny tactical helpers ----------
    @staticmethod
    def has_immediate_win(raw_board_tuple, player, N, win_len):
        legal = legal_actions_from_raw(raw_board_tuple, N)
        if not legal: return False, None
        for a in legal:
            nb = apply_action_raw(raw_board_tuple, a, player, N)
            winner, done = check_terminal_gomoku(nb, N, win_len)
            if done and winner == player:
                return True, a
        return False, None

    @staticmethod
    def opponent_has_immediate_win_after(raw_board_tuple, my_action, my_player, N, win_len):
        child = apply_action_raw(raw_board_tuple, my_action, my_player, N)
        opp = -my_player
        legal_opp = legal_actions_from_raw(child, N)
        for oa in legal_opp:
            nb = apply_action_raw(child, oa, opp, N)
            winner, done = check_terminal_gomoku(nb, N, win_len)
            if done and winner == opp:
                return True
        return False

# 4. Tactical Unit Test

In [24]:
import random, numpy as np, torch

# Reproducibility
random.seed(0); np.random.seed(0); torch.manual_seed(0)

# Build env/model
env = Gomoku(N=9, win_len=5)
model = PVResNet(board_size=env.N, channels=64, n_blocks=6)
# model.load_state_dict(torch.load("models/gomoku_alpha_zero.pth", map_location="cpu"))
model.eval()

mcts = MCTS(model, simulations=800, c_puct=2.0, win_len=env.win_len)

# One move:
root = mcts.run(env, add_root_noise=True)
a = mcts.select_action(root, temperature=1.0)   # temperature=1 in opening, 0 later
state, reward, done = env.step(a)

In [10]:
def must_block_test(model, N=9, win_len=5, sims=800, c_puct=2.0):
    b = np.zeros((N, N), dtype=np.int8)
    r = 4
    for c in [0,1,2,3]:   # single-ended four at edge
        b[r, c] = +1

    env = Gomoku(N=N, win_len=win_len)
    env.board[:] = b
    env.current_player = -1   # defender to move

    expected_block = r * N + 4  # only one completion for attacker

    mcts = MCTS(model, simulations=sims, c_puct=c_puct, win_len=win_len)
    root = mcts.run(env, add_root_noise=False)
    picked = mcts.select_action(root, temperature=0.0)
    return picked, expected_block

picked, exp_block = must_block_test(model, sims=100)  # bump sims if needed
print("picked:", picked, " expected block:", exp_block)

picked: 40  expected block: 40


In [11]:
def must_win_test(model, N=9, win_len=5, sims=200):
    b = np.zeros((N, N), dtype=np.int8)
    r = 4
    for c in [0,1,2,3]:
        b[r, c] = +1
    env = Gomoku(N=N, win_len=win_len)
    env.board[:] = b
    env.current_player = +1   # attacker to move

    expected_win = r * N + 4
    mcts = MCTS(model, simulations=sims, c_puct=2.0, win_len=win_len)
    root = mcts.run(env, add_root_noise=False)
    picked = mcts.select_action(root, temperature=0.0)
    return picked, expected_win

picked, expected_win = must_win_test(model, sims=100)  # bump sims if needed
print("picked:", picked, " expected win:", expected_win)

picked: 40  expected win: 40


In [31]:
def avoid_suicide_test(model, N=9, win_len=5, sims=400, c_puct=2.0):
    """
    Position where attacker has a single immediate winning square.
    Defender (-1) must block that square; any other move is suicidal.
    """
    import numpy as np
    # +1 stones at row r=4, columns 0..3 (edge). Only (4,4) completes five.
    b = np.zeros((N, N), dtype=np.int8)
    r = 4
    for c in [0,1,2,3]:
        b[r, c] = +1

    env = Gomoku(N=N, win_len=win_len)
    env.board[:] = b
    env.current_player = -1

    block = r*N + 4      # the only square that prevents an immediate win
    legal = set(env.get_legal_actions())
    assert block in legal

    mcts = MCTS(model, simulations=sims, c_puct=c_puct, win_len=win_len)
    root = mcts.run(env, add_root_noise=False)
    picked = mcts.select_action(root, temperature=0.0, tactical_guard=True)

    # Suicide if opponent can then win in one move
    raw = tuple(env.board.reshape(-1))
    is_suicide = MCTS.opponent_has_immediate_win_after(raw, picked, -1, N, win_len)
    is_block   = (picked == block)

    return {
        "picked": picked,
        "block": block,
        "is_block": is_block,
        "is_suicide": is_suicide,
        "pass_test": (is_block and not is_suicide)
    }

# Example:
res = avoid_suicide_test(model, sims=100)
print(res)

{'picked': 40, 'block': 40, 'is_block': True, 'is_suicide': False, 'pass_test': True}


# 5. Self-Play with Curriculum Learning

In [35]:
import torch, random, math, collections
import random
import numpy as np

# --- helpers to rotate/reflect both planes and a flat policy over N*N cells ---
def rotate90_board(planes):    # planes: [C,N,N]
    return np.rot90(planes, k=1, axes=(1,2)).copy()

def rotate180_board(planes):
    return np.rot90(planes, k=2, axes=(1,2)).copy()

def rotate270_board(planes):
    return np.rot90(planes, k=3, axes=(1,2)).copy()

def flip_h_board(planes):
    return np.flip(planes, axis=2).copy()   # horizontal flip (left-right)

def idx_map_rotate90(i, N):
    r, c = divmod(i, N)
    rr, cc = c, N-1-r
    return rr*N + cc

def idx_map_rotate180(i, N):
    r, c = divmod(i, N)
    rr, cc = N-1-r, N-1-c
    return rr*N + cc

def idx_map_rotate270(i, N):
    r, c = divmod(i, N)
    rr, cc = N-1-c, r
    return rr*N + cc

def idx_map_flip_h(i, N):
    r, c = divmod(i, N)
    rr, cc = r, N-1-c
    return rr*N + cc

def transform_pi(pi, idx_map, N):
    out = np.zeros_like(pi)
    for i, p in enumerate(pi):
        out[idx_map(i, N)] = p
    return out

def symmetries_board_and_pi(planes, pi, N):
    # 8 sym group: I, R90, R180, R270, FH, FH+R90, FH+R180, FH+R270
    out = []

    # Identity
    out.append((planes.copy(), pi.copy()))

    # Rotations
    p1, m1 = rotate90_board(planes),  (lambda i: idx_map_rotate90(i, N))
    out.append((p1,  transform_pi(pi, m1, N)))

    p2, m2 = rotate180_board(planes), (lambda i: idx_map_rotate180(i, N))
    out.append((p2,  transform_pi(pi, m2, N)))

    p3, m3 = rotate270_board(planes), (lambda i: idx_map_rotate270(i, N))
    out.append((p3,  transform_pi(pi, m3, N)))

    # Flip + rotations
    pf, mf = flip_h_board(planes),    (lambda i: idx_map_flip_h(i, N))
    out.append((pf,  transform_pi(pi, mf, N)))

    pf1 = rotate90_board(pf)
    out.append((pf1, transform_pi(pi, lambda i: idx_map_rotate90(mf(i, N), N), N)))

    pf2 = rotate180_board(pf)
    out.append((pf2, transform_pi(pi, lambda i: idx_map_rotate180(mf(i, N), N), N)))

    pf3 = rotate270_board(pf)
    out.append((pf3, transform_pi(pi, lambda i: idx_map_rotate270(mf(i, N), N), N)))

    return out

class Replay:
    def __init__(self, capacity=100_000):
        self.buf = collections.deque(maxlen=capacity)
    def add(self, x): self.buf.append(x)
    def sample(self, bs): return random.sample(self.buf, bs)
    def __len__(self): return len(self.buf)

"""
# fixed ceiling for clarity
MAX_SIMS = 800

def sims_schedule(iter_idx, move_no, base=200, max_sims=MAX_SIMS, warmup_iters=20, ramp_moves=12):
    # ramp with training iteration (more search later)
    # ramp with move number (opening cheaper, mid/late richer)
    
    # ramp by iteration
    f_iter = min(1.0, iter_idx / float(warmup_iters))
    sims_iter = int(base + (max_sims - base) * f_iter)
    # ramp by move number
    f_move = min(1.0, move_no / float(ramp_moves))
    sims = int(base + (sims_iter - base) * f_move)
    return max(base, min(max_sims, sims))

def opening_eps_schedule(iter_idx, move_no, max_eps=0.40, stop_iter=30, stop_move=10):
    f_iter = max(0.0, 1.0 - iter_idx/float(stop_iter))
    f_move = max(0.0, 1.0 - move_no/float(stop_move))
    return max_eps * f_iter * f_move

def self_play_episode(env, model, mcts, temp_moves=8, iter_idx=1):
    env.reset()
    # list of (planes, pi, player)
    traj = []
    done = False
    move_no = 0
    while not done:
        # adjust sims on the fly
        # mcts.simulations = sims_schedule(iter_idx, move_no, base=200, max_sims=mcts.simulations)
        mcts.simulations = sims_schedule(iter_idx, move_no, base=200, max_sims=MAX_SIMS)
        eps_open = opening_eps_schedule(iter_idx, move_no)
        root = mcts.run(
            env, 
            add_root_noise=True, 
            dir_alpha=0.15, 
            dir_eps=0.25,
            opening_bias=OPENING_BIAS,
            opening_eps=eps_open
        )
        T = 1.0 if move_no < temp_moves else 0.0
        a = mcts.select_action(root, temperature=T)

        # build visit-count policy target π
        pi = np.zeros(env.N * env.N, dtype=np.float32)
        for aa, child in root.children.items():
            pi[aa] = child.visit_count
        pi /= (pi.sum() + 1e-8)

        planes = env.to_planes()
        traj.append((planes, pi, env.current_player))

        _, reward, done = env.step(a)
        move_no += 1

    # assign final outcome z to each step from that step's player perspective
    data = []
    for planes, pi, player in traj:
        # +1 win for player's perspective, 0 draw, -1 loss
        z = reward * player  
        # 8 sym augmentations
        for sp in symmetries(planes):
            data.append((sp.astype(np.float32), pi.copy(), float(z)))
        # list of (planes[3,N,N], pi[N*N], z)
    return data
"""

# ----- curriculum knobs -----
CURRICULUM = dict(
    max_sims=1600,          # stronger ceiling later if you like (e.g., 1600)
    base_sims=200,         # very cheap opening search
    sims_warmup_iters=40,  # how many iterations to reach near max_sims
    sims_ramp_moves=15,    # ramp sims across the opening moves
    opening_max_eps=0.40,  # strength of center prior mix at start
    opening_stop_iter=30,  # how fast to fade opening bias across iters
    opening_stop_move=10,  # ... and across moves
    temp_opening_moves=10   # use T=1 for first K moves, then T=0
)

def sims_schedule(iter_idx, move_no, cfg=CURRICULUM):
    base, maxs = cfg['base_sims'], cfg['max_sims']
    f_iter = min(1.0, iter_idx / float(cfg['sims_warmup_iters']))
    sims_iter = int(base + (maxs - base) * f_iter)
    f_move = min(1.0, move_no / float(cfg['sims_ramp_moves']))
    sims = int(base + (sims_iter - base) * f_move)
    return max(base, min(maxs, sims))

def opening_eps_schedule(iter_idx, move_no, cfg=CURRICULUM):
    max_eps = cfg['opening_max_eps']
    f_iter = max(0.0, 1.0 - iter_idx/float(cfg['opening_stop_iter']))
    f_move = max(0.0, 1.0 - move_no/float(cfg['opening_stop_move']))
    return max_eps * f_iter * f_move

def temperature_for_move(move_no, cfg=CURRICULUM):
    return 1.0 if move_no < cfg['temp_opening_moves'] else 0.0

def self_play_episode(env, model, mcts, iter_idx=1, cfg=CURRICULUM):
    env.reset()
    traj, done, move_no = [], False, 0
    while not done:
        mcts.simulations = sims_schedule(iter_idx, move_no, cfg)
        eps_open = opening_eps_schedule(iter_idx, move_no, cfg)
        root = mcts.run(
            env,
            add_root_noise=True,
            dir_alpha=0.15, dir_eps=0.25,
            opening_bias=center_bias(env.N, sigma=0.35),
            opening_eps=eps_open
        )
        T = temperature_for_move(move_no, cfg)
        a = mcts.select_action(root, temperature=T)

        pi = np.zeros(env.N * env.N, dtype=np.float32)
        for aa, child in root.children.items():
            pi[aa] = child.visit_count
        s = pi.sum()
        if s > 0: pi /= s

        traj.append((env.to_planes(), pi, env.current_player))
        _, reward, done = env.step(a)
        move_no += 1

    # --- FIXED outcome mapping ---
    if reward == 0:
        outcome = 0
    else:
        winner = -env.current_player     # player who just moved
        outcome = +1 if winner == +1 else -1

    data = []
    for planes, pi, player in traj:
        z = outcome * player
        for planes_aug, pi_aug in symmetries_board_and_pi(planes, pi, env.N):
            data.append((planes_aug.astype(np.float32), pi_aug.astype(np.float32), float(z)))
    return data

def train_step(
    model,
    optimizer,
    batch,
    N,
    policy_coef=1.0,
    value_coef=0.5,
    entropy_coef=0.0,   # try 1e-3 for the first ~20 iterations
    grad_clip=None,     # e.g., 1.0 to clip by norm
):
    # batch: list of (planes[3,N,N], pi[N*N], z in {-1,0,+1})
    device = next(model.parameters()).device

    planes = torch.from_numpy(np.stack([b[0] for b in batch])).to(device=device, dtype=torch.float32)
    pi     = torch.from_numpy(np.stack([b[1] for b in batch])).to(device=device, dtype=torch.float32)
    z      = torch.tensor([[b[2]] for b in batch], device=device, dtype=torch.float32)

    logits, value = model(planes)                 # logits: [B, N*N], value: [B,1]
    logp = F.log_softmax(logits, dim=1)
    p    = F.softmax(logits, dim=1)

    # policy cross-entropy with the MCTS targets
    policy_loss = -(pi * logp).sum(dim=1).mean()
    # value MSE to {-1,0,1} (after your z fix in self-play)
    value_loss  = F.mse_loss(value, z)
    # optional exploration bonus early in training
    entropy     = -(p * logp).sum(dim=1).mean()

    loss = policy_coef * policy_loss + value_coef * value_loss - entropy_coef * entropy

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    if grad_clip is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()

    return float(loss.item()), float(policy_loss.item()), float(value_loss.item())

def random_action(env):
    return random.choice(env.get_legal_actions())

def evaluate_vs_random(model, sims=800, games=50, N=9, win_len=5):
    wins = draws = losses = 0
    for g in range(games):
        env = Gomoku(N=N, win_len=win_len)
        my_color = +1 if (g % 2 == 0) else -1  # alternate our color
        env.current_player = my_color          # start from our color that game
        my_mcts = MCTS(model, simulations=sims, c_puct=2.0, win_len=win_len)

        done = False
        while not done:
            if env.current_player == my_color:
                root = my_mcts.run(env, add_root_noise=False)
                a = my_mcts.select_action(root, temperature=0.0)
            else:
                a = random_action(env)
            _, reward, done = env.step(a)

        if reward == 0:
            draws += 1
        else:
            winner = -env.current_player  # player who just moved
            if winner == my_color: wins += 1
            else: losses += 1
    print(f"[Eval vs random] W {wins}  D {draws}  L {losses}  (games={games})")
    return wins, draws, losses

# ------------- main training loop (sketch) -------------
def train_gomoku(num_iters=200, episodes_per_iter=20, sims=800, batch_size=256):
    env = Gomoku(N=9, win_len=5)
    model = PVResNet(board_size=env.N, channels=64, n_blocks=6)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iters)
    mcts = MCTS(model, simulations=sims, c_puct=2.0, win_len=env.win_len)
    replay = Replay(200_000)

    ENTROPY_WARM_ITERS = 20     # turn on small entropy bonus for first ~20 iterations
    ENTROPY_COEF_START = 1e-3   # or 5e-4

    for it in range(1, num_iters+1):
        # ------- self-play -------
        for _ in range(episodes_per_iter):
            data = self_play_episode(env, model, mcts, iter_idx=it, cfg=CURRICULUM)
            for item in data:
                replay.add(item)

        # ------- train steps -------
        # schedule entropy: on for first ENTROPY_WARM_ITERS, then 0.0
        # entropy_coef = ENTROPY_COEF_START if it <= ENTROPY_WARM_ITERS else 0.0
        # cosine decay from ENTROPY_COEF_START → 0 over ENTROPY_WARM_ITERS
        if it <= ENTROPY_WARM_ITERS:
            t = (it-1) / max(1, ENTROPY_WARM_ITERS-1)
            entropy_coef = 0.5 * (1 + math.cos(math.pi * t)) * ENTROPY_COEF_START
        else:
            entropy_coef = 0.0

        losses = []
        num_batches = min(400, len(replay)//batch_size)
        for _ in range(num_batches):
            batch = replay.sample(batch_size)
            l, pl, vl = train_step(
                model, optimizer, batch, env.N,
                policy_coef=1.0, value_coef=0.5,
                entropy_coef=entropy_coef,
                grad_clip=1.0,   # optional but often helpful
            )
            losses.append(l)

        sched.step()
        print(f"Iter {it:03d} | replay {len(replay)} | sims {sims} | "
              f"lr {sched.get_last_lr()[0]:.4e} | loss {np.mean(losses) if losses else 0:.3f}")

        # cheap/occasional evals only if you want
        # if it % 25 == 0:
        #     evaluate_vs_random(model, sims=200, games=8, N=env.N, win_len=env.win_len)

        if it % 10 == 0:
            torch.save(model.state_dict(), f"models/gomoku_alpha_zero_iter{it:03d}.pth")

    torch.save(model.state_dict(), "models/gomoku_alpha_zero.pth")
    return model

In [36]:
train_gomoku()

KeyboardInterrupt: 

# 6. Strength Ladder

In [None]:
def eval_vs_mcts(model, my_sims, opp_sims, games=40, N=9, win_len=5):
    wins = draws = losses = 0
    for g in range(games):
        env = Gomoku(N=N, win_len=win_len)
        my_color = +1 if (g % 2 == 0) else -1
        env.current_player = my_color

        my_mcts  = MCTS(model, simulations=my_sims,  c_puct=2.0, win_len=win_len)
        opp_mcts = MCTS(model, simulations=opp_sims, c_puct=2.0, win_len=win_len)

        done = False
        while not done:
            if env.current_player == my_color:
                root = my_mcts.run(env, add_root_noise=False)
                a = my_mcts.select_action(root, temperature=0.0)
            else:
                root = opp_mcts.run(env, add_root_noise=False)
                a = opp_mcts.select_action(root, temperature=0.0)
            _, reward, done = env.step(a)

        if reward == 0: draws += 1
        else:
            winner = -env.current_player
            if winner == my_color: wins += 1
            else: losses += 1
    return wins, draws, losses

def strength_ladder(model, N=9, win_len=5):
    print("[Ladder] vs random (color-aware)")
    evaluate_vs_random(model, sims=400, games=50, N=N, win_len=win_len)

    print("\n[Ladder] symmetric sims (ours == theirs)")
    for sims in [100, 200, 400, 800, 1200]:
        w,d,l = eval_vs_mcts(model, my_sims=sims, opp_sims=sims, games=50, N=N, win_len=win_len)
        print(f"sims={sims:4}  W {w:2}  D {d:2}  L {l:2}")

    print("\n[Ladder] advantage (ours ≥ 800 sims)")
    for opp in [100, 200, 400, 800, 1200]:
        w,d,l = eval_vs_mcts(model, my_sims=max(opp, 800), opp_sims=opp, games=50, N=N, win_len=win_len)
        print(f"opp={opp:4}  ours={max(opp,800):4}  W {w:2}  D {d:2}  L {l:2}")

In [None]:
model.eval()
strength_ladder(model, N=9, win_len=5)

# 7. Interative Games

In [27]:
# pip install ipywidgets
import ipywidgets as W
from IPython.display import display

class GomokuUI:
    def __init__(self, model, N=9, win_len=5, sims=800, human_player=-1):
        self.model = model
        self.N = N
        self.win_len = win_len
        self.human_player = int(human_player)   # +1 black (X), -1 white (O)
        self.sims = sims

        self.env = Gomoku(N=N, win_len=win_len)
        self.buttons = [[W.Button(description=" ", layout=W.Layout(width="40px", height="40px")) 
                         for _ in range(N)] for _ in range(N)]
        for r in range(N):
            for c in range(N):
                self.buttons[r][c].on_click(self._make_cb(r, c))

        self.msg = W.HTML(value="")
        self.reset_btn = W.Button(description="Reset")
        self.reset_btn.on_click(self.reset)

        grid = W.GridBox(
            children=[self.buttons[r][c] for r in range(N) for c in range(N)],
            layout=W.Layout(grid_template_columns=f"repeat({N}, 42px)")
        )
        display(W.VBox([grid, W.HBox([self.reset_btn]), self.msg]))

        # AI may start
        if self.human_player == -1 and self.env.current_player == +1:
            self.ai_move()
        self.refresh()

    def _make_cb(self, r, c):
        def _cb(btn):
            if self.env.current_player != self.human_player:
                return
            a = r * self.N + c
            if a not in self.env.get_legal_actions():
                return
            _, _, done = self.env.step(a)
            if not done:
                self.ai_move()
            self.refresh()
        return _cb

    """
    # Deterministic
    def ai_move(self):
        mcts = MCTS(self.model, simulations=self.sims, c_puct=1.5, win_len=self.win_len)
        root = mcts.run(self.env, add_root_noise=False)  # no noise for a strong play
        a = mcts.select_action(root, temperature=0.0)
        self.env.step(a)
    """
    
    # Less deterministic, more human-like variety
    def ai_move(self):
        mcts = MCTS(self.model, simulations=self.sims, c_puct=2.0, win_len=self.win_len)
        # Optional: tiny root noise for variety in opening
        add_noise = (self.env.num_moves() < 8)  # implement num_moves() if you like
        root = mcts.run(self.env, add_root_noise=add_noise, dir_alpha=0.15, dir_eps=0.10)
        # Temperature schedule
        T = 1.0 if self.env.num_moves() < 8 else 0.0
        a = mcts.select_action(root, temperature=T)
        self.env.step(a)
    
    def refresh(self):
        sym = {1:"X", -1:"O", 0:" "}
        for r in range(self.N):
            for c in range(self.N):
                v = int(self.env.board[r, c])
                b = self.buttons[r][c]
                b.description = sym[v]
                b.disabled = bool(v != 0)

        reward, done = self.env.check_winner_fast()
        if done:
            if reward == 0:
                self.msg.value = "<b>Draw.</b> Click Reset."
            else:
                winner = 'X' if -self.env.current_player == +1 else 'O'
                self.msg.value = f"<b>{winner} wins!</b> Click Reset."
            for r in range(self.N):
                for c in range(self.N):
                    self.buttons[r][c].disabled = True
        else:
            turn = 'X' if self.env.current_player==+1 else 'O'
            self.msg.value = f"Player {turn} to move."

    def reset(self, _=None):
        self.env.reset()
        if self.human_player == -1 and self.env.current_player == +1:
            self.ai_move()
        self.refresh()

In [25]:
# Example usage:
model.eval()
ui = GomokuUI(model, N=9, win_len=5, sims=600, human_player=-1)  # human plays O

VBox(children=(GridBox(children=(Button(description=' ', layout=Layout(height='40px', width='40px'), style=But…

In [26]:
# Example usage:
model.eval()
ui = GomokuUI(model, N=9, win_len=5, sims=1600, human_player=-1)  # human plays O

VBox(children=(GridBox(children=(Button(description=' ', layout=Layout(height='40px', width='40px'), style=But…