In [1]:
!pip install tf_agents
!pip install "tensorflow-probability==0.24.0"
!pip install tf-keras

Collecting tf_agents
  Downloading tf_agents-0.19.0-py3-none-any.whl.metadata (12 kB)
Collecting gym<=0.23.0,>=0.17.0 (from tf_agents)
  Downloading gym-0.23.0.tar.gz (624 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m624.4/624.4 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting typing-extensions==4.5.0 (from tf_agents)
  Downloading typing_extensions-4.5.0-py3-none-any.whl.metadata (8.5 kB)
Collecting pygame==2.1.3 (from tf_agents)
  Downloading pygame-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.3 kB)
Collecting tensorflow-probability~=0.23.0 (from tf_agents)
  Downloading tensorflow_probability-0.23.0-py2.py3-none-any.whl.metadata (13 kB)
Downloading tf_agents-0.19.0-py3-none-any.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m 

In [2]:
#!/usr/bin/env python3
"""
PPO training for Super Tic-Tac-Toe with:
- legal action masking
- reward shaping
- parallel environments
- mixed opponent policies
"""

import random, time, pathlib, math
import numpy as np, matplotlib.pyplot as plt, tqdm.auto as tqdm
import tensorflow as tf, tensorflow_probability as tfp
from tf_agents.environments  import py_environment, tf_py_environment, parallel_py_environment
from tf_agents.specs         import array_spec
from tf_agents.trajectories  import time_step as ts, policy_step, from_transition
from tf_agents.networks      import network
from tf_agents.agents.ppo    import ppo_agent
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.policies      import py_policy
from tf_agents.system        import multiprocessing as tf_mp
try:
    tf_mp.enable_interactive_mode()
except ValueError:
    pass
from tf_agents.policies import greedy_policy

# ─── GPU configuration ───────────────────────────────────────────────
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for g in gpus:
        tf.config.experimental.set_memory_growth(g, True)
    print(f"{len(gpus)} GPU(s) memory-growth enabled")

# ═════════════ Board geometry and constants ══════════════════════════════
BOARD = 12
EMPTY, X, O = 0, 1, -1

# Create cross-shaped mask over 12×12 grid
CROSS = np.zeros((BOARD, BOARD), bool)
C0 = BOARD // 2 - 2
CROSS[C0:C0+4, C0:C0+4] = True
CROSS[:C0,   C0:C0+4] = True
CROSS[C0+4:, C0:C0+4] = True
CROSS[C0:C0+4, :C0]   = True
CROSS[C0:C0+4, C0+4:] = True

# Map between flat action index and board coordinates
COORDS    = [(r, c) for r in range(BOARD) for c in range(BOARD) if CROSS[r, c]]
IDX2COORD = np.array(COORDS, int)
N_CELLS   = len(COORDS)

# Eight neighbor offsets for random move distribution
NEIGH8  = [(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)]
# Precomputed linear indices for TensorFlow masking
IDX_LIN = tf.constant([r * BOARD + c for r, c in COORDS], tf.int32)

def in_bounds(r, c):
    """Check if (r, c) is within the BOARD and part of CROSS."""
    return 0 <= r < BOARD and 0 <= c < BOARD

def _count_line(board, r, c, p):
    """
    Count contiguous pieces of player p through (r,c) in 4 directions:
    horizontal, vertical, main diagonal, anti-diagonal.
    Returns list of counts for each direction.
    """
    out = []
    for dr, dc in [(0,1), (1,0), (1,1), (-1,1)]:
        cnt = 1
        # forward direction
        rr, cc = r + dr, c + dc
        while in_bounds(rr, cc) and CROSS[rr, cc] and board[rr, cc] == p:
            cnt += 1
            rr += dr; cc += dc
        # backward direction
        rr, cc = r - dr, c - dc
        while in_bounds(rr, cc) and CROSS[rr, cc] and board[rr, cc] == p:
            cnt += 1
            rr -= dr; cc -= dc
        out.append(cnt)
    return out

# ═════════════ Reward shaping constants ══════════════════════════════
SCALE       = 1.0
R_LEGAL     = 0.5    # base reward for a legal move
B_TWO       = 0.5    # bonus for making a 2-in-a-row
B_THREE     = 0.8    # bonus for making 3-in-a-row or more
R_ILLEGAL   = -1.0   # penalty for illegal move
R_PASS_EDGE = -1.0   # penalty when random adjacent picks out-of-bounds
P_NO_SPACE  = -1.0   # penalty if no empty neighbors after placement
P_IGN_THR   = -0.7   # penalty for ignoring opponent's 3-in-a-row threat

# ═════════════ Environment definition ════════════════════════════════
class SuperTicTacToe(py_environment.PyEnvironment):
    """Custom PyEnvironment for the Super Tic-Tac-Toe game."""
    def __init__(self, opponent="random"):
        """
        Initialize environment specs, opponent policies, and state.
        opponent: "random", "rule", or "mixed".
        """
        self._obs_spec = array_spec.BoundedArraySpec((BOARD, BOARD, 3), np.float32, 0, 1)
        self._act_spec = array_spec.BoundedArraySpec((), np.int32, 0, N_CELLS - 1)
        self._opp_type = opponent
        self._rnd  = RandomPolicy(self)
        self._rule = RuleBasedPolicy(self)
        self._state = None
        self._done  = False
        self._turn  = X

    def observation_spec(self):
        """Return the observation spec."""
        return self._obs_spec

    def action_spec(self):
        """Return the action spec."""
        return self._act_spec

    def _reset(self):
        """Reset the board state and turn. Return initial TimeStep."""
        self._state = np.zeros((BOARD, BOARD), np.int8)
        self._turn  = X
        self._done  = False
        return ts.restart(self._obs())

    def _step(self, action):
        """
        Apply agent move, check terminal, then apply opponent move,
        and return the appropriate TimeStep.
        """
        if self._done:
            return self.reset()

        # Agent's move
        reward = self._apply_move(action, self._turn, agent_move=True)
        done, winner = self._check_terminal()
        if done:
            return ts.termination(self._obs(), self._final_reward(winner))

        # Select opponent action
        opp_pol = self._rule if self._opp_type == "rule" else self._rnd
        if self._opp_type == "mixed":
            opp_pol = self._rule if random.random() < 0.3 else self._rnd
        opp_act = opp_pol.action(self._fake()).action

        # Opponent's move
        self._apply_move(int(opp_act), self._turn, agent_move=False)
        done, winner = self._check_terminal()
        if done:
            return ts.termination(self._obs(), self._final_reward(winner))

        # Continue with transition
        return ts.transition(self._obs(), reward, discount=1.0)

    def _apply_move(self, idx, player, agent_move=True):
        """
        Attempt to place piece for player at IDX2COORD[idx].
        With 50% chance place at target, else random empty neighbor.
        Returns shaped reward or penalty.
        """
        r, c = IDX2COORD[int(idx)]
        # Illegal: occupied or outside CROSS
        if self._state[r, c] != EMPTY or not CROSS[r, c]:
            if agent_move:
                # switch turn on illegal agent move
                self._turn = -player
                return R_ILLEGAL
            else:
                return 0.0

        # 50% chance to place at chosen spot
        if random.random() < 0.5:
            self._state[r, c] = player
            raw = self._positional_reward(r, c, player) if agent_move else 0.0
            self._turn = -player
            return raw * SCALE

        # Otherwise choose random legal neighbor
        legal = [
            (r+dr, c+dc)
            for dr, dc in NEIGH8
            if in_bounds(r+dr, c+dc)
            and CROSS[r+dr, c+dc]
            and self._state[r+dr, c+dc] == EMPTY
        ]
        if not legal:
            self._turn = -player
            return (R_PASS_EDGE * SCALE) if agent_move else 0.0

        rr, cc = random.choice(legal)
        self._state[rr, cc] = player
        raw = self._positional_reward(rr, cc, player) if agent_move else 0.0
        self._turn = -player
        return raw * SCALE

    def _positional_reward(self, r, c, p):
        """
        Compute shaped reward after placing at (r,c) by player p:
        - base legal move reward
        - bonus for 2-in-row or 3+-in-row
        - penalty if no empty neighbors
        - penalty if ignoring opponent threat
        """
        bonus = R_LEGAL

        # Bonus for creating lines
        own_lengths = _count_line(self._state, r, c, p)
        m = max(own_lengths)
        if m >= 3:
            bonus += B_THREE
        elif m == 2:
            bonus += B_TWO

        # Penalty if no empty adjacent cells
        no_space = True
        for dr, dc in NEIGH8:
            rr, cc = r + dr, c + dc
            if in_bounds(rr, cc) and CROSS[rr, cc] and self._state[rr, cc] == EMPTY:
                no_space = False
                break
        if no_space:
            bonus += P_NO_SPACE

        # Penalty if opponent has an active 3-in-row threat
        if self._opponent_threat():
            bonus += P_IGN_THR

        return bonus

    def _opponent_threat(self):
        """
        Check if opponent has any 3-in-row threat or live-2 threat
        not yet blocked on the board.
        """
        for r, c in COORDS:
            if self._state[r, c] != -self._turn:
                continue
            lengths = _count_line(self._state, r, c, -self._turn)
            if max(lengths) >= 3:
                return True
            # Live-2 horizontal threat
            if lengths[0] == 2:
                left  = (r, c-1)
                right = (r, c+1)
                if all(in_bounds(rr, cc) and CROSS[rr, cc] and self._state[rr, cc] == EMPTY
                       for rr, cc in [left, right]):
                    return True
            # Live-2 vertical threat
            if lengths[1] == 2:
                up   = (r-1, c)
                down = (r+1, c)
                if all(in_bounds(rr, cc) and CROSS[rr, cc] and self._state[rr, cc] == EMPTY
                       for rr, cc in [up, down]):
                    return True
        return False

    def _check_terminal(self):
        """
        Check if game is won or board is full.
        Returns (done_flag, winner) where winner ∈ {X, O, 0}.
        """
        for r, c in COORDS:
            p = self._state[r, c]
            if p == EMPTY:
                continue
            lengths = _count_line(self._state, r, c, p)
            # Win: 4 in row/col or 5 on diagonal
            if (max(lengths) >= 4 and any(lengths[:2])) or max(lengths) >= 5:
                return True, p
        # Draw if no empty cells
        if np.all(self._state[CROSS] != EMPTY):
            return True, 0
        return False, None

    @staticmethod
    def _final_reward(w):
        """
        Compute final reward: +3 for X win, -3 for O win, 0 for draw.
        """
        base = 3.0 if w == X else -3.0 if w == O else 0.0
        return base * SCALE

    def _obs(self):
        """
        Build observation tensor of shape (BOARD,BOARD,3):
        channel 0: X positions, channel 1: O positions,
        channel 2: current player flag.
        """
        x = (self._state == X).astype(np.float32)
        o = (self._state == O).astype(np.float32)
        p = np.full_like(x, 1.0 if self._turn == X else 0.0, np.float32)
        return np.stack([x, o, p], -1)

    def _fake(self):
        """
        Generate a fake TimeStep for opponent policy calls,
        preserving observation shape.
        """
        return ts.TimeStep(
            step_type   = np.array(1, np.int32),
            reward      = np.array(0.0, np.float32),
            discount    = np.array(1.0, np.float32),
            observation = self._obs()
        )

# ════════════ Opponent policies ════════════════════════════
class RandomPolicy(py_policy.PyPolicy):
    """Random legal-move policy for opponent."""
    def __init__(self, env):
        super().__init__(env.time_step_spec(), env.action_spec())

    def _action(self, time_step, _=()):
        """
        Choose a random empty cell from the board as action.
        """
        B = time_step.observation[..., 0] - time_step.observation[..., 1]
        legal = [i for i, (r, c) in enumerate(IDX2COORD) if B[r, c] == 0]
        return policy_step.PolicyStep(np.int32(random.choice(legal)), (), ())

class RuleBasedPolicy(py_policy.PyPolicy):
    """Greedy blocking/winning policy for opponent."""
    def __init__(self, env):
        super().__init__(env.time_step_spec(), env.action_spec())

    def _action(self, time_step, _=()):
        """
        If a winning move exists, take it; otherwise random legal.
        """
        Bx = time_step.observation[..., 0]
        Bo = time_step.observation[..., 1]
        B  = Bo - Bx
        # Check for winning move for opponent
        for idx, (r, c) in enumerate(IDX2COORD):
            if B[r, c] != 0:
                continue
            T = B.copy()
            T[r, c] = 1
            if max(_count_line(T, r, c, 1)) >= 4:
                return policy_step.PolicyStep(idx, (), ())
        # Otherwise pick random empty
        legal = [i for i, (r, c) in enumerate(IDX2COORD) if B[r, c] == 0]
        return policy_step.PolicyStep(np.int32(random.choice(legal)), (), ())

# ════════════ Actor network with legal-mask ═══════════════════════════════
def legal_mask(obs):
    """
    Compute binary mask of legal moves from observation:
    1 for empty CROSS cells, 0 otherwise.
    """
    empty = 1.0 - tf.cast(tf.reduce_max(obs[..., :2], -1), tf.float32)
    flat  = tf.reshape(empty, (tf.shape(empty)[0], BOARD * BOARD))
    return tf.gather(flat, IDX_LIN, axis=1)

class MaskedActor(network.Network):
    """
    Actor network producing a masked categorical distribution
    over N_CELLS legal actions.
    """
    def __init__(self, obs_spec, act_spec,
                 conv_params=((64,3,1),(64,3,1),(64,3,1)),
                 fc_params=(256,128)):
        super().__init__(input_tensor_spec=obs_spec, state_spec=(), name="actor")
        # Convolutional feature extractor
        self._conv  = [
            tf.keras.layers.Conv2D(filters=n, kernel_size=k, strides=s,
                                   activation='relu', padding='same')
            for n, k, s in conv_params
        ]
        self._flat  = tf.keras.layers.Flatten()
        # Fully connected layers
        self._fc    = [tf.keras.layers.Dense(units=u, activation='relu')
                       for u in fc_params]
        # Output logits for all BOARD*BOARD positions
        self._logits = tf.keras.layers.Dense(act_spec.maximum + 1)

    def call(self, obs, step_type=None, network_state=(), training=False):
        """
        Forward pass: conv -> flatten -> fc -> logits -> mask -> distribution.
        """
        bs = tf.shape(obs)[:-3]
        o  = tf.reshape(obs, (-1, BOARD, BOARD, 3))
        x  = tf.cast(o, tf.float32)
        for layer in self._conv:
            x = layer(x, training=training)
        x = self._flat(x)
        for layer in self._fc:
            x = layer(x, training=training)
        logits = self._logits(x)
        # Mask illegal actions by subtracting large constant
        mask   = legal_mask(o)
        logits = tf.where(mask > 0, logits, logits - 1e2)
        logits = tf.reshape(logits, tf.concat([bs, [N_CELLS]], axis=0))
        return tfp.distributions.Categorical(logits=logits), network_state

class SharedValue(network.Network):
    """
    Shared trunk value network using same conv+fc as actor,
    with final dense to scalar value.
    """
    def __init__(self, actor):
        super().__init__(input_tensor_spec=actor.input_tensor_spec,
                         state_spec=(), name="value")
        self._conv  = actor._conv
        self._flat  = actor._flat
        self._fc    = actor._fc
        self._v     = tf.keras.layers.Dense(1)

    def call(self, obs, step_type=None, network_state=(), training=False):
        """
        Forward pass: conv -> flatten -> fc -> single value -> squeeze.
        """
        bs = tf.shape(obs)[:-3]
        o  = tf.reshape(obs, (-1, BOARD, BOARD, 3))
        x  = tf.cast(o, tf.float32)
        for layer in self._conv:
            x = layer(x, training=training)
        x = self._flat(x)
        for layer in self._fc:
            x = layer(x, training=training)
        v = self._v(x)
        v = tf.reshape(v, tf.concat([bs, [1]], axis=0))
        return tf.squeeze(v, -1), network_state

# ───────── Hyper-parameters ─────────────────────────────────────────────
EPOCHS          = 1000
PRETRAIN_EPOCHS = 20
NUM_ENVS        = 8
COLLECT_PER_ENV = 64
EVAL_EPIS       = 80
SAVE_EVERY      = 200
TARGET_WIN      = 0.90

LR_BASE      = 1e-4
ENTROPY_INIT = 0.2
MIN_ENT_COEF = 0.05
CLIP_RATIO   = 0.30
NUM_PPO_EPOCH= 5
VALUE_COEF   = 1.5
GAE_LAMBDA   = 0.80



# Unit-test

In [5]:
# -*- coding: utf-8 -*-
"""
Super Tic-Tac-Toe environment, policies, and unit tests.

This module defines the game logic for a variant of Tic-Tac-Toe played on a 12x12 board
with a cross-shaped playable area, two simple policies, and unit tests.
"""

# Imports for testing
import unittest
from unittest.mock import patch, MagicMock
import numpy as np
import tensorflow as tf
import random  # Will be mocked where necessary

# tf-agents imports
from tf_agents.environments import py_environment
from tf_agents.specs import array_spec
from tf_agents.trajectories import time_step as ts
from tf_agents.policies import py_policy
from tf_agents.trajectories import policy_step

# Board geometry and constants
BOARD = 12  # Board size (12x12)
EMPTY, X, O = 0, 1, -1  # Represent empty, X, and O

# Define playable "cross" region on the board
CROSS = np.zeros((BOARD, BOARD), bool)
C0 = BOARD // 2 - 2  # Center offset for cross arms
CROSS[C0:C0+4, C0:C0+4] = True
CROSS[:C0,   C0:C0+4] = True
CROSS[C0+4:, C0:C0+4] = True
CROSS[C0:C0+4, :C0]   = True
CROSS[C0:C0+4, C0+4:] = True

# List of playable coordinates and flatten index mapping
COORDS    = [(r, c) for r in range(BOARD) for c in range(BOARD) if CROSS[r, c]]
IDX2COORD = np.array(COORDS, int)
N_CELLS   = len(COORDS)

# Neighbor offsets for random move fallback
NEIGH8  = [(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)]

# TensorFlow index tensor for legal mask computation
IDX_LIN = tf.constant([r * BOARD + c for r, c in COORDS], tf.int32)

# Reward scaling and values
SCALE       = 1.0
R_LEGAL     = 0.5
R_ILLEGAL   = -1.0
R_PASS_EDGE = -1.0


def in_bounds(r, c):
    """
    Check if a board coordinate is within the valid 0..BOARD-1 range.

    Args:
        r (int): Row index.
        c (int): Column index.

    Returns:
        bool: True if (r, c) is within the board boundaries, False otherwise.
    """
    return 0 <= r < BOARD and 0 <= c < BOARD


def _count_line(board, r, c, p):
    """
    Count consecutive pieces of player p through (r, c) in 4 directions.

    Scans horizontally, vertically, and two diagonals.

    Args:
        board (np.ndarray): 2D board array.
        r (int): Row of the piece to count around.
        c (int): Column of the piece to count around.
        p (int): Player value (X or O).

    Returns:
        List[int]: Counts of consecutive pieces along each of the 4 directions.
    """
    out = []
    # Directions: (0,1)=horiz, (1,0)=vert, (1,1)=diag down-right, (-1,1)=diag up-right
    for dr, dc in [(0,1), (1,0), (1,1), (-1,1)]:
        cnt = 1
        # Scan forward
        rr, cc = r + dr, c + dc
        while in_bounds(rr, cc) and CROSS[rr, cc] and board[rr, cc] == p:
            cnt += 1
            rr += dr; cc += dc
        # Scan backward
        rr, cc = r - dr, c - dc
        while in_bounds(rr, cc) and CROSS[rr, cc] and board[rr, cc] == p:
            cnt += 1
            rr -= dr; cc -= dc
        out.append(cnt)
    return out


class SuperTicTacToe(py_environment.PyEnvironment):
    """
    PyEnvironment for the Super Tic-Tac-Toe game.

    Playable area is a cross of size 12x12, and players alternate placing X or O.
    Episodes end on a win (4 in a row) or draw.
    """

    def __init__(self, opponent="random"):
        """
        Initialize environment state and specs.

        Args:
            opponent (str): Opponent policy name (unused in stub).
        """
        self._obs_spec = array_spec.BoundedArraySpec(
            (BOARD, BOARD, 3), np.float32, 0, 1)
        self._act_spec = array_spec.BoundedArraySpec(
            (), np.int32, 0, N_CELLS - 1)
        self._state = np.zeros((BOARD, BOARD), np.int8)
        self._turn  = X
        self._done  = False

    def observation_spec(self):
        """Return the observation spec."""
        return self._obs_spec

    def action_spec(self):
        """Return the action spec."""
        return self._act_spec

    def _reset(self):
        """
        Reset the environment to the initial state.

        Returns:
            TimeStep: A restart time_step with initial observation.
        """
        self._state = np.zeros((BOARD, BOARD), np.int8)
        self._turn  = X
        self._done  = False
        return ts.restart(self._obs())

    def _obs(self):
        """
        Encode the current state as an observation tensor.

        Observation planes: X positions, O positions, current player flag.

        Returns:
            np.ndarray: Observation of shape (BOARD, BOARD, 3).
        """
        x_plane = (self._state == X).astype(np.float32)
        o_plane = (self._state == O).astype(np.float32)
        player_val = 1.0 if self._turn == X else 0.0
        player_plane = np.full_like(x_plane, player_val, np.float32)
        return np.stack([x_plane, o_plane, player_plane], axis=-1)

    def _step(self, action):
        """
        Apply the action to the environment.

        Simplified stub: places piece if valid, flips turn,
        checks terminal, and returns appropriate TimeStep.

        Args:
            action (int): Index into legal moves (0..N_CELLS-1).

        Returns:
            TimeStep: transition, termination, or restart.
        """
        if self._done:
            return self.reset()
        # Apply move if in legal range
        if 0 <= action < N_CELLS:
            r, c = IDX2COORD[action]
            if self._state[r, c] == EMPTY and CROSS[r, c]:
                self._state[r, c] = self._turn
        # Flip turn
        self._turn = -self._turn
        # Check for terminal state
        self._done, winner = self._check_terminal()
        if self._done:
            return ts.termination(self._obs(), self._final_reward(winner))
        return ts.transition(self._obs(), reward=R_LEGAL, discount=1.0)

    @staticmethod
    def _final_reward(w):
        """
        Compute final reward based on winner.

        Args:
            w (int or None): Winner (X, O, or 0 for draw).

        Returns:
            float: Scaled reward (±3.0 or 0 for draw).
        """
        if w == X:
            return 3.0 * SCALE
        if w == O:
            return -3.0 * SCALE
        return 0.0 * SCALE

    def _check_terminal(self):
        """
        Check if the game has ended by win or draw.

        Returns:
            (bool, int or None): done flag and winner (X/O), 0 for draw, None otherwise.
        """
        # Win if any 4-in-a-row
        for r_idx, c_idx in COORDS:
            p = self._state[r_idx, c_idx]
            if p == EMPTY:
                continue
            lengths = _count_line(self._state, r_idx, c_idx, p)
            if lengths and max(lengths) >= 4:
                return True, p
        # Draw if no empty cells remain
        if np.all(self._state[CROSS] != EMPTY):
            return True, 0
        return False, None

    def _mockable_positional_reward(self, r, c, p):
        """
        Positional reward hook for mocking/testing.

        Args:
            r, c (int): Coordinates of placement.
            p (int): Player who moved.

        Returns:
            float: Base legal reward.
        """
        return R_LEGAL

    def _apply_move(self, idx, player, agent_move=True):
        """
        Apply a move with randomization: either place at idx or neighbor.

        Args:
            idx (int): Preferred move index.
            player (int): Player making the move.
            agent_move (bool): True if move by agent (affects turn and reward).

        Returns:
            float: Reward for the move (legal, pass, or illegal).
        """
        r, c = IDX2COORD[int(idx)]
        # Illegal if occupied or outside play area
        if not (in_bounds(r, c) and CROSS[r, c] and self._state[r, c] == EMPTY):
            if agent_move:
                self._turn = -player
                return R_ILLEGAL
            return 0.0
        # Randomly decide direct placement or neighbor fallback
        if random.random() < 0.5:
            self._state[r, c] = player
            raw = self._mockable_positional_reward(r, c, player) if agent_move else 0.0
            self._turn = -player
            return raw * SCALE
        # Fallback placement in a random legal neighbor
        neighbors = []
        for dr, dc in NEIGH8:
            nr, nc = r + dr, c + dc
            if in_bounds(nr, nc) and CROSS[nr, nc] and self._state[nr, nc] == EMPTY:
                neighbors.append((nr, nc))
        if not neighbors:
            self._turn = -player
            return R_PASS_EDGE * SCALE if agent_move else 0.0
        rr, cc = random.choice(neighbors)
        self._state[rr, cc] = player
        raw = self._mockable_positional_reward(rr, cc, player) if agent_move else 0.0
        self._turn = -player
        return raw * SCALE


class RandomPolicy(py_policy.PyPolicy):
    """
    A policy that picks a random legal move.
    """
    def __init__(self, time_step_spec, action_spec):
        super().__init__(time_step_spec, action_spec)

    def _action(self, time_step_obj, policy_state=()):
        """
        Choose a random legal action from the observation.

        Args:
            time_step_obj (TimeStep): Current environment timestep.
            policy_state: Policy state (unused).

        Returns:
            PolicyStep: Contains chosen action.
        """
        obs = time_step_obj.observation
        board_x = obs[..., 0]
        board_o = obs[..., 1]
        current = board_x - board_o
        # Find empty CROSS cells
        legal = [i for i, (r, c) in enumerate(IDX2COORD)
                 if current[r, c] == EMPTY and CROSS[r, c]]
        if not legal:
            return policy_step.PolicyStep(action=np.int32(0), state=(), info=())
        choice = random.choice(legal)
        return policy_step.PolicyStep(action=np.int32(choice), state=(), info=())


class RuleBasedPolicy(py_policy.PyPolicy):
    """
    A simple policy that takes a winning move if available, else random.
    """
    def __init__(self, time_step_spec, action_spec):
        super().__init__(time_step_spec, action_spec)

    def _action(self, time_step_obj, policy_state=()):
        """
        Evaluate each legal move to see if it wins immediately; otherwise random.

        Args:
            time_step_obj (TimeStep): Current observation.
            policy_state: Policy state (unused).

        Returns:
            PolicyStep: Contains chosen action.
        """
        obs = time_step_obj.observation
        x_plane, o_plane = obs[..., 0], obs[..., 1]
        # Try to win as O
        for idx, (r, c) in enumerate(IDX2COORD):
            if x_plane[r, c] == 0 and o_plane[r, c] == 0 and CROSS[r, c]:
                temp = np.zeros((BOARD, BOARD), np.int8)
                temp[x_plane == 1] = X
                temp[o_plane == 1] = O
                temp[r, c] = O
                if max(_count_line(temp, r, c, O)) >= 4:
                    return policy_step.PolicyStep(action=np.int32(idx), state=(), info=())
        # Otherwise pick random legal
        legal = [i for i, (r, c) in enumerate(IDX2COORD)
                 if x_plane[r, c] == 0 and o_plane[r, c] == 0 and CROSS[r, c]]
        if not legal:
            return policy_step.PolicyStep(action=np.int32(0), state=(), info=())
        choice = random.choice(legal)
        return policy_step.PolicyStep(action=np.int32(choice), state=(), info=())


def legal_mask(obs_tensor):
    """
    Compute a mask of legal moves (1.0 legal, 0.0 illegal) from observation.

    Args:
        obs_tensor (tf.Tensor): Observation, shape [..., BOARD, BOARD, 3].

    Returns:
        tf.Tensor: Mask of shape [..., N_CELLS].
    """
    obs_f = tf.cast(obs_tensor, tf.float32)
    # Ensure batch dimension
    if len(tf.shape(obs_f)) == 3:
        obs_f = tf.expand_dims(obs_f, 0)
    # Empty plane = not occupied by X or O
    empty = 1.0 - tf.reduce_max(obs_f[..., :2], axis=-1)
    flat   = tf.reshape(empty, (tf.shape(empty)[0], BOARD * BOARD))
    mask   = tf.gather(flat, IDX_LIN, axis=1)
    # Remove batch dim if needed
    if mask.shape[0] == 1:
        mask = tf.squeeze(mask, axis=0)
    return mask


# Unit test classes
class TestBoardUtils(unittest.TestCase):
    """Tests for board utility functions (in_bounds, _count_line)."""
    def test_in_bounds_valid_center(self):
        """Center of cross is in bounds."""
        self.assertTrue(in_bounds(C0 + 1, C0 + 1))

    # ... (other test methods remain unchanged, relying on implicit naming) ...

class TestSuperTicTacToe(unittest.TestCase):
    """Tests for SuperTicTacToe environment core functionality."""
    def setUp(self):
        self.env = SuperTicTacToe()
        self.env._reset()

    # ... test methods ...

class TestPolicies(unittest.TestCase):
    """Tests for RandomPolicy and RuleBasedPolicy behavior."""
    def setUp(self):
        self.obs_spec = array_spec.BoundedArraySpec(
            (BOARD, BOARD, 3), np.float32, 0, 1)
        self.act_spec = array_spec.BoundedArraySpec(
            (), np.int32, 0, N_CELLS - 1)
        self.time_step_spec = ts.time_step_spec(self.obs_spec)

    # ... test methods ...

class TestNetworkFunctions(unittest.TestCase):
    """Tests for legal_mask computation."""
    # ... test methods ...


def run_tests_in_colab():
    """
    Loader and runner for executing unit tests in Colab.
    """
    tf.config.set_visible_devices([], 'GPU')
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    suite.addTests(loader.loadTestsFromTestCase(TestBoardUtils))
    suite.addTests(loader.loadTestsFromTestCase(TestSuperTicTacToe))
    suite.addTests(loader.loadTestsFromTestCase(TestPolicies))
    suite.addTests(loader.loadTestsFromTestCase(TestNetworkFunctions))
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)
    if result.wasSuccessful():
        print("\nAll selected unit tests passed successfully!")
    else:
        print("\nSome unit tests FAILED.")
        if result.errors:
            print("\nErrors:")
            for test, err in result.errors:
                print(f"  {test}: {err}")
        if result.failures:
            print("\nFailures:")
            for test, fail in result.failures:
                print(f"  {test}: {fail}")

run_tests_in_colab()


test_count_line_horizontal_three_x (__main__.TestBoardUtils.test_count_line_horizontal_three_x) ... ok
test_count_line_single_piece (__main__.TestBoardUtils.test_count_line_single_piece) ... ok
test_count_line_vertical_three_o (__main__.TestBoardUtils.test_count_line_vertical_three_o) ... ok
test_in_bounds_invalid_outside_board (__main__.TestBoardUtils.test_in_bounds_invalid_outside_board) ... ok
test_in_bounds_valid_center (__main__.TestBoardUtils.test_in_bounds_valid_center) ... ok
test_in_bounds_valid_corner_if_board_large_enough (__main__.TestBoardUtils.test_in_bounds_valid_corner_if_board_large_enough) ... ok
test_in_bounds_valid_cross_arm_edge (__main__.TestBoardUtils.test_in_bounds_valid_cross_arm_edge) ... ok
  self._minimum[self._minimum == -np.inf] = low
  self._minimum[self._minimum == np.inf] = high
  self._maximum[self._maximum == -np.inf] = low
  self._maximum[self._maximum == np.inf] = high
ok
test_apply_move_legal_neighbor (__main__.TestSuperTicTacToe.test_apply_move_le


All selected unit tests passed successfully!
