# Pytorch scratchpad

In [1]:
from dataclasses import dataclass

from typing import Any, Optional, Sequence, Type, TypeVar
from typing import override
import os
import importlib

import torch

from rgi.core import base

In [2]:
from rgi.games import connect4

# reimport connect4
importlib.reload(connect4)
GameState = connect4.GameState
BatchGameState = connect4.BatchGameState
Action = connect4.Action
BatchAction = connect4.BatchAction
PlayerId = connect4.PlayerId


In [3]:
class Connect4Scratch(connect4.Connect4Game):

    @override
    def next_state(self, game_state: GameState, action: Action) -> GameState:
        """Find the lowest empty row in the selected column and return the updated game state."""
        if action not in self.legal_actions(game_state):
            raise ValueError(
                f"Invalid move: Invalid column '{action}' no in {self._all_column_ids}"
            )

        for row in range(1, self.height + 1):
            if (row, action) not in game_state.board:
                new_board = game_state.board.set((row, action), game_state.current_player)
                winner = self._calculate_winner(
                    new_board, action, row, game_state.current_player
                )
                next_player: PlayerId = 2 if game_state.current_player == 1 else 1
                return GameState(
                    board=new_board, current_player=next_player, winner=winner
                )

        raise ValueError("Invalid move: column is full")

    def _calculate_winner(
        self, board: torch.Tensor, col: int, row: int, player: PlayerId
    ) -> Optional[PlayerId]:
        """Check if the last move made at (row, col) by 'player' wins the game."""
        directions = [
            ((1, 0), (-1, 0)),  # Vertical
            ((0, 1), (0, -1)),  # Horizontal
            ((1, 1), (-1, -1)),  # Diagonal /
            ((1, -1), (-1, 1)),  # Diagonal \
        ]

        def count_in_direction(delta_row: int, delta_col: int) -> int:
            """Count consecutive pieces in one direction."""
            count = 0
            current_row, current_col = row + delta_row, col + delta_col
            while 1 <= current_row <= self.height and 1 <= current_col <= self.width:
                if board.get((current_row, current_col)) == player:
                    count += 1
                    current_row += delta_row
                    current_col += delta_col
                else:
                    break
            return count

        for (delta_row1, delta_col1), (delta_row2, delta_col2) in directions:
            consecutive_count = (
                count_in_direction(delta_row1, delta_col1)
                + count_in_direction(delta_row2, delta_col2)
                + 1  # Include the current piece
            )
            if consecutive_count >= self.connect_length:
                return player

        return None  # No winner yet

    @override
    def is_terminal(self, game_state: GameState) -> bool:
        if game_state.winner is not None:
            return True
        return all((self.height, col) in game_state.board for col in self._all_column_ids)

    @override
    def reward(self, game_state: GameState, player_id: PlayerId) -> float:
        if game_state.winner == player_id:
            return 1.0
        elif game_state.winner is not None:
            return -1.0
        return 0.0

    @override
    def pretty_str(self, game_state: GameState) -> str:
        return (
            "\n".join(
                "|"
                + "|".join(
                    " ●○"[game_state.board.get((row, col), 0)]
                    for col in self._all_column_ids
                )
                + "|"
                for row in reversed(
                    self._all_row_ids
                )  # Start from the top row and work down
            )
            + "\n+"
            + "-+" * self.width
        )

    def parse_board(self, board_str: str, current_player: PlayerId) -> GameState:
        """Parses the output of pretty_str into a GameState."""
        rows = board_str.strip().split("\n")[:-1]  # Skip the bottom border row
        board: Map[TPosition, int] = Map()
        for r, row in enumerate(reversed(rows), start=1):
            row_cells = row.strip().split("|")[1:-1]  # Extract cells between borders
            for c, cell in enumerate(row_cells, start=1):
                if cell == "●":
                    board = board.set((r, c), 1)  # Player 1
                elif cell == "○":
                    board = board.set((r, c), 2)  # Player 2
        return GameState(board=board, current_player=current_player)


In [4]:
g = Connect4Scratch()
initial_state = g.initial_state()

assert g.current_player_id(initial_state) == 1
assert g.all_actions() == (1, 2, 3, 4, 5, 6, 7)
# assert g.legal_actions(initial_state) == (2, 3, 4, 5, 7)

b = initial_state.board
b[0, 0] = 1
b[0, 5] = -1
# assert g.legal_actions(initial_state) == (2, 3, 4, 5, 7)

g.legal_actions(initial_state)

    # @override
    # def legal_actions(self, game_state: GameState) -> Sequence[Action]:
    #     return ((game_state.board[0] == 0).nonzero().squeeze()+1).tolist()

tuple(((b[0] == 0).nonzero().squeeze()+1).tolist())


(2, 3, 4, 5, 7)

In [10]:
import numpy as np
values_int = [1, 2, 3, 4, 5]
values_int_2d = [[1, 2, 3, 4, 5], [2,4,6,8,10]]
values_np = [np.array(values_int), np.array(values_int)*2]
values_torch = [torch.tensor(values_int), torch.tensor(values_int)*2]
values_float = [1.0, 2.0, 3.0, 4.0, 5.0]
values_float_2d = [[1.0, 2.0, 3.0, 4.0, 5.0], [2.0,4.0,6.0,8.0,10.0]]

print(torch.stack([torch.tensor(value) for value in values_int]))
print(torch.stack([torch.tensor(value) for value in values_int_2d]))
print(torch.stack([torch.tensor(value) for value in values_np]))
print(torch.stack([torch.tensor(value) for value in values_torch]))
print(torch.stack([torch.tensor(value) for value in values_float]))
print(torch.stack([torch.tensor(value) for value in values_float_2d]))


tensor([1, 2, 3, 4, 5])
tensor([[ 1,  2,  3,  4,  5],
        [ 2,  4,  6,  8, 10]])
tensor([[ 1,  2,  3,  4,  5],
        [ 2,  4,  6,  8, 10]])
tensor([[ 1,  2,  3,  4,  5],
        [ 2,  4,  6,  8, 10]])
tensor([1., 2., 3., 4., 5.])
tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 2.,  4.,  6.,  8., 10.]])


  print(torch.stack([torch.tensor(value) for value in values_torch]))


RuntimeError: stack expects each tensor to be equal size, but got [3] at entry 0 and [2] at entry 1

In [12]:
board_size = 8
board = torch.zeros((board_size, board_size), dtype=torch.int8)
mid = board_size // 2
# board[mid - 1 : mid + 1, mid - 1 : mid + 1] = torch.tensor([[2, 1], [1, 2]], dtype=torch.int8)
board[mid - 1, mid - 1] = 2
board[mid - 1, mid] = 1
board[mid, mid - 1] = 1
board[mid, mid] = 2

board

    # @override
    # def initial_state(self) -> OthelloState:
    #     board = torch.zeros((self.board_size, self.board_size), dtype=torch.int8)
    #     mid = self.board_size // 2
    #     board[mid - 1 : mid + 1, mid - 1 : mid + 1] = torch.tensor([[2, 1], [1, 2]], dtype=torch.int8)
    #     return OthelloState(board=board, current_player=1, is_terminal=False)

tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 2, 1, 0, 0, 0],
        [0, 0, 0, 1, 2, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int8)