In [1]:
import numpy as np

In [53]:
class Board:
    LENGTH = 3
    X = 2 # x, also terminal
    O = 1 # 0, also terminal
    U = 0 # free cell, not terminal
    N = -1 # terminal, but no moves left

In [166]:
class State:
    def __init__(self, code):
        self.code = code
    
    def get_code(self):
        return self.code
    
    def get(self, i, j):
        return (self.code // (3 ** (i * Board.LENGTH + j))) % 3

    def is_free(self, i, j):
        return self.get(i, j) == Board.U
    
    def set_and_get_next_state(self, i, j, sym):
        if not self.is_free(i, j):
            raise ValueError("Cell is full")
        
        newCode = self.get_code() + sym * (3 ** (i * Board.LENGTH + j))
        return State(newCode)
    
    def is_terminal(self):
        # rows
        for i in range(Board.LENGTH):
            v = self.get(i, 0)
            win = True
            for j in range(1, Board.LENGTH):
                if self.get(i, j) != v:
                    win = False
                    break
            if win and v != Board.U:
                return v
            
        # cols
        for j in range(Board.LENGTH):
            v = self.get(0, j)
            win = True
            for i in range(1, Board.LENGTH):
                if self.get(i, j) != v:
                    win = False
                    break
            if win and v != Board.U:
                return v
        
        # diag1
        v = self.get(0, 0)
        win = True
        for i in range(1, Board.LENGTH):
            if self.get(i, i) != v:
                win = False
                break
        if win and v != Board.U:
            return v
        
        # diag2
        v = self.get(0, Board.LENGTH - 1)
        win = True
        for i in range(1, Board.LENGTH):
            if self.get(i, Board.LENGTH - i - 1) != v:
                win = False
                break
        if win and v != Board.U:
            return v
        
        is_full = True
        for i in range(Board.LENGTH):
            for j in range(Board.LENGTH):
                if self.get(i, j) == Board.U:
                    is_full = False
                    break
                    
        if is_full:
            return Board.N
        
        return Board.U

    def print(self):
        for i in range(Board.LENGTH):
            for j in range(Board.LENGTH):
                print(self.get(i, j), end='')
            
            print("")

In [167]:
class Environment:
    def __init__(self):
        self.state = State(0)
    
    def get_state(self):
        return self.state
    
    def set_state(self, state):
        self.state = state

In [168]:
class Agent:
    def _code_value(self, code):
        s = State(code)
        winner = s.is_terminal()
        if winner == Board.U or winner == Board.N:
            return 0.5
        elif winner == self.sym:
            return 1
        else:
            return 0
        
    def __init__(self, sym, eps, lam):
        self.sym = sym
        self.eps = eps
        self.lam = lam
        self.values = {code: self._code_value(code) for code in range(3 ** (Board.LENGTH * Board.LENGTH))}
        self.history = []
        
    def _should_pick_randomly(self):
        # use epsilon-greedy to determine explore/exploit
        r = np.random.random()
        if r < self.eps:
            return True
        else:
            return False
        
    def make_move(self, env, train = True):
        next_state = None
        next_states = []
        for i in range(Board.LENGTH):
            for j in range(Board.LENGTH):
                if env.get_state().is_free(i, j):
                    next_states.append(env.get_state().set_and_get_next_state(i, j, self.sym))
        
        if len(next_states) == 0:
            print(env.get_state().get_code())
            env.get_state().print()
            raise Exception("Empy next states")
        
        if train and self._should_pick_randomly():
            idx = np.random.randint(0, len(next_states))
            next_state = next_states[idx]
        else:
            next_state = next_states[np.argmax([self.values[state.get_code()] for state in next_states])]
        
        env.set_state(next_state)
    
    def insert_state_in_history(self, state):
        self.history.append(state)
        
    def reset_history(self):
        self.history = []

    def update(self, env):
        # update the history of states in reverse based on the reward
        for state in reversed(self.history):
            if not state.is_terminal():
                self.values[state.get_code()] = self.values[state.get_code()] + self.lam * (self.values[future_state.get_code()] - self.values[state.get_code()])
                future_state = state
            else:
                future_state = state

In [176]:
class Game:
    def train(self, episodes):
        self.pX = Agent(Board.X, 0.05, 0.2)
        self.p0 = Agent(Board.O, 0.05, 0.2)
        
        for episode in range(episodes):
            env = Environment()
            
            self.pX.insert_state_in_history(env.get_state())
            self.p0.insert_state_in_history(env.get_state())
            
            currentPlayer = self.pX            
            while env.get_state().is_terminal() == Board.U:
                currentPlayer.make_move(env)
                self.pX.insert_state_in_history(env.get_state())
                self.p0.insert_state_in_history(env.get_state())

                if currentPlayer == self.pX:
                    currentPlayer = self.p0
                else:
                    currentPlayer = self.pX
    
            self.pX.update(env)
            self.p0.update(env)
            
            self.pX.reset_history()
            self.p0.reset_history()
        
    def get_pX(self):
        return self.pX
    
    def get_p0(self):
        return self.p0
    
    def play_human(self, human_sym = Board.X):
        player = self.p0
        if human_sym == Board.O:
            player = self.pX
            
        env = Environment()
        currentPlayer = Board.X
        while env.get_state().is_terminal() == Board.U:
            if currentPlayer == human_sym:
                print("---------------------")
                env.get_state().print()
                pos = tuple(int(x.strip()) for x in input().split(','))
                if not env.get_state().is_free(pos[0], pos[1]):
                    print("Busy cell")
                    continue
                else:
                    env.set_state(env.get_state().set_and_get_next_state(pos[0], pos[1], human_sym))
            else:
                player.make_move(env, train=False)
            
            if currentPlayer == Board.X:
                currentPlayer = Board.O
            else:
                currentPlayer = Board.X
        
        winner = env.get_state().is_terminal()
        if winner == Board.X:
            print("X won")
        elif winner == Board.O:
            print("0 won")
        else:
            print("Draw")

In [177]:
game = Game()
game.train(10000)

In [181]:
game.play_human()

---------------------
000
000
000
1,0
---------------------
100
200
000
2,0
---------------------
101
200
200
1,1
0 won
