In [None]:
# ✅ 1. Install PyTorch (if needed)
!pip install torch

# ✅ 2. Import dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
from collections import deque

# ✅ 3. Define constants
BOARD_SIZE = 7
GAMES_PER_ITERATION = 5  # reduce for testing, increase later
REPLAY_BUFFER_SIZE = 50000
BATCH_SIZE = 64
EPOCHS = 2
NUM_ITERATIONS = 10000  # set high (e.g. 10000) for full training
MODEL_PATH = "side_stacker_model.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [2]:
# ✅ 4. Model definition
class SideStackerNet(nn.Module):
    def __init__(self, board_size=7):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.policy_head = nn.Linear(64 * board_size * board_size, board_size * 2)
        self.value_head = nn.Linear(64 * board_size * board_size, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        policy = self.policy_head(x)
        value = torch.tanh(self.value_head(x))
        return policy, value


In [3]:
# ✅ 5. Game utils
def board_to_tensor(board, player):
    p1 = [[1 if cell == player else 0 for cell in row] for row in board]
    p2 = [[1 if cell == -player else 0 for cell in row] for row in board]
    tensor = torch.tensor([p1, p2], dtype=torch.float32).unsqueeze(0)
    return tensor.to(device)

def get_valid_moves(board):
    valid_moves = []
    for row_index, row in enumerate(board):
        try:
            left_index = row.index(0)
            valid_moves.append((row_index, 'L'))
        except ValueError:
            pass
        try:
            right_index = len(row) - 1 - row[::-1].index(0)
            if right_index != left_index:
                valid_moves.append((row_index, 'R'))
        except ValueError:
            pass
    return valid_moves

def apply_move(board, row, direction, player):
    board = [list(r) for r in board]
    if direction == 'L':
        for col in range(BOARD_SIZE):
            if board[row][col] == 0:
                board[row][col] = player
                break
    else:
        for col in reversed(range(BOARD_SIZE)):
            if board[row][col] == 0:
                board[row][col] = player
                break
    return board

def check_winner(board):
    def check_line(line):
        for i in range(len(line) - 3):
            window = line[i:i+4]
            if sum(window) == 4:
                return 1
            elif sum(window) == -4:
                return -1
        return 0

    for row in board:
        if (res := check_line(row)) != 0:
            return res
    for col in zip(*board):
        if (res := check_line(col)) != 0:
            return res
    for d in range(-BOARD_SIZE + 1, BOARD_SIZE):
        diag1 = [board[i][i - d] for i in range(max(d, 0), min(BOARD_SIZE + d, BOARD_SIZE)) if 0 <= i - d < BOARD_SIZE]
        diag2 = [board[i][BOARD_SIZE - 1 - i + d] for i in range(max(-d, 0), min(BOARD_SIZE - d, BOARD_SIZE)) if 0 <= BOARD_SIZE - 1 - i + d < BOARD_SIZE]
        if (res := check_line(diag1)) != 0:
            return res
        if (res := check_line(diag2)) != 0:
            return res
    return 0


In [10]:
# ✅ 6. Self-play training logic
def self_play_game(model):
    board = [[0] * BOARD_SIZE for _ in range(BOARD_SIZE)]
    player = 1
    history = []

    while True:
        state_tensor = board_to_tensor(board, player)
        with torch.no_grad():
            logits, _ = model(state_tensor)
            probs = torch.softmax(logits, dim=1).cpu().numpy()[0]

        move_map = [(row, d) for row in range(BOARD_SIZE) for d in ['L', 'R']]
        valid_moves = get_valid_moves(board)
        legal = [(i, move_map[i]) for i in range(len(move_map)) if move_map[i] in valid_moves]
        if not legal:
            break

        legal_indices = [i for i, _ in legal]
        legal_probs = [probs[i] for i in legal_indices]
        prob_array = np.array(legal_probs, dtype=np.float64)

        # Handle total probability safely
        total = prob_array.sum()
        if total == 0 or np.isnan(total):
            # fallback: uniform distribution over legal moves
            prob_array = np.ones_like(prob_array) / len(prob_array)
        else:
            prob_array = prob_array / total

        # Final precision correction
        prob_array = np.clip(prob_array, 0, 1)
        prob_array = prob_array / prob_array.sum()  # ensure exact normalization

        # Optional debug assertion
        if not np.isclose(prob_array.sum(), 1.0):
            # fallback: uniform
            prob_array = np.ones_like(prob_array) / len(prob_array)


        chosen = np.random.choice(len(prob_array), p=prob_array)
        move = move_map[legal_indices[chosen]]

        policy_tensor = torch.zeros(len(move_map), dtype=torch.float32)
        policy_tensor[legal_indices[chosen]] = 1.0

        history.append((board_to_tensor(board, player).squeeze(0), policy_tensor, player))
        board = apply_move(board, move[0], move[1], player)
        winner = check_winner(board)
        if winner != 0:
            break
        player *= -1

    if winner != 0:
        return [(s, p, torch.tensor([1.0 if pl == winner else -1.0])) for s, p, pl in history]
    else:
        return [(s, p, torch.tensor([0.0])) for s, p, pl in history]


In [11]:
# ✅ 7. Train loop
def train_model(model, replay_buffer):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.train()
    for epoch in range(EPOCHS):
        random.shuffle(replay_buffer)
        for i in range(0, len(replay_buffer), BATCH_SIZE):
            batch = replay_buffer[i:i + BATCH_SIZE]
            if len(batch) < BATCH_SIZE:
                continue
            states, policies, values = zip(*batch)
            states = torch.stack(states).to(device)
            policies = torch.stack(policies).to(device)
            values = torch.stack(values).to(device)

            pred_policies, pred_values = model(states)
            loss = F.cross_entropy(pred_policies, policies) + F.mse_loss(pred_values.squeeze(), values.squeeze())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


In [None]:
# ✅ 8. Main run function
def run_training():
    model = SideStackerNet().to(device)
    replay_buffer = deque(maxlen=REPLAY_BUFFER_SIZE)

    for iteration in range(NUM_ITERATIONS):
        print(f"--- Iteration {iteration + 1}/{NUM_ITERATIONS} ---")
        for _ in range(GAMES_PER_ITERATION):
            game_data = self_play_game(model)
            replay_buffer.extend(game_data)

        print(f"Training on {len(replay_buffer)} samples...")
        train_model(model, list(replay_buffer))

        # Save after each iteration for safety
        torch.save(model.state_dict(), MODEL_PATH)

    print("✅ Training complete.")

run_training()


In [None]:
# ✅ 9. Download the trained model
from google.colab import files
files.download("side_stacker_model.pth")
