# 1. Gomoku

In [1]:
import numpy as np

class Gomoku:
    """
    N×N board, players in {+1 (black), -1 (white)}, empty=0.
    Win condition: 5 in a row (any direction).
    """
    def __init__(self, N=9, win_len=5):
        self.N = int(N)
        self.win_len = int(win_len)
        self.reset()

    def reset(self):
        self.board = np.zeros((self.N, self.N), dtype=np.int8)
        self.current_player = 1
        self.done = False
        self.last_move = None  # (r,c) or None
        return self.get_state()  # canonical 1×(N*N) if you want; see below

    # --- Canonical features for NN (3 planes): me, opp, to-move ---
    def to_planes(self):
        me  = (self.board == self.current_player).astype(np.float32)
        opp = (self.board == -self.current_player).astype(np.float32)
        turn = np.full_like(me, 1.0, dtype=np.float32)  # “to move” plane
        # shape: [3, N, N]
        return np.stack([me, opp, turn], axis=0)

    # For drop-in compatibility with your older code:
    def get_state(self):
        # vectorized canonical state (board * current_player) like tic-tac-toe
        return (self.board.flatten() * self.current_player).astype(np.float32)

    def get_legal_actions(self):
        # actions are 0..N*N-1
        flat = self.board.reshape(-1)
        return [i for i in range(self.N * self.N) if flat[i] == 0]

    def step(self, action):
        """Apply action index (0..N^2-1). Returns (state, reward, done)."""
        if self.done:
            return self.get_state(), 0.0, True
        r, c = divmod(int(action), self.N)
        if self.board[r, c] != 0:
            # illegal move penalty (AlphaZero typically forbids via masking;
            # but keep this as a safeguard)
            self.done = True
            return self.get_state(), -1.0, True

        self.board[r, c] = self.current_player
        self.last_move = (r, c)

        reward, done = self.check_winner_fast()
        self.done = done
        state = self.get_state()
        self.current_player *= -1
        return state, float(reward), bool(done)

    # --------- winner checks ----------
    def check_winner_fast(self):
        """
        Check only the lines passing through self.last_move for speed.
        Returns (reward_for_player_who_just_moved, done)
        reward ∈ {+1 win, 0 draw/ongoing, -1 loss} relative to the mover.
        """
        if self.last_move is None:
            return 0.0, False
        r, c = self.last_move
        p = self.board[r, c]
        if p == 0:
            return 0.0, False

        if ( self._count_dir(r, c, 1, 0, p) + self._count_dir(r, c, -1, 0, p) - 1 >= self.win_len or
             self._count_dir(r, c, 0, 1, p) + self._count_dir(r, c, 0, -1, p) - 1 >= self.win_len or
             self._count_dir(r, c, 1, 1, p) + self._count_dir(r, c, -1, -1, p) - 1 >= self.win_len or
             self._count_dir(r, c, 1, -1, p) + self._count_dir(r, c, -1, 1, p) - 1 >= self.win_len ):
            # player p (who just moved) wins: reward +1 for mover
            return 1.0, True

        if (self.board != 0).all():
            return 0.0, True  # draw

        return 0.0, False

    def _count_dir(self, r, c, dr, dc, p):
        """Count contiguous stones with color p starting at (r,c) inclusive along (dr,dc)."""
        N = self.N
        cnt = 0
        rr, cc = r, c
        while 0 <= rr < N and 0 <= cc < N and self.board[rr, cc] == p:
            cnt += 1
            rr += dr
            cc += dc
        return cnt

    # --------- convenience ----------
    def render_ascii(self):
        sym = {1:'X', -1:'O', 0:'.'}
        print("\n".join(" ".join(sym[v] for v in row) for row in self.board))

# 2. ResNet Policy-Value Model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(channels)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        y = self.bn2(self.conv2(y))
        return F.relu(x + y)

class PVResNet(nn.Module):
    """
    Policy–Value net for Gomoku.
    Input: [B, 3, N, N]
    Outputs:
      - policy_logits: [B, N*N]
      - value:         [B, 1]  (tanh)
    """
    def __init__(self, board_size=9, channels=64, n_blocks=6):
        super().__init__()
        self.N = board_size
        C = channels

        self.stem = nn.Sequential(
            nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(C),
            nn.ReLU(inplace=True)
        )
        self.trunk = nn.Sequential(*[ResidualBlock(C) for _ in range(n_blocks)])

        # Policy head
        self.p_head = nn.Sequential(
            nn.Conv2d(C, 2, kernel_size=1, bias=False),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True)
        )
        self.p_fc = nn.Linear(2 * self.N * self.N, self.N * self.N)

        # Value head
        self.v_head = nn.Sequential(
            nn.Conv2d(C, 1, kernel_size=1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )
        self.v_fc1 = nn.Linear(1 * self.N * self.N, C)
        self.v_fc2 = nn.Linear(C, 1)

        # (optional) init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                nn.init.zeros_(m.bias)

    def forward(self, x):  # x: [B, 3, N, N]
        z = self.stem(x)
        z = self.trunk(z)

        # policy
        p = self.p_head(z)
        p = p.view(p.size(0), -1)
        policy_logits = self.p_fc(p)  # [B, N*N]

        # value
        v = self.v_head(z)
        v = v.view(v.size(0), -1)
        v = F.relu(self.v_fc1(v))
        value = torch.tanh(self.v_fc2(v))  # [-1, 1]

        return policy_logits, value

# 3. Monte Carlo Tree Search (MCTS)

In [3]:
import numpy as np
import torch
import torch.nn.functional as F

def gomoku_planes_from_raw(raw_board_tuple, player, N):
    """Build [3, N, N] planes from (raw_board, player_to_move)."""
    b = np.array(raw_board_tuple, dtype=np.int8).reshape(N, N)
    me  = (b == player).astype(np.float32)
    opp = (b == -player).astype(np.float32)
    turn = np.ones_like(me, dtype=np.float32)
    return np.stack([me, opp, turn], axis=0)

def legal_actions_from_raw(raw_board_tuple, N):
    b = np.array(raw_board_tuple, dtype=np.int8).reshape(N, N)
    return [i for i in range(N*N) if b.flat[i] == 0]

def apply_action_raw(raw_board_tuple, action, player, N):
    b = np.array(raw_board_tuple, dtype=np.int8).reshape(N, N).copy()
    r, c = divmod(int(action), N)
    b[r, c] = player
    return tuple(b.reshape(-1))

def check_terminal_gomoku(raw_board_tuple, N, win_len=5):
    """Return (value_for_current_node_player, done). Full scan (O(N^2)), fine for 9x9."""
    b = np.array(raw_board_tuple, dtype=np.int8).reshape(N, N)

    def has_k(p):
        # rows
        for r in range(N):
            run = 0
            for c in range(N):
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
        # cols
        for c in range(N):
            run = 0
            for r in range(N):
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
        # diag \
        for r0 in range(N):
            run = 0
            r, c = r0, 0
            while r < N and c < N:
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
                r += 1; c += 1
        for c0 in range(1, N):
            run = 0
            r, c = 0, c0
            while r < N and c < N:
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
                r += 1; c += 1
        # diag /
        for r0 in range(N):
            run = 0
            r, c = r0, 0
            while r >= 0 and c < N:
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
                r -= 1; c += 1
        for c0 in range(1, N):
            run = 0
            r, c = N-1, c0
            while r >= 0 and c < N:
                run = run + 1 if b[r, c] == p else 0
                if run >= win_len: return True
                r -= 1; c += 1
        return False

    if has_k(+1): return (+1.0, True)
    if has_k(-1): return (-1.0, True)
    if (b != 0).all(): return (0.0, True)
    return (0.0, False)

# Opening bias
def center_bias(N, sigma=0.35):
    """Return normalized N*N center-biased prior (Gaussian on [0,1]^2 grid)."""
    xs = np.linspace(0, 1, N)
    X, Y = np.meshgrid(xs, xs, indexing='ij')
    cx = cy = 0.5
    g = np.exp(-((X-cx)**2 + (Y-cy)**2) / (2*(sigma**2)))
    v = g.reshape(-1).astype(np.float32)
    v /= v.sum() + 1e-8
    return v

def mix_opening_bias(node, N, bias_vec, eps=0.25):
    """Blend bias_vec into child priors at the root (legal only)."""
    if eps is None or eps <= 0.0:
        return
    legal = list(node.children.keys())
    if not legal: return
    b = bias_vec.copy()
    # renorm on legal only
    s = b[legal].sum()
    b = b / (s + 1e-8)
    for a in legal:
        p = node.children[a].prior
        node.children[a].prior = (1 - eps) * p + eps * float(b[a])

In [4]:
# ---- Node ----
class Node:
    def __init__(self, raw_board_tuple, player_to_move):
        self.state = raw_board_tuple          # tuple of length N*N (0/±1)
        self.player = int(player_to_move)     # +1 or -1
        self.children = {}                    # action -> Node
        self.visit_count = 0
        self.total_value = 0.0                # mean value from this node's player perspective
        self.prior = 0.0

    def value(self):
        return 0.0 if self.visit_count == 0 else self.total_value / self.visit_count

# ---- MCTS ----
class MCTS:
    def __init__(self, model, simulations=800, c_puct=1.5, win_len=5):
        self.model = model
        self.simulations = simulations
        self.c_puct = c_puct
        self.win_len = win_len
        self.ttable = {}  # (state_tuple, player, N) -> (priors_np[N*N], value_float)

    # ----------------------------
    def run(self, env, add_root_noise=True, dir_alpha=0.3, dir_eps=0.25,
            opening_bias=None, opening_eps=0.25):
        N = env.N
        root = Node(tuple(env.board.reshape(-1)), env.current_player)
        if not root.children:
            _ = self._expand_and_evaluate(root, N)
    
        if add_root_noise:
            self._add_dirichlet_noise(root, N, alpha=dir_alpha, eps=dir_eps)
    
        if opening_bias is not None:
            mix_opening_bias(root, N, opening_bias, eps=opening_eps)
    
        for _ in range(self.simulations):
            self._simulate(root, N)
        return root

    # ----------------------------
    def _simulate(self, node, N):
        # Terminal test
        wval, done = check_terminal_gomoku(node.state, N, self.win_len)
        if done:
            v = float(wval)                      # value from node.player perspective
            node.visit_count += 1
            node.total_value += v
            return v

        # Leaf: not visited yet and no children → expand
        if node.visit_count == 0 and not node.children:
            v = self._expand_and_evaluate(node, N)
            node.visit_count += 1
            node.total_value += v
            return v

        # Select via PUCT
        best_a, best_score = None, -1e9
        sqrt_N = np.sqrt(node.visit_count + 1e-8)
        for a, child in node.children.items():
            ucb = child.value() + self.c_puct * child.prior * (sqrt_N / (1 + child.visit_count))
            if ucb > best_score:
                best_score, best_a = ucb, a

        child = node.children[best_a]
        v_child = self._simulate(child, N)
        v = -v_child                              # perspective switch
        node.visit_count += 1
        node.total_value += v
        return v

    # ----------------------------
    def _expand_and_evaluate(self, node, N):
        key = (node.state, node.player, N)
        if key in self.ttable:
            priors, v = self.ttable[key]
        else:
            planes = gomoku_planes_from_raw(node.state, node.player, N)
            st = torch.from_numpy(planes).unsqueeze(0)
            with torch.no_grad():
                logits, value = self.model(st)
                logits = logits.squeeze(0)        # [N*N]
                v = float(value.item())
            policy = F.softmax(logits, dim=-1).cpu().numpy()  # length N*N

            legal = legal_actions_from_raw(node.state, N)
            if len(legal) == 0:
                priors = np.zeros(N*N, dtype=np.float32)
            else:
                mask = np.full(N*N, -1e9, dtype=np.float32)
                mask[legal] = 0.0
                masked = np.log(policy + 1e-12) + mask  # numeric stable masking
                masked -= masked.max()
                priors = np.exp(masked)
                s = priors[legal].sum()
                priors = priors / (s + 1e-8)

            self.ttable[key] = (priors.astype(np.float32), v)

        # Create children
        for a in np.where(self.ttable[key][0] > 0)[0]:
            child_board = apply_action_raw(node.state, int(a), node.player, N)
            child = Node(child_board, -node.player)
            child.prior = float(self.ttable[key][0][a])
            node.children[int(a)] = child

        return v

    # ----------------------------
    def _add_dirichlet_noise(self, node, N, alpha=0.3, eps=0.25):
        if not node.children:
            return
        legal_actions = list(node.children.keys())
        if len(legal_actions) == 0:
            return
        noise = np.random.dirichlet([alpha] * len(legal_actions)).astype(np.float32)
        for a, n in zip(legal_actions, noise):
            p = node.children[a].prior
            node.children[a].prior = (1 - eps) * p + eps * float(n)

    # ----------------------------
    def select_action(self, root, temperature=1.0):
        N2 = int(np.sqrt(len(self.ttable.get((root.state, root.player, int(np.sqrt(len(root.state)))), (np.zeros(len(root.state)), 0.0))[0])) or len(root.state))
        # safer: compute over children
        visits = np.zeros(N2*N2, dtype=np.float64)
        legal  = np.zeros(N2*N2, dtype=bool)
        for a, child in root.children.items():
            visits[a] = child.visit_count
            legal[a] = True

        legal_idxs = np.where(legal)[0]
        if legal_idxs.size == 0:
            # no children? pick any legal by state
            legal_idxs = np.array(legal_actions_from_raw(root.state, N2), dtype=int)
            if legal_idxs.size == 0:
                return 0
            return int(np.random.choice(legal_idxs))

        if temperature == 0.0:
            vmax = visits[legal_idxs].max()
            ties = legal_idxs[visits[legal_idxs] == vmax]
            return int(np.random.choice(ties))

        x = visits[legal_idxs] ** (1.0 / temperature)
        p = x / (x.sum() + 1e-8)
        return int(np.random.choice(legal_idxs, p=p))

# 4. Self-Play with Curriculum Learning

In [5]:
import random, numpy as np, torch

# Reproducibility
random.seed(0); np.random.seed(0); torch.manual_seed(0)

# Build env/model
env = Gomoku(N=9, win_len=5)
model = PVResNet(board_size=env.N, channels=64, n_blocks=6)
model.eval()

mcts = MCTS(model, simulations=800, c_puct=1.5, win_len=env.win_len)

# One move:
root = mcts.run(env, add_root_noise=True)
a = mcts.select_action(root, temperature=1.0)   # temperature=1 in opening, 0 later
state, reward, done = env.step(a)

In [9]:
import torch, random, math, collections
import random
import numpy as np

# 8 symmetry transforms for NxN boards
def symmetries(planes):  # planes: [C,N,N] numpy
    ps = []
    for k in range(4):  # rotations
        r = np.rot90(planes, k=k, axes=(1,2))
        ps.append(r)
        ps.append(np.flip(r, axis=1))  # mirror
    return ps

class Replay:
    def __init__(self, capacity=100_000):
        self.buf = collections.deque(maxlen=capacity)
    def add(self, x): self.buf.append(x)
    def sample(self, bs): return random.sample(self.buf, bs)
    def __len__(self): return len(self.buf)

# precompute once
OPENING_BIAS = center_bias(N=env.N, sigma=0.35)

"""
# fixed ceiling for clarity
MAX_SIMS = 800

def sims_schedule(iter_idx, move_no, base=200, max_sims=MAX_SIMS, warmup_iters=20, ramp_moves=12):
    # ramp with training iteration (more search later)
    # ramp with move number (opening cheaper, mid/late richer)
    
    # ramp by iteration
    f_iter = min(1.0, iter_idx / float(warmup_iters))
    sims_iter = int(base + (max_sims - base) * f_iter)
    # ramp by move number
    f_move = min(1.0, move_no / float(ramp_moves))
    sims = int(base + (sims_iter - base) * f_move)
    return max(base, min(max_sims, sims))

def opening_eps_schedule(iter_idx, move_no, max_eps=0.40, stop_iter=30, stop_move=10):
    f_iter = max(0.0, 1.0 - iter_idx/float(stop_iter))
    f_move = max(0.0, 1.0 - move_no/float(stop_move))
    return max_eps * f_iter * f_move

def self_play_episode(env, model, mcts, temp_moves=8, iter_idx=1):
    env.reset()
    # list of (planes, pi, player)
    traj = []
    done = False
    move_no = 0
    while not done:
        # adjust sims on the fly
        # mcts.simulations = sims_schedule(iter_idx, move_no, base=200, max_sims=mcts.simulations)
        mcts.simulations = sims_schedule(iter_idx, move_no, base=200, max_sims=MAX_SIMS)
        eps_open = opening_eps_schedule(iter_idx, move_no)
        root = mcts.run(
            env, 
            add_root_noise=True, 
            dir_alpha=0.15, 
            dir_eps=0.25,
            opening_bias=OPENING_BIAS,
            opening_eps=eps_open
        )
        T = 1.0 if move_no < temp_moves else 0.0
        a = mcts.select_action(root, temperature=T)

        # build visit-count policy target π
        pi = np.zeros(env.N * env.N, dtype=np.float32)
        for aa, child in root.children.items():
            pi[aa] = child.visit_count
        pi /= (pi.sum() + 1e-8)

        planes = env.to_planes()
        traj.append((planes, pi, env.current_player))

        _, reward, done = env.step(a)
        move_no += 1

    # assign final outcome z to each step from that step's player perspective
    data = []
    for planes, pi, player in traj:
        # +1 win for player's perspective, 0 draw, -1 loss
        z = reward * player  
        # 8 sym augmentations
        for sp in symmetries(planes):
            data.append((sp.astype(np.float32), pi.copy(), float(z)))
        # list of (planes[3,N,N], pi[N*N], z)
    return data
"""

# ----- curriculum knobs -----
CURRICULUM = dict(
    max_sims=800,          # stronger ceiling later if you like (e.g., 1600)
    base_sims=200,         # very cheap opening search
    sims_warmup_iters=40,  # how many iterations to reach near max_sims
    sims_ramp_moves=15,    # ramp sims across the opening moves
    opening_max_eps=0.40,  # strength of center prior mix at start
    opening_stop_iter=30,  # how fast to fade opening bias across iters
    opening_stop_move=10,  # ... and across moves
    temp_opening_moves=8   # use T=1 for first K moves, then T=0
)

def sims_schedule(iter_idx, move_no, cfg=CURRICULUM):
    base, maxs = cfg['base_sims'], cfg['max_sims']
    f_iter = min(1.0, iter_idx / float(cfg['sims_warmup_iters']))
    sims_iter = int(base + (maxs - base) * f_iter)
    f_move = min(1.0, move_no / float(cfg['sims_ramp_moves']))
    sims = int(base + (sims_iter - base) * f_move)
    return max(base, min(maxs, sims))

def opening_eps_schedule(iter_idx, move_no, cfg=CURRICULUM):
    max_eps = cfg['opening_max_eps']
    f_iter = max(0.0, 1.0 - iter_idx/float(cfg['opening_stop_iter']))
    f_move = max(0.0, 1.0 - move_no/float(cfg['opening_stop_move']))
    return max_eps * f_iter * f_move

def temperature_for_move(move_no, cfg=CURRICULUM):
    return 1.0 if move_no < cfg['temp_opening_moves'] else 0.0

def self_play_episode(env, model, mcts, iter_idx=1, cfg=CURRICULUM):
    env.reset()
    traj, done, move_no = [], False, 0
    while not done:
        mcts.simulations = sims_schedule(iter_idx, move_no, cfg)
        eps_open = opening_eps_schedule(iter_idx, move_no, cfg)
        root = mcts.run(
            env,
            add_root_noise=True,
            dir_alpha=0.15, dir_eps=0.25,
            opening_bias=center_bias(env.N, sigma=0.35),
            opening_eps=eps_open
        )
        T = temperature_for_move(move_no, cfg)
        a = mcts.select_action(root, temperature=T)

        pi = np.zeros(env.N * env.N, dtype=np.float32)
        for aa, child in root.children.items(): pi[aa] = child.visit_count
        pi /= (pi.sum() + 1e-8)

        traj.append((env.to_planes(), pi, env.current_player))
        _, reward, done = env.step(a)
        move_no += 1

    data = []
    for planes, pi, player in traj:
        z = reward * player
        for sp in symmetries(planes):
            data.append((sp.astype(np.float32), pi.copy(), float(z)))
    return data

def train_step(model, optimizer, batch, N, weight_decay=1e-4):
    # batch: list of tuples
    planes = torch.from_numpy(np.stack([b[0] for b in batch]))    # [B,3,N,N]
    pi      = torch.from_numpy(np.stack([b[1] for b in batch]))   # [B,N*N]
    z       = torch.tensor([[b[2]] for b in batch], dtype=torch.float32)

    logits, value = model(planes)
    policy_loss = -(pi * F.log_softmax(logits, dim=1)).sum(dim=1).mean()
    value_loss  = F.mse_loss(value, z)
    loss = policy_loss + value_loss

    optimizer.zero_grad()
    loss.backward()
    # decoupled weight decay (AdamW style)
    optimizer.step()
    return float(loss.item()), float(policy_loss.item()), float(value_loss.item())

def random_action(env):
    return random.choice(env.get_legal_actions())

def evaluate_vs_random(model, sims=800, games=50, N=9, win_len=5):
    wins = draws = losses = 0
    for g in range(games):
        env = Gomoku(N=N, win_len=win_len)
        mcts = MCTS(model, simulations=sims, c_puct=1.5, win_len=win_len)
        # alternate who starts
        human_like = (+1 if g % 2 == 0 else -1)  # model plays both colors
        env.current_player = human_like

        done = False
        while not done:
            if env.current_player == human_like:
                root = mcts.run(env, add_root_noise=False)  # no root noise in eval
                a = mcts.select_action(root, temperature=0.0)
            else:
                a = random_action(env)
            _, reward, done = env.step(a)

        # reward is from the perspective of the player who moved last
        # winner = -env.current_player  (player who just moved)
        if reward == 0:
            draws += 1
        else:
            winner = -env.current_player
            if winner == human_like:
                wins += 1
            else:
                losses += 1
    print(f"[Eval vs random] W {wins}  D {draws}  L {losses}  (games={games})")
    return wins, draws, losses

# ------------- main training loop (sketch) -------------
def train_gomoku(num_iters=200, episodes_per_iter=20, sims=800, batch_size=256):
    env = Gomoku(N=9, win_len=5)
    model = PVResNet(board_size=env.N, channels=64, n_blocks=6)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_iters)
    mcts = MCTS(model, simulations=sims, c_puct=1.5, win_len=env.win_len)
    replay = Replay(100_000)
    
    for it in range(1, num_iters+1):
        # Self-play
        for _ in range(episodes_per_iter):
            # data = self_play_episode(env, model, mcts, temp_moves=8, iter_idx=it)
            data = self_play_episode(env, model, mcts, iter_idx=it, cfg=CURRICULUM)
            for item in data:
                replay.add(item)

        # Train
        losses = []
        for _ in range(min(400, len(replay)//batch_size) ):
            batch = replay.sample(batch_size)
            l, pl, vl = train_step(model, optimizer, batch, env.N)
            losses.append(l)
        sched.step()

        print(f"Iter {it:03d} | replay {len(replay)} | sims {sims} | lr {sched.get_last_lr()[0]:.4e} | loss {np.mean(losses) if losses else 0:.3f}")

        if it % 5 == 0:
            evaluate_vs_random(model, sims=400, games=20, N=env.N, win_len=env.win_len)

        if it % 10 == 0:
            torch.save(model.state_dict(), f"models/gomoku_alpha_zero_iter{it:03d}.pth")

    torch.save(model.state_dict(), "models/gomoku_alpha_zero.pth")
    return model

In [10]:
train_gomoku()

Iter 001 | replay 656 | sims 800 | lr 0.0000e+00 | loss 5.545


PVResNet(
  (stem): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (trunk): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResidualBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=

# 5. Strength Ladder

In [13]:
# Strength ladder
def eval_vs_mcts(model, opp_sims, games=40, N=9, win_len=5):
    wins = draws = losses = 0
    for g in range(games):
        env = Gomoku(N=N, win_len=win_len)
        # alternate who starts
        env.current_player = +1 if (g % 2 == 0) else -1
        my_sims   = max(opp_sims, 800)  # or fix both; we’ll compare across opp_sims
        my_mcts   = MCTS(model, simulations=my_sims, c_puct=1.5, win_len=win_len)
        opp_mcts  = MCTS(model, simulations=opp_sims, c_puct=1.5, win_len=win_len)

        done = False
        while not done:
            if env.current_player == +1:   # “our” side
                root = my_mcts.run(env, add_root_noise=False)
                a = my_mcts.select_action(root, temperature=0.0)
            else:                           # opponent strength
                root = opp_mcts.run(env, add_root_noise=False)
                a = opp_mcts.select_action(root, temperature=0.0)
            _, reward, done = env.step(a)

        if reward == 0: draws += 1
        else:
            winner = -env.current_player
            if winner == +1: wins += 1
            else: losses += 1
    return wins, draws, losses

def strength_ladder(model, N=9, win_len=5):
    print("[Ladder] vs random")
    evaluate_vs_random(model, sims=400, games=50, N=N, win_len=win_len)

    print("\n[Ladder] symmetric sims")
    for sims in [100, 200, 400, 800, 1200]:
        w,d,l = eval_vs_mcts(model, opp_sims=sims, games=50, N=N, win_len=win_len)  # set my_sims=opp_sims inside if you want symmetric
        print(f"MCTS({sims})  W {w}  D {d}  L {l}")

    print("\n[Ladder] advantage (ours ≥800 sims)")
    for sims in [100, 200, 400, 800, 1200]:
        w,d,l = eval_vs_mcts(model, opp_sims=sims, games=50, N=N, win_len=win_len)  # keep your current my_sims=max(opp_sims,800)
        print(f"opp MCTS({sims})  W {w}  D {d}  L {l}")

In [14]:
model.eval()
strength_ladder(model, N=9, win_len=5)

[Ladder] vs random
[Eval vs random] W 28  D 1  L 21  (games=50)
[Ladder] vs MCTS(100)  W 50  D 0  L 0
[Ladder] vs MCTS(200)  W 50  D 0  L 0


KeyboardInterrupt: 

# 6. Interative Games

In [11]:
# pip install ipywidgets
import ipywidgets as W
from IPython.display import display

class GomokuUI:
    def __init__(self, model, N=9, win_len=5, sims=800, human_player=-1):
        self.model = model
        self.N = N
        self.win_len = win_len
        self.human_player = int(human_player)   # +1 black (X), -1 white (O)
        self.sims = sims

        self.env = Gomoku(N=N, win_len=win_len)
        self.buttons = [[W.Button(description=" ", layout=W.Layout(width="40px", height="40px")) 
                         for _ in range(N)] for _ in range(N)]
        for r in range(N):
            for c in range(N):
                self.buttons[r][c].on_click(self._make_cb(r, c))

        self.msg = W.HTML(value="")
        self.reset_btn = W.Button(description="Reset")
        self.reset_btn.on_click(self.reset)

        grid = W.GridBox(
            children=[self.buttons[r][c] for r in range(N) for c in range(N)],
            layout=W.Layout(grid_template_columns=f"repeat({N}, 42px)")
        )
        display(W.VBox([grid, W.HBox([self.reset_btn]), self.msg]))

        # AI may start
        if self.human_player == -1 and self.env.current_player == +1:
            self.ai_move()
        self.refresh()

    def _make_cb(self, r, c):
        def _cb(btn):
            if self.env.current_player != self.human_player:
                return
            a = r * self.N + c
            if a not in self.env.get_legal_actions():
                return
            _, _, done = self.env.step(a)
            if not done:
                self.ai_move()
            self.refresh()
        return _cb

    def ai_move(self):
        mcts = MCTS(self.model, simulations=self.sims, c_puct=1.5, win_len=self.win_len)
        root = mcts.run(self.env, add_root_noise=False)  # no noise for a strong play
        a = mcts.select_action(root, temperature=0.0)
        self.env.step(a)

    def refresh(self):
        sym = {1:"X", -1:"O", 0:" "}
        for r in range(self.N):
            for c in range(self.N):
                v = int(self.env.board[r, c])
                b = self.buttons[r][c]
                b.description = sym[v]
                b.disabled = bool(v != 0)

        reward, done = self.env.check_winner_fast()
        if done:
            if reward == 0:
                self.msg.value = "<b>Draw.</b> Click Reset."
            else:
                winner = 'X' if -self.env.current_player == +1 else 'O'
                self.msg.value = f"<b>{winner} wins!</b> Click Reset."
            for r in range(self.N):
                for c in range(self.N):
                    self.buttons[r][c].disabled = True
        else:
            turn = 'X' if self.env.current_player==+1 else 'O'
            self.msg.value = f"Player {turn} to move."

    def reset(self, _=None):
        self.env.reset()
        if self.human_player == -1 and self.env.current_player == +1:
            self.ai_move()
        self.refresh()

In [12]:
# Example usage:
model.eval()
ui = GomokuUI(model, N=9, win_len=5, sims=600, human_player=-1)  # human plays O

VBox(children=(GridBox(children=(Button(description=' ', layout=Layout(height='40px', width='40px'), style=But…