In [1]:
import numpy as np
import numba as nb

In [2]:
SHAPE = (7, 8)
K = 4
PLAYERS = [1, 2]

In [3]:
@nb.njit(nb.int8[:,:](nb.types.UniTuple(nb.int8, 2)))
def create_grid(shape: tuple[int, int]) -> np.ndarray:
    return np.zeros(shape, dtype=np.int8)

In [4]:
from typing import NamedTuple

State = NamedTuple('State', grid=np.ndarray, player=int, time=int)


In [5]:
@nb.njit(nb.int8[:,:](nb.int8[:,:], nb.int8, nb.int8))
def transition(grid: np.ndarray, player: int, action: int) -> np.ndarray:
    for row in range(grid.shape[0]):
        if grid[row][action] != 0:
            grid[row - 1][action] = player
            break
    else:
        grid[grid.shape[0] - 1][action] = player 
    return grid

In [6]:
@nb.njit(nb.int8[:](nb.int8[:,:]))
def generate_actions(grid: np.ndarray) -> np.ndarray:
    actions = np.zeros(grid.shape[1], np.int8)
    for col in range(grid.shape[1]):
        if not np.all(grid[:, col] != 0):
            actions[col] = 1 
    return actions

In [7]:
@nb.njit(nb.boolean(nb.int8[:]))
def check_line(line: np.ndarray) -> bool:
    return line[0] != 0 and np.min(line) == np.max(line)

@nb.njit(nb.boolean(nb.int8[:,:], nb.int8))
def check_horizontal_lines(grid: np.ndarray, k: int) -> bool: # tuple[bool, tuple[tuple[int, int], tuple[int, int]]]:
    for row in range(grid.shape[0]):
        # Get the columns of the row
        subgrid = grid[row, :]
        assert subgrid.ndim == 1 and subgrid.shape[0] == grid.shape[1]
        # Loop for all valid length k lines of row
        for col in range(grid.shape[1] - k + 1):
            # Get the length k line starting from k of row
            line = subgrid[col : col + k]
            assert line.shape[0] == k
            # Check if line is full
            if check_line(line):
                return True #, ((row, col), (row , col + k - 1))
    return False

@nb.njit(nb.boolean(nb.int8[:,:], nb.int8))
def check_vertical_lines(grid: np.ndarray, k: int) -> bool: # tuple[bool, tuple[tuple[int, int], tuple[int, int]]]:
    for col in range(grid.shape[1]):
        # Get the rows of the col
        subgrid = grid[:, col]
        assert subgrid.ndim == 1 and subgrid.shape[0] == grid.shape[0]
        # Loop for all valid length k lines of col
        for row in range(grid.shape[0] - k + 1):
            # Get the length k line starting from k of col
            line = subgrid[row : row + k]
            assert line.shape[0] == k
            # Check if line is full
            if check_line(line):
                return True # ((row, col), (row + k - 1, col))
    return False

@nb.njit(nb.boolean(nb.int8[:,:], nb.int8))
def check_diagonal_lines(grid: np.ndarray, k: int) -> bool:
    for row in range(grid.shape[0]):
        # Check if row + k goes out of bounds of the grid
        if row + k > grid.shape[0]:
            break
        for col in range(grid.shape[1]):
            # Check if col + k goes out of bounds of the grid
            if col + k > grid.shape[1]:
                break
            # Get the k x k subgrid to check the diagonals
            subgrid = grid[row : row + k, col : col + k]
            assert subgrid.shape[0] == subgrid.shape[1] == k
            # Get the main and anti diagonals of the subgrid
            lr_line =  np.diag(subgrid)
            rl_line = np.diag(np.fliplr(subgrid))
            # Check if either of the lines are full
            if check_line(lr_line) or check_line(rl_line):
                return True
    return False

@nb.njit(nb.boolean(nb.int8[:,:]))
def check_tie(grid: np.ndarray) -> bool:
    return np.all(grid != 0)

@nb.njit(nb.boolean(nb.int8[:,:], nb.int8))
def terminal(grid: np.ndarray, k: int) -> bool:
    return check_horizontal_lines(grid, k) or check_vertical_lines(grid, k) or check_diagonal_lines(grid, k) or check_tie(grid)

In [8]:
@nb.njit
def run(shape: tuple[int, int], k: int, players: list[int]) -> dict[str, int]:
    grid = create_grid(shape)
    time = 0
    player = players[time % 2]

    while not terminal(grid, k):
        # Generate valid actions
        actions = generate_actions(grid)
        # Select action
        action = np.random.choice(np.argwhere(actions == 1)[:, 0])
        # Execute action
        grid = transition(grid, player, action)
        # Adjust state
        time += 1
        # Get current player
        player = players[time % 2]

    return (grid, time)

In [None]:
from typing import TypedDict

# Shape = tuple[int, int]
# Config = TypedDict('Config', {'shape': Shape, 'k': int, 'players': list[int]})

In [3]:
from typing import Optional
from numba.experimental import jitclass
from connectx.types import Config, Grid, Info, Action, Actions, State
import connectx.functional as cxf
from connectx.renderer import terminal_render as render

spec = [
    ("config", Config),
    ("grid", Grid),
    ("info", Info),
    ("actions", Actions)
]

class Game():
    def __init__(self, config: Config) -> None:
        self._config: Config = config
        self._grid: Grid = None
        self._actions: Actions = None
        self._info: Info = None

    def start(self, state: Optional[State] = None) -> tuple[State, Actions]:
        if state is None:
            self._grid = cxf.create_grid(self._config["shape"])
            self._info = {"active": 0, "time": 0}
        else:
            self._grid = state["grid"]
            self._info = state["info"]

        self._actions = cxf.generate_actions(self._grid)

        return self.state, self.actions
    
    def transition(self, action: Action) -> tuple[State, Actions]:
        token = self._config["players"][self._info["active"]]
        self._grid = cxf.place_token(self._grid, token, action)

        self._info["time"] += 1
        self._info["active"] = self._info["time"] % 2

        self._actions = cxf.generate_actions(self._grid)

        return self.state, self.actions

    def terminal(self) -> bool:
        return cxf.terminal(self._grid, self._config["k"])
    
    def render(self) -> None:
        render(self._grid, self._info["time"], self._config["players"][self._info["active"]])
     
    @property
    def state(self) -> State:
        return {
            "grid": self._grid,
            "info": self._info
        }

    @property
    def actions(self) -> Actions: 
        return self._actions

In [5]:
game = Game(Config(shape=(6, 7), k=4, players=[1, 2]))

# print(game._grid)
game.start()
game.transition(0)

array([[0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0]], dtype=uint8)

In [None]:
@nb.vectorize
def parallel():
    pass

In [27]:
def get_report(grid: np.ndarray, time: int, players: list[int]) -> dict[str, int]:
    # Get game-report
    report = nb.typed.Dict.empty(nb.types.string, nb.types.int8)
    if check_tie(grid):
        report = {"winner": 0, "time": time}
    else:
        report = {
            "winner": players[(time - 1) % 2], "time": time
        }

    return report

In [33]:
grid, time = run((25, 28), 10, PLAYERS)
get_report(grid, time, PLAYERS)

{'winner': 1, 'time': 233}

In [77]:
%%timeit
generate_actions(grid)

1.05 µs ± 7.68 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [74]:
%%timeit
generate_actions(grid)

27.9 µs ± 289 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [67]:
%%timeit
check_horizontal_lines(grid, k)

30 µs ± 694 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [65]:
%%timeit
check_horizontal_lines(grid, k)

353 ns ± 4.71 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
