In [31]:
from random import choice
from math import inf
import time

 
class Game():
    def __init__(self):
        self.initialize_game()
        
    def empty_board(self, n):
        brd = []
        while len(brd) < n:
            brd.append([])
            while len(brd[-1]) < n:
                brd[-1].append('.')
        return brd

    def initialize_game(self):
        self.current_state = self.empty_board(n)

        # Player X always plays first
        self.player_turn = 'X'

    def draw_board(self):
        print('-----'*n)
        for i in range(0, n):
            for j in range(0, n):
                print(f'| {self.current_state[i][j]} |', end='')
            print('\n' + '-----'*n)
        print()
        
    def empty_squares(self, brd):
        emptys = []
        for i, row in enumerate(brd):
            for j, col in enumerate(row):
                if brd[i][j] == '.':
                    emptys.append([i, j])
        return emptys
               
        
    # Determines if the made move is a legal move
    def is_valid(self, px, py):
        if px < 0 or px > n-1 or py < 0 or py > n-1: # can't place outside the board
            return False
        elif self.current_state[px][py] != '.': # can't place on an occupied square
            return False
        else:
            return True
        
    # Checks if the game has ended and returns the winner in each case
    def is_terminal(self):
        
        # Column win           
        colwin = []
        for col in range(0, n):
            tmp = []
            for row in range(0, n):   
                tmp.append(self.current_state[row][col])
            colwin.append(tmp)
        
        # Row win
        rowwin = []
        for row in range(0, n):
            tmp = []
            for col in range(0, n):   
                tmp.append(self.current_state[row][col])
            rowwin.append(tmp)
        
        # Main diagonal win
        diagwin = []
        tmp1 = []
        tmp2 = []
        for diag in range(0, n):
            tmp1.append(self.current_state[diag][diag])
            tmp2.append(self.current_state[diag][-diag+(n-1)])
        diagwin.append(tmp1)
        diagwin.append(tmp2)
        
        # k X in a row

        # only empty squares
        kX = []
        for i in range(0, n-k+1):
            tmp = ['.']*n
            tmp[i:(i+k)] = ['X']*k
            kX.append(tmp)
        
        # k O in a row

        # only empty squares
        kO = []
        for i in range(0, n-k+1):
            tmp = ['.']*n
            tmp[i:(i+k)] = ['O']*k
            kO.append(tmp)


        #is_win
        for i, win in enumerate(kX):
            for j, row in enumerate(rowwin):
                if len(win) != len(row):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == row[l] and win[l] == 'X':
                        count = count+1
                    if count == k:
                        return 'X'
                    
                    
        for i, win in enumerate(kX):
            for j, col in enumerate(colwin):
                if len(win) != len(col):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == col[l] and win[l] == 'X':
                        count = count+1
                    if count == k:
                        return 'X'
                    
        for i, win in enumerate(kX):
            for j, diag in enumerate(diagwin):
                if len(win) != len(diag):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == diag[l] and win[l] == 'X':
                        count = count+1
                    if count == k:
                        return 'X'
                    
        for i, win in enumerate(kO):
            for j, row in enumerate(rowwin):
                if len(win) != len(row):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == row[l] and win[l] == 'O':
                        count = count+1
                    if count == k:
                        return 'O'
                    
                    
        for i, win in enumerate(kO):
            for j, col in enumerate(colwin):
                if len(win) != len(col):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == col[l] and win[l] == 'O':
                        count = count+1
                    if count == k:
                        return 'O'
                    
        for i, win in enumerate(kO):
            for j, diag in enumerate(diagwin):
                if len(win) != len(diag):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == diag[l] and win[l] == 'O':
                        count = count+1
                    if count == k:
                        return 'O'        

        # Non-main diagonal win
        
        # upper right to lower left non-main diagonals in upper left triangle
        uldiagwin = []
        for i in range(k, n):
            tmp = []
            x = 0
            y = i-1
            for j in range(0, i):
                tmp.append(self.current_state[x][y])
                x = x+1
                y = y-1
            uldiagwin.append(tmp)

        # upper right to lower left non-main diagonals in lower right triangle
        lrdiagwin = []
        for i in range(k, n):
            tmp = []
            x = n-i 
            y = n-1
            for j in range(0, i):
                tmp.append(self.current_state[x][y])
                x = x+1
                y = y-1
            lrdiagwin.append(tmp)
            
        # upper right to lower left non-main diagonals in lower triangles
        lowdiagwin = []
        for i in range(k, n):
            tmp1 = []
            tmp2 = []
            x = n-i  # start at lower left triangle
            y = 0
            for j in range(0, i):
                tmp1.append(self.current_state[x][y])
                tmp2.append(self.current_state[y][x]) # symmertic
                x = x+1
                y = y+1
            lowdiagwin.append(tmp1)
            lowdiagwin.append(tmp2)
                
        # k X in a row on non-main diagonals
        kX_nonmain = []
        for i in range(0, n-k):
            len_diag = i+k
            for j in range(0, len_diag-k+1):
                tmp = ['.']*len_diag
                tmp[j:(j+k)] = ['X']*k
                kX_nonmain.append(tmp)
        
        # k O in a row on non-main diagonals
        kO_nonmain = []
        for i in range(0, n-k):
            len_diag = i+k
            for j in range(0, len_diag-k+1):
                tmp = ['.']*len_diag
                tmp[j:(j+k)] = ['O']*k
                kO_nonmain.append(tmp)
        
        #is_win_non-main diagonals
        
        for i, win in enumerate(kX_nonmain):
            for j, uldiag in enumerate(uldiagwin):
                if len(win) != len(uldiag):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == uldiag[l] and win[l] == 'X':
                        count = count+1
                    if count == k:
                        return 'X'
                    
                    
        for i, win in enumerate(kX_nonmain):
            for j, lrdiag in enumerate(lrdiagwin):
                if len(win) != len(lrdiag):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == lrdiag[l] and win[l] == 'X':
                        count = count+1
                    if count == k:
                        return 'X'
                    
        for i, win in enumerate(kX_nonmain):
            for j, lowdiag in enumerate(lowdiagwin):
                if len(win) != len(lowdiag):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == lowdiag[l] and win[l] == 'X':
                        count = count+1
                    if count == k:
                        return 'X'
                    
        for i, win in enumerate(kO_nonmain):
            for j, uldiag in enumerate(uldiagwin):
                if len(win) != len(uldiag):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == uldiag[l] and win[l] == 'O':
                        count = count+1
                    if count == k:
                        return 'O'
                    
                    
        for i, win in enumerate(kO_nonmain):
            for j, lrdiag in enumerate(lrdiagwin):
                if len(win) != len(lrdiag):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == lrdiag[l] and win[l] == 'O':
                        count = count+1
                    if count == k:
                        return 'O'
                    
        for i, win in enumerate(kO_nonmain):
            for j, lowdiag in enumerate(lowdiagwin):
                if len(win) != len(lowdiag):
                    continue
                count = 0
                for l in range(0, len(win)):
                    if win[l] == lowdiag[l] and win[l] == 'O':
                        count = count+1
                    if count == k:
                        return 'O'        

        
        
        # Is whole board full?
        for i in range(0, n):
            for j in range(0, n):
                # There's an empty field, we continue the game
                if (self.current_state[i][j] == '.'):
                    return None

        # It's a tie!
        return '.'
    
    
    
    # Player 'X' is max
    def max_alpha_beta(self, alpha, beta):
    
        # Possible values for maxv are:
        # -1 - loss
        # 0  - draw
        # 1  - win
        # We're initially setting it to -inf worse than the worst case:
        maxv = -inf
        px = None
        py = None

        result = self.is_terminal()

        if result == 'X':
            return (1, 0, 0)
        elif result == 'O':
            return (-1, 0, 0)
        elif result == '.':
            return (0, 0, 0)

        for i in range(0, n):
            for j in range(0, n):
                if self.current_state[i][j] == '.':
                    
                    # On the empty field player 'X' makes a move and calls Min
                    # That's one branch of the game tree.
                    self.current_state[i][j] = 'X'
                    (v, min_i, in_j) = self.min_alpha_beta(alpha, beta)
                    # Fixing the maxv value if needed
                    if v > maxv:
                        maxv = v
                        px = i
                        py = j
                        
                        if maxv > alpha:
                            alpha = maxv
                    # Setting back the square to empty
                    self.current_state[i][j] = '.'

                    # Next two ifs in Max and Min are the only difference between regular algorithm and minimax
                    if maxv >= beta:
                        return (maxv, px, py)

                    

        return (maxv, px, py)
    
    # Player 'O' is min
    def min_alpha_beta(self, alpha, beta):

            minv = inf

            qx = None
            qy = None

            result = self.is_terminal()
            
            # If the game came to an end, the function needs to return
            # the evaluation function of the end. That can be:
            # -1 - loss
            # 0  - a tie
            # 1  - win

            if result == 'X':
                return (1, 0, 0)
            elif result == 'O':
                return (-1, 0, 0)
            elif result == '.':
                return (0, 0, 0)

            for i in range(0, n):
                for j in range(0, n):
                    if self.current_state[i][j] == '.':
                        self.current_state[i][j] = 'O'
                        (v, max_i, max_j) = self.max_alpha_beta(alpha, beta)
                        
                        if v < minv:
                            minv = v
                            qx = i
                            qy = j
                            
                            if minv < beta:
                                beta = minv
                        self.current_state[i][j] = '.'
                        
                        # Next two ifs in Max and Min are the only difference between regular algorithm and minimax
                        if minv <= alpha:
                            return (minv, qx, qy)

                        

            return (minv, qx, qy)
        
    def play_game(self):
        
        while True:
            self.draw_board()
            self.result = self.is_terminal()

            if self.result != None:
                if self.result == 'X':
                    print('The winner is X!')
                elif self.result == 'O':
                    print('The winner is O!')
                elif self.result == '.':
                    print("Draw!")


                self.initialize_game()
                return

            if self.player_turn == 'X':

                if player == 'X':
                    # Human player
                    while True:
                        px = int(input('Insert the X coordinate: '))
                        py = int(input('Insert the Y coordinate: '))

                        if self.is_valid(px, py):
                            self.current_state[px][py] = 'X'
                            self.player_turn = 'O'
                            break
                        else:
                            print('Invalid move! Try again.')  
                        
                else:
                    start = time.time()
                    (v, px, py) = self.max_alpha_beta(-inf, inf)
                    end = time.time()
                    print('Evaluation time: {}s'.format(round(end - start, 4)))
                    print('Recommended move: X = {}, Y = {}'.format(px, py)) 
                    self.current_state[px][py] = 'X'
                    self.player_turn = 'O'

            else:
                
                if player == 'O':
                    # Human player
                    while True:
                        qx = int(input('Insert the X coordinate: '))
                        qy = int(input('Insert the Y coordinate: '))

                        if self.is_valid(qx, qy):
                            self.current_state[qx][qy] = 'O'
                            self.player_turn = 'X'
                            break
                        else:
                            print('Invalid move! Try again.')  
                        
                else:
                    start = time.time()
                    (v, qx, qy) = self.min_alpha_beta(-inf, inf)
                    end = time.time()
                    print('Evaluation time: {}s'.format(round(end - start, 4)))
                    print('Recommended move: X = {}, Y = {}'.format(qx, qy))
                    self.current_state[qx][qy] = 'O'
                    self.player_turn = 'X'
        
             

In [33]:
n = 4
k = 2
player = 'X'
def main():
    g = Game()
    g.play_game()

if __name__ == "__main__":
    main()

--------------------
| . || . || . || . |
--------------------
| . || . || . || . |
--------------------
| . || . || . || . |
--------------------
| . || . || . || . |
--------------------

Insert the X coordinate: 2
Insert the Y coordinate: 2
--------------------
| . || . || . || . |
--------------------
| . || . || . || . |
--------------------
| . || . || X || . |
--------------------
| . || . || . || . |
--------------------

Evaluation time: 0.9315s
Recommended move: X = 0, Y = 0
--------------------
| O || . || . || . |
--------------------
| . || . || . || . |
--------------------
| . || . || X || . |
--------------------
| . || . || . || . |
--------------------

Insert the X coordinate: 3
Insert the Y coordinate: 2
--------------------
| O || . || . || . |
--------------------
| . || . || . || . |
--------------------
| . || . || X || . |
--------------------
| . || . || X || . |
--------------------

The winner is X!
