# Imports

In [36]:


import torch
import torch.nn as nn
import torch.nn.functional as F
import chess
from torchsummary import summary
import numpy as np
import random
from tqdm.notebook import tqdm

# Device

In [37]:
if torch.cuda.is_available():
    device_count = torch.cuda.device_count()
    print(f"✅ {device_count} CUDA device(s) available:")
    for i in range(device_count):
        print(f"  └─ [{i}] {torch.cuda.get_device_name(i)}")
    device = torch.device("cuda")
else:
    print("⚠️ CUDA not available, using CPU.")
    device = torch.device("cpu")

✅ 1 CUDA device(s) available:
  └─ [0] NVIDIA GeForce GTX 1650


# Encoding

In [80]:

def board_encoder(board: chess.Board) -> np.ndarray:
    planes = np.zeros((19, 8, 8), dtype=np.float32)
    piece_map = board.piece_map()

    # 12 planes for pieces (6x2)
    for square, piece in piece_map.items():
        row, col = divmod(square, 8)
        plane_idx = (piece.piece_type - 1) + \
            (0 if piece.color == chess.WHITE else 6)
        planes[plane_idx][row][col] = 1

    # Turn to move
    planes[12].fill(float(board.turn))

    # Castling rights (4 planes)
    planes[13].fill(board.has_kingside_castling_rights(chess.WHITE))
    planes[14].fill(board.has_queenside_castling_rights(chess.WHITE))
    planes[15].fill(board.has_kingside_castling_rights(chess.BLACK))
    planes[16].fill(board.has_queenside_castling_rights(chess.BLACK))

    # Halfmove clock (normalized)
    planes[17].fill(board.halfmove_clock / 100.0)

    # En Passant square (new plane!)
    if board.ep_square is not None:
        row, col = divmod(board.ep_square, 8)
        planes[18][row][col] = 1.0

    return planes


def move_to_index(move: chess.Move) -> int:
    from_square = move.from_square
    to_square = move.to_square
    promotion = move.promotion

    dx = chess.square_file(to_square) - chess.square_file(from_square)
    dy = chess.square_rank(to_square) - chess.square_rank(from_square)

    # Handle knight moves separately
    knight_deltas = {(1, 2): 0, (2, 1): 1, (2, -1): 2, (1, -2):3,
                     (-1, -2): 4, (-2, -1): 5, (-2, 1): 6, (-1, 2): 7}
    if (dx, dy) in knight_deltas:
        dir_idx = knight_deltas[(dx, dy)]
        return from_square * 73 + 56 + dir_idx

    # Handle promotions
    if promotion is not None:
        promo_types = [chess.KNIGHT, chess.ROOK, chess.BISHOP]
        directions = [-1, 0, 1]  # left, straight, right
        dx = chess.square_file(to_square) - chess.square_file(from_square)

        if dx in directions and promotion in promo_types:
            dir_idx = directions.index(dx)
            promo_idx = promo_types.index(promotion)
            return from_square * 73 + 64 + (promo_idx * 3) + dir_idx

    # Handle normal directional moves
    directions = {(0, 1): 0, (1, 0): 1, (0, -1): 2, (-1, 0): 3,
                  (1, 1): 4, (1, -1): 5, (-1, 1): 6, (-1, -1): 7}
    curr_dir = [0, 0]
    if dx > 0: curr_dir[0] = 1
    if dx < 0: curr_dir[0] = -1
    if dy > 0: curr_dir[1] = 1
    if dy < 0: curr_dir[1] = -1
    curr_dir = tuple(curr_dir)
    diff = abs(dx) | abs(dy)
    assert diff > 0

    if curr_dir in directions:
        dir_index = directions[curr_dir]
        return from_square * 73 + (diff-1)*8 + dir_index

    raise ValueError(f"Move {move} could not be encoded.")


def index_to_move(index: int) -> chess.Move:
    from_square = index // 73
    sub_index = index % 73

    if sub_index < 56:
        # Sliding move
        directions = [(0, 1), (1, 0), (0, -1), (-1, 0), (1, 1), (1, -1), (-1, 1), (-1, -1)]
        diff, dir = sub_index // 8, sub_index%8
        diff += 1
        dx, dy = ( diff*directions[dir][0], diff*directions[dir][1] )
        fx = chess.square_file(from_square)
        fy = chess.square_rank(from_square)
        tx = fx + dx
        ty = fy + dy
        if 0 <= tx < 8 and 0 <= ty < 8:
            to_square = chess.square(tx, ty)
            move = chess.Move(from_square, to_square)
            return move

    elif sub_index < 64:
        # Knight move
        knight_deltas = [(1, 2), (2, 1), (2, -1), (1, -2),
                         (-1, -2), (-2, -1), (-2, 1), (-1, 2)]
        dx, dy = knight_deltas[sub_index - 56]
        fx = chess.square_file(from_square)
        fy = chess.square_rank(from_square)
        tx = fx + dx
        ty = fy + dy
        if 0 <= tx < 8 and 0 <= ty < 8:
            to_square = chess.square(tx, ty)
            move = chess.Move(from_square, to_square)
            return move

    else:
        # Promotion
        promo_idx = sub_index - 64
        promotion_types = [chess.KNIGHT, chess.ROOK, chess.BISHOP]
        directions = [-1, 0, 1]
        promo_type = promotion_types[promo_idx // 3]
        dx = directions[promo_idx % 3]

        fx = chess.square_file(from_square)
        fy = chess.square_rank(from_square)
        tx = fx + dx
        if fy == 6:
            ty = 7
        elif fy == 1:
            ty = 0
        else:
            raise ValueError("Not a valid promotion square")

        if 0 <= tx < 8 and 0 <= ty < 8:
            to_square = chess.square(tx, ty)
            move = chess.Move(from_square, to_square, promotion=promo_type)
            return move

    raise ValueError(f"Index {index} could not be decoded to a legal move.")

In [None]:
def get_random_move():
    board = chess.Board()
    for _ in range(random.randint(0, 40)):
        if board.is_game_over():
            board = chess.Board()
        board.push(random.choice(list(board.legal_moves)))
    if not board.legal_moves: board = chess.Board()
    legal_moves = list(board.legal_moves)
    return random.choice(legal_moves)

for _ in tqdm(range(50000)):
    move = get_random_move()
    if 'q' in move.uci(): continue
    try:
        moveCirc = index_to_move(move_to_index(move))
    except:
        print(f"Move: {move}")
    assert move == moveCirc, f"Original: {move}, New: {moveCirc}"

  0%|          | 0/50000 [00:00<?, ?it/s]

# Model

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return F.relu(out)


class ChessNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_conv = nn.Conv2d(19, 64, kernel_size=3, padding=1)
        self.bn_input = nn.BatchNorm2d(64)

        self.res_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(5)])

        # Policy head
        self.policy_conv = nn.Conv2d(64, 2, kernel_size=1)
        self.policy_bn = nn.BatchNorm2d(2)
        self.policy_fc = nn.Linear(2 * 8 * 8, 4672)

        # Value head
        self.value_conv = nn.Conv2d(64, 1, kernel_size=1)
        self.value_bn = nn.BatchNorm2d(1)
        self.value_fc1 = nn.Linear(8 * 8, 64)
        self.value_fc2 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.bn_input(self.input_conv(x)))
        x = self.res_blocks(x)

        # Policy head
        p = F.relu(self.policy_bn(self.policy_conv(x)))
        p = p.view(p.size(0), -1)
        p = self.policy_fc(p)

        # Value head
        v = F.relu(self.value_bn(self.value_conv(x)))
        v = v.view(v.size(0), -1)
        v = F.relu(self.value_fc1(v))
        v = torch.tanh(self.value_fc2(v))

        return p, v


# Instantiate and print summary
model = ChessNet()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Print model summary for input size [batch_size=1, 19, 8, 8]
summary(model, input_size=(19, 8, 8))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1             [-1, 64, 8, 8]          11,008
       BatchNorm2d-2             [-1, 64, 8, 8]             128
            Conv2d-3             [-1, 64, 8, 8]          36,928
       BatchNorm2d-4             [-1, 64, 8, 8]             128
            Conv2d-5             [-1, 64, 8, 8]          36,928
       BatchNorm2d-6             [-1, 64, 8, 8]             128
     ResidualBlock-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,928
       BatchNorm2d-9             [-1, 64, 8, 8]             128
           Conv2d-10             [-1, 64, 8, 8]          36,928
      BatchNorm2d-11             [-1, 64, 8, 8]             128
    ResidualBlock-12             [-1, 64, 8, 8]               0
           Conv2d-13             [-1, 64, 8, 8]          36,928
      BatchNorm2d-14             [-1, 6

# MCTS Search

In [None]:
import numpy as np
import math
import chess


class MCTSNode:
    def __init__(self, board: chess.Board, parent=None, move=None):
        self.board = board
        self.parent = parent
        self.move = move
        self.children = {}
        self.visit_count = 0
        self.total_value = 0.0
        self.prior = 0.0

    def is_expanded(self):
        return len(self.children) > 0

    def value(self):
        if self.visit_count == 0:
            return 0
        return self.total_value / self.visit_count

def softmax_temperature(logits, temperature=1.0):
    logits = np.array(logits, dtype=np.float32)
    logits = logits / temperature
    exp_logits = np.exp(logits - np.max(logits))
    return exp_logits / np.sum(exp_logits)

def select_child(node, c_puct=1.0):
    best_score = -float('inf')
    best_move = None
    best_child = None

    for move, child in node.children.items():
        ucb = child.value() + c_puct * child.prior * \
            math.sqrt(node.visit_count) / (1 + child.visit_count)
        if ucb > best_score:
            best_score = ucb
            best_move = move
            best_child = child
    return best_move, best_child

def expand_node(node, policy_logits, legal_moves, policy_mapping):
    probs = softmax_temperature(policy_logits)
    for move in legal_moves:
        index = policy_mapping.get(move.uci(), None)
        if index is None:
            continue
        prior = probs[index]
        next_board = node.board.copy()
        next_board.push(move)
        node.children[move] = MCTSNode(next_board, parent=node, move=move)
        node.children[move].prior = prior

def backpropagate(node, value):
    while node:
        node.visit_count += 1
        node.total_value += value
        node = node.parent
        value = -value  # opponent’s perspective

def mcts_search(board, model, encoder, policy_mapping, num_simulations=800):
    root = MCTSNode(board)
    board_input = encoder(board)  # shape: (8, 8, N)
    board_input = np.expand_dims(board_input, 0)  # batch dim

    policy_logits, value = model.predict(board_input, verbose=0)
    policy_logits = policy_logits[0]
    value = float(value[0][0])

    legal_moves = list(board.legal_moves)
    expand_node(root, policy_logits, legal_moves, policy_mapping)

    for _ in range(num_simulations):
        node = root
        search_path = [node]

        # Selection
        while node.is_expanded():
            move, node = select_child(node)
            search_path.append(node)

        # Leaf Evaluation
        board_input = encoder(node.board)
        board_input = np.expand_dims(board_input, 0)
        policy_logits, value = model.predict(board_input, verbose=0)
        policy_logits = policy_logits[0]
        value = float(value[0][0])

        legal_moves = list(node.board.legal_moves)
        expand_node(node, policy_logits, legal_moves, policy_mapping)

        # Backup
        backpropagate(node, value)

    return root

def choose_action(root, temperature=1.0):
    visit_counts = np.array(
        [child.visit_count for child in root.children.values()])
    moves = list(root.children.keys())

    if temperature == 0:
        best_move = moves[np.argmax(visit_counts)]
        return best_move

    probs = softmax_temperature(visit_counts, temperature)
    return np.random.choice(moves, p=probs)

# Self Play

In [None]:
def play_self_play_game(model, encoder, policy_mapping, num_simulations=100):
    board = chess.Board()
    states = []
    policies = []
    outcomes = []

    while not board.is_game_over():
        root = mcts_search(board, model, encoder,
                           policy_mapping, num_simulations)
        move = choose_action(root, temperature=1.0)

        # Record data
        board_tensor = encoder(board)
        visit_counts = np.array(
            [child.visit_count for child in root.children.values()])
        visit_sum = np.sum(visit_counts)
        policy_target = np.zeros(len(policy_mapping))
        for m, child in root.children.items():
            idx = policy_mapping[m.uci()]
            policy_target[idx] = child.visit_count / visit_sum

        states.append(board_tensor)
        policies.append(policy_target)
        board.push(move)

    result = board.result()  # "1-0", "0-1", "1/2-1/2"
    if result == '1-0':
        z = 1
    elif result == '0-1':
        z = -1
    else:
        z = 0

    # Assign outcomes for all steps
    for i in range(len(states)):
        outcomes.append(z if (i % 2 == 0) else -z)

    return states, policies, outcomes

# Collect Data

In [None]:
def generate_training_data(model, encoder, policy_mapping, num_games=10):
    all_states = []
    all_policies = []
    all_values = []

    for _ in range(num_games):
        states, policies, values = play_self_play_game(
            model, encoder, policy_mapping)
        all_states.extend(states)
        all_policies.extend(policies)
        all_values.extend(values)

    return np.array(all_states), np.array(all_policies), np.array(all_values)

# Train the model

In [None]:
def train_model(model, X, P, Z, epochs=10, batch_size=64, device="cpu"):
    model.train()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn_policy = torch.nn.CrossEntropyLoss()
    loss_fn_value = torch.nn.MSELoss()

    dataset = torch.utils.data.TensorDataset(X, P, Z)
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0
        for xb, pb, zb in loader:
            xb, pb, zb = xb.to(device), pb.to(device), zb.to(device)

            pred_policy, pred_value = model(xb)
            loss_p = loss_fn_policy(pred_policy, torch.argmax(pb, dim=1))
            loss_v = loss_fn_value(pred_value.squeeze(), zb)
            loss = loss_p + loss_v

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")

# Prediction

In [None]:
@torch.no_grad()
def predict(board, model, policy_index_map, device="cpu"):
    model.eval()
    x = encode_board(board).unsqueeze(0).to(device)
    logits, value = model(x)
    probs = torch.softmax(logits, dim=1).squeeze()
    return probs.cpu(), value.item()

# Start