# 1. Gomoku

In [None]:
import numpy as np

class Gomoku:
    """
    N×N board, players in {+1 (black), -1 (white)}, empty=0.
    Win condition: 5 in a row (any direction).
    """
    def __init__(self, N=9, win_len=5):
        self.N = int(N)
        self.win_len = int(win_len)
        self.reset()

    def reset(self):
        self.board = np.zeros((self.N, self.N), dtype=np.int8)
        self.current_player = 1
        self.done = False
        self.last_move = None  # (r,c) or None
        return self.get_state()  # canonical 1×(N*N) if you want; see below

    # --- Canonical features for NN (3 planes): me, opp, to-move ---
    def to_planes(self):
        me  = (self.board == self.current_player).astype(np.float32)
        opp = (self.board == -self.current_player).astype(np.float32)
        turn = np.full_like(me, 1.0, dtype=np.float32)  # “to move” plane
        # shape: [3, N, N]
        return np.stack([me, opp, turn], axis=0)

    # For drop-in compatibility with your older code:
    def get_state(self):
        # vectorized canonical state (board * current_player) like tic-tac-toe
        return (self.board.flatten() * self.current_player).astype(np.float32)

    def get_legal_actions(self):
        # actions are 0..N*N-1
        flat = self.board.reshape(-1)
        return [i for i in range(self.N * self.N) if flat[i] == 0]

    def step(self, action):
        """Apply action index (0..N^2-1). Returns (state, reward, done)."""
        if self.done:
            return self.get_state(), 0.0, True
        r, c = divmod(int(action), self.N)
        if self.board[r, c] != 0:
            # illegal move penalty (AlphaZero typically forbids via masking;
            # but keep this as a safeguard)
            self.done = True
            return self.get_state(), -1.0, True

        self.board[r, c] = self.current_player
        self.last_move = (r, c)

        reward, done = self.check_winner_fast()
        self.done = done
        state = self.get_state()
        self.current_player *= -1
        return state, float(reward), bool(done)

    # --------- winner checks ----------
    def check_winner_fast(self):
        """
        Check only the lines passing through self.last_move for speed.
        Returns (reward_for_player_who_just_moved, done)
        reward ∈ {+1 win, 0 draw/ongoing, -1 loss} relative to the mover.
        """
        if self.last_move is None:
            return 0.0, False
        r, c = self.last_move
        p = self.board[r, c]
        if p == 0:
            return 0.0, False

        if ( self._count_dir(r, c, 1, 0, p) + self._count_dir(r, c, -1, 0, p) - 1 >= self.win_len or
             self._count_dir(r, c, 0, 1, p) + self._count_dir(r, c, 0, -1, p) - 1 >= self.win_len or
             self._count_dir(r, c, 1, 1, p) + self._count_dir(r, c, -1, -1, p) - 1 >= self.win_len or
             self._count_dir(r, c, 1, -1, p) + self._count_dir(r, c, -1, 1, p) - 1 >= self.win_len ):
            # player p (who just moved) wins: reward +1 for mover
            return 1.0, True

        if (self.board != 0).all():
            return 0.0, True  # draw

        return 0.0, False

    def _count_dir(self, r, c, dr, dc, p):
        """Count contiguous stones with color p starting at (r,c) inclusive along (dr,dc)."""
        N = self.N
        cnt = 0
        rr, cc = r, c
        while 0 <= rr < N and 0 <= cc < N and self.board[rr, cc] == p:
            cnt += 1
            rr += dr
            cc += dc
        return cnt

    # --------- convenience ----------
    def render_ascii(self):
        sym = {1:'X', -1:'O', 0:'.'}
        print("\n".join(" ".join(sym[v] for v in row) for row in self.board))

# 2. ResNet Policy-Value Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(channels)

    def forward(self, x):
        y = F.relu(self.bn1(self.conv1(x)))
        y = self.bn2(self.conv2(y))
        return F.relu(x + y)

class PVResNet(nn.Module):
    """
    Policy–Value net for Gomoku.
    Input: [B, 3, N, N]
    Outputs:
      - policy_logits: [B, N*N]
      - value:         [B, 1]  (tanh)
    """
    def __init__(self, board_size=9, channels=64, n_blocks=6):
        super().__init__()
        self.N = board_size
        C = channels

        self.stem = nn.Sequential(
            nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(C),
            nn.ReLU(inplace=True)
        )
        self.trunk = nn.Sequential(*[ResidualBlock(C) for _ in range(n_blocks)])

        # Policy head
        self.p_head = nn.Sequential(
            nn.Conv2d(C, 2, kernel_size=1, bias=False),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True)
        )
        self.p_fc = nn.Linear(2 * self.N * self.N, self.N * self.N)

        # Value head
        self.v_head = nn.Sequential(
            nn.Conv2d(C, 1, kernel_size=1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )
        self.v_fc1 = nn.Linear(1 * self.N * self.N, C)
        self.v_fc2 = nn.Linear(C, 1)

        # (optional) init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                nn.init.zeros_(m.bias)

    def forward(self, x):  # x: [B, 3, N, N]
        z = self.stem(x)
        z = self.trunk(z)

        # policy
        p = self.p_head(z)
        p = p.view(p.size(0), -1)
        policy_logits = self.p_fc(p)  # [B, N*N]

        # value
        v = self.v_head(z)
        v = v.view(v.size(0), -1)
        v = F.relu(self.v_fc1(v))
        value = torch.tanh(self.v_fc2(v))  # [-1, 1]

        return policy_logits, value

# 3. Monte Carlo Tree Search (MCTS)

In [None]:
# ---- Node ----
class Node:
    def __init__(self, raw_board_tuple, player_to_move):
        self.state = raw_board_tuple         # raw board (0/±1) flattened as tuple
        self.player = player_to_move         # +1 or -1
        self.children = {}                   # action -> Node
        self.visit_count = 0
        self.total_value = 0.0               # from perspective of self.player
        self.prior = 0.0

    def value(self):
        return 0.0 if self.visit_count == 0 else self.total_value / self.visit_count

# ---- MCTS ----
class MCTS:
    def __init__(self, model, simulations=50, c_puct=1.0):
        self.model = model
        self.simulations = simulations
        self.c_puct = c_puct

    def run(self, env):
        # store RAW board in the root; player stored separately
        root = Node(tuple(env.board.flatten()), env.current_player)
        for _ in range(self.simulations):
            self.simulate(root)
        return root

    def simulate(self, node):
        # Terminal?
        winner_val, done = self.check_terminal(node.state)
        if done:
            v = float(winner_val)            # +1 win for node.player, 0 draw, -1 loss
            node.visit_count += 1
            node.total_value += v
            return v

        # Leaf?
        if node.visit_count == 0 and not node.children:
            # NN eval in canonical perspective

            # before (tic-tac-toe):
            # x = np.array(node.state, dtype=np.float32) * node.player
            # state_tensor = torch.from_numpy(x).unsqueeze(0)
            # policy_logits, value = self.model(state_tensor)
            # policy_logits = policy_logits.detach().numpy().flatten()
            
            # Gomoku:
            planes = gomoku_planes_from_raw(node.state, node.player, N)  # [3, N, N] np.float32
            state_tensor = torch.from_numpy(planes).unsqueeze(0)         # torch [1, 3, N, N]
            logits, value = self.model(state_tensor)
            policy = F.softmax(logits, dim=-1).detach().cpu().numpy().flatten()  # length N*N
            
            legal = self.legal_actions_from_raw(node.state)
            # mask + renorm
            mask = np.full(9, -1e9, dtype=np.float32)
            mask[legal] = 0.0
            masked = policy_logits + mask
            priors = np.exp(masked - masked.max())
            priors = priors / (priors[legal].sum() + 1e-8)

            for a in legal:
                child_board = self.apply_action_raw(node.state, a, node.player)
                child = Node(child_board, -node.player)
                child.prior = float(priors[a])
                node.children[a] = child

            v = float(value.item())          # value is from node.player perspective
            node.visit_count += 1
            node.total_value += v
            return v

        # Select child by PUCT
        best_score, best_a = -1e9, None
        sqrt_N = np.sqrt(node.visit_count + 1e-8)
        for a, child in node.children.items():
            ucb = child.value() + self.c_puct * child.prior * (sqrt_N / (1 + child.visit_count))
            if ucb > best_score:
                best_score, best_a = ucb, a

        # Recurse
        child = node.children[best_a]
        v_child = self.simulate(child)
        v = -v_child                           # switch perspective on backup
        node.visit_count += 1
        node.total_value += v
        return v

    # --- helpers on RAW board (0/±1) ---
    def legal_actions_from_raw(self, raw_board_tuple):
        raw = np.array(raw_board_tuple).reshape(3,3)
        return [i for i in range(9) if raw.flat[i] == 0]

    def apply_action_raw(self, raw_board_tuple, action, player):
        raw = np.array(raw_board_tuple).reshape(3,3).copy()
        r, c = divmod(action, 3)
        raw[r, c] = player
        return tuple(raw.flatten())

    def check_terminal(self, raw_board_tuple):
        b = np.array(raw_board_tuple).reshape(3,3)
        lines = list(b.sum(axis=0)) + list(b.sum(axis=1)) + [b.trace(), np.fliplr(b).trace()]
        if 3 in lines:   return (+1, True)  # current node.player wins
        if -3 in lines:  return (-1, True)  # current node.player loses
        if (b != 0).all(): return (0, True) # draw
        return (0, False)

    def select_action(self, root, temperature=1.0):
        visits = np.array([root.children[a].visit_count if a in root.children else 0 for a in range(9)], dtype=np.float64)
        legal = np.array([a in root.children for a in range(9)], dtype=bool)

        if temperature == 0 or visits.sum() == 0:
            # fallback: pick most visited among legal, otherwise random legal
            legal_idxs = np.where(legal)[0]
            if visits.sum() == 0:
                return np.random.choice(legal_idxs)
            return legal_idxs[np.argmax(visits[legal_idxs])]

        x = visits.astype(np.float64) ** (1.0 / temperature)
        x = x * legal
        s = x.sum()
        if s <= 0:
            legal_idxs = np.where(legal)[0]
            return np.random.choice(legal_idxs)
        p = x / s
        return np.random.choice(9, p=p)