<a href="https://colab.research.google.com/github/raghav-sanagavarapu/Connect4RL/blob/main/Connect4(TicTacToe_Style).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import pickle

rows = 6
columns = 7

class State:
  def __init__(self, p1, p2):
    self.board = np.zeros((rows, columns))
    self.p1 = p1
    self.p2 = p2
    self.isEnd = False
    self.boardHash = None
    # init p1 plays first
    self.playerSymbol = 1

  def getHash(self):
    self.boardHash = str(self.board.reshape(columns * rows))
    return self.boardHash

  def reset(self):
    self.board = np.zeros((rows, columns))
    self.boardHash = None
    self.isEnd = False
    self.playerSymbol = 1


  def availablePositions(self):
    positions = []
    for j in range(columns):
        for i in range(rows - 1, -1, -1):
            if self.board[i, j] == 0:
                positions.append((i,j))  # need to be tuple
                break  # Stop searching in this column once an empty space is found
    return positions


  def updateState(self, position):
    self.board[position] = self.playerSymbol
    # switch to another player
    self.playerSymbol = -1 if self.playerSymbol == 1 else 1


  def winner(self):
    # Check for a win in rows
    for row in range(rows):
        for col in range(columns - 3):  # Only need to check up to 4th column from the left
            if (
                self.board[row, col] == self.board[row, col + 1] == self.board[row, col + 2] == self.board[row, col + 3]
                and self.board[row, col] != 0
            ):
                self.isEnd = True
                return self.board[row, col]

    # Check for a win in columns
    for col in range(columns):
        for row in range(rows - 3):  # Only need to check up to 4th row from the bottom
            if (
                self.board[row, col] == self.board[row + 1, col] == self.board[row + 2, col] == self.board[row + 3, col]
                and self.board[row, col] != 0
            ):
                self.isEnd = True
                return self.board[row, col]

    # Check for a win in diagonals (from top-left to bottom-right)
    for row in range(rows - 3):
        for col in range(columns - 3):
            if (
                self.board[row, col] == self.board[row + 1, col + 1] == self.board[row + 2, col + 2] == self.board[row + 3, col + 3]
                and self.board[row, col] != 0
            ):
                self.isEnd = True
                return self.board[row, col]

    # Check for a win in diagonals (from top-right to bottom-left)
    for row in range(rows - 3):
        for col in range(3, columns):
            if (
                self.board[row, col] == self.board[row + 1, col - 1] == self.board[row + 2, col - 2] == self.board[row + 3, col - 3]
                and self.board[row, col] != 0
            ):
                self.isEnd = True
                return self.board[row, col]

    # Check for a tie
    if len(self.availablePositions()) == 0:
        self.isEnd = True
        return 0

    # No winner or tie
    self.isEnd = False
    return None



  def giveReward(self):
    result = self.winner()
    # backpropagate reward
    if result == 1:
        self.p1.feedReward(1)
        self.p2.feedReward(-10)
    elif result == -1:
        self.p1.feedReward(-10)
        self.p2.feedReward(1)
    else:
        self.p1.feedReward(0.1)
        self.p2.feedReward(0.5)

  def play(self, rounds=100):
    for i in range(rounds):
        if i % 1000 == 0:
            print("Rounds {}".format(i))
        while not self.isEnd:
            # Player 1
            positions = self.availablePositions()
            p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)
            # take action and upate board state
            self.updateState(p1_action)
            board_hash = self.getHash()
            self.p1.addState(board_hash)
            # check board status if it is end

            win = self.winner()
            if win is not None:
                # self.showBoard()
                # ended with p1 either win or draw
                self.giveReward()
                self.p1.reset()
                self.p2.reset()
                self.reset()
                break

            else:
                # Player 2
                positions = self.availablePositions()
                p2_action = self.p2.chooseAction(positions, self.board, self.playerSymbol)
                self.updateState(p2_action)
                board_hash = self.getHash()
                self.p2.addState(board_hash)

                win = self.winner()
                if win is not None:
                    # self.showBoard()
                    # ended with p2 either win or draw
                    self.giveReward()
                    self.p1.reset()
                    self.p2.reset()
                    self.reset()
                    break

  def play2(self):
    while not self.isEnd:
        # Player 1
        positions = self.availablePositions()
        p1_action = self.p1.chooseAction(positions, self.board, self.playerSymbol)
        # take action and upate board state
        self.updateState(p1_action)
        self.showBoard()
        # check board status if it is end
        win = self.winner()
        if win is not None:
            if win == 1:
                print(self.p1.name, "wins!")
            else:
                print("tie!")
            self.reset()
            break

        else:
            # Player 2
            positions = self.availablePositions()
            p2_action = self.p2.chooseAction(positions, self.board, self.playerSymbol)

            self.updateState(p2_action)
            self.showBoard()
            win = self.winner()
            if win is not None:
                if win == -1:
                    print(self.p2.name, "wins!")
                else:
                    print("tie!")
                self.reset()
                break

  def showBoard(self):
        # p1: x  p2: o
        print("""  0   1   2   3   4   5   6
            """)
        for i in range(0, rows):
            print('-----------------------------')
            out = '| '
            for j in range(0, columns):
                if self.board[i, j] == 1:
                    token = 'x'
                if self.board[i, j] == -1:
                    token = 'o'
                if self.board[i, j] == 0:
                    token = ' '
                out += token + ' | '
            print(out)
        print('-----------------------------')

In [2]:
class Player:
  def __init__(self, name, exp_rate=0.3):
    self.name = name
    self.states = []  # record all positions taken
    self.lr = 0.2
    self.exp_rate = exp_rate
    self.decay_gamma = 0.9
    self.states_value = {}  # state -> value
    board = np.zeros((rows, columns))

  def getHash(self, board):
    boardHash = str(board.reshape(columns * rows))
    return boardHash

  def chooseAction(self, positions,board, symbol):
    if np.random.uniform(0, 1) <= self.exp_rate:
        # take random action
        idx = np.random.choice(len(positions))
        action = positions[idx]

    else:
      # Use the loaded policy
      board_hash = self.getHash(board)
      if board_hash in self.states_value:
          # Find the best action based on the loaded policy
          action = np.argmax([self.states_value[board_hash][action] for action in positions])
      else:
          # If the current state is not in the policy, choose a random action
          idx = np.random.choice(len(positions))
          action = positions[idx]
    #print("{} takes action {}".format(self.name, action))
    return action




  def addState(self, state):
    self.states.append(state)

    # at the end of game, backpropagate and update states value
  def feedReward(self, reward):
    for st in reversed(self.states):
        if self.states_value.get(st) is None:
            self.states_value[st] = 0
        self.states_value[st] += self.lr * (self.decay_gamma * reward - self.states_value[st])
        reward = self.states_value[st]


  def reset(self):
        self.states = []

  def savePolicy(self):
      fw = open('policy_p1', 'wb')
      pickle.dump(self.states_value, fw)
      fw.close()

  def loadPolicy(self, file):
      fr = open(file, 'rb')
      self.states_value = pickle.load(fr)
      fr.close()


In [3]:


class HumanPlayer:
    def __init__(self, name):
        self.name = name

    def chooseAction(self, positions, board, playerSymbol):
        col = int(input("Which column?"))
        for pos in positions:
          if col == pos[1]:
            return(pos)


    def addState(self, state):
        pass

    def feedReward(self, reward):
        pass

    def reset(self):
        pass



In [4]:
if __name__ == "__main__":
    # training
    p1 = Player("p1")
    p2 = Player("p2")

    st = State(p1, p2)
    print("training...")
    st.play(50)
    p1.savePolicy()

    # play with human
    p1.loadPolicy("policy_p1")
    p1 = Player("computer", exp_rate=0)



    p2 = HumanPlayer('Human')

    st = State(p1, p2)
    st.play2()

training...
Rounds 0
  0   1   2   3   4   5   6
            
-----------------------------
|   |   |   |   |   |   |   | 
-----------------------------
|   |   |   |   |   |   |   | 
-----------------------------
|   |   |   |   |   |   |   | 
-----------------------------
|   |   |   |   |   |   |   | 
-----------------------------
|   |   |   |   |   |   |   | 
-----------------------------
|   | x |   |   |   |   |   | 
-----------------------------
Which column?5
  0   1   2   3   4   5   6
            
-----------------------------
|   |   |   |   |   |   |   | 
-----------------------------
|   |   |   |   |   |   |   | 
-----------------------------
|   |   |   |   |   |   |   | 
-----------------------------
|   |   |   |   |   |   |   | 
-----------------------------
|   |   |   |   |   |   |   | 
-----------------------------
|   | x |   |   |   | o |   | 
-----------------------------
  0   1   2   3   4   5   6
            
-----------------------------
|   |   |   |   |  