Copyright **`(c)`** 2023 Giovanni Squillero `<giovanni.squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free for personal or classroom use; see [`LICENSE.md`](https://github.com/squillero/computational-intelligence/blob/master/LICENSE.md) for details.  

# LAB10

Use reinforcement learning to devise a tic-tac-toe player.

### Deadlines:

* Submission: [Dies Natalis Solis Invicti](https://en.wikipedia.org/wiki/Sol_Invictus)
* Reviews: [Befana](https://en.wikipedia.org/wiki/Befana)

Notes:

* Reviews will be assigned  on Monday, December 4
* You need to commit in order to be selected as a reviewer (ie. better to commit an empty work than not to commit)

In [8]:
from itertools import combinations
from collections import namedtuple, defaultdict
from random import choice, uniform
from copy import deepcopy

from tqdm.notebook import trange
import numpy as np

In [9]:
State = namedtuple("State", ["x", "o"])
MAGIC = [2, 7, 6, 9, 5, 1, 4, 3, 8]
winning_formations = [
    [2, 7, 6], [9, 5, 1], [4, 3, 8],  # Rows
    [2, 9, 4], [7, 5, 3], [6, 1, 8],  # Columns
    [2, 5, 8], [6, 5, 4]              # Diagonals
]

In [10]:
def win(elements):
    """Checks if elements is winning"""
    return any(sum(c) == 15 for c in combinations(elements, 3))


def state_value(pos: State, action):
    """Evaluate state: +1 first player wins"""
    state_x = deepcopy(pos.x)
    if action is not None:
        state_x.add(action)
    state_o = deepcopy(pos.o)
    if action is not None:
        state_o.add(action)
    if win(state_x):
        return 1
    elif win(state_o):
        return -1
    elif intermediate_reward_condition(pos, action):
        return 0.5
    else:
        return 0


def print_board(pos):
    """Nicely prints the board"""
    for r in range(3):
        for c in range(3):
            i = r * 3 + c
            if MAGIC[i] in pos.x:
                print("X", end="")
            elif MAGIC[i] in pos.o:
                print("O", end="")
            else:
                print(".", end="")
        print()
    print()


def intermediate_reward_condition(state, action):
    # Check if a player has made a move forming a potential winning formation
    # For example, checking if two cells in a row/column/diagonal are taken by the same player
    # and the third cell is available

    for formation in winning_formations:
        # check for two cells taken by the player and the third cell available
        new_state = deepcopy(state)
        new_state.x.add(action)
        x_old_count = sum(1 for cell in formation if cell in state.x)
        x_count = sum(1 for cell in formation if cell in new_state.x)
        o_count = sum(1 for cell in formation if cell in state.o)
        if x_count == 2 and o_count == 0 and x_old_count == 1:
            empty_cell = [
                cell
                for cell in formation
                if cell not in new_state.x and cell not in state.o
            ]
            if empty_cell:
                return True

    return False

In [11]:
def epsilon_greedy_policy(Qtable, state, epsilon, available):
    state_key = state_to_key(state)
    random_int = uniform(0, 1)
    if random_int > epsilon:
        actions = sorted(Qtable[state_key], reverse=True)
        if len(actions) == 0:
            action = choice(list(available))
        for a in actions:
            if a in available:
                action = a
                break
    else:
        action = choice(list(available))
    return action


def state_to_key(state):
    return (frozenset(state.x), frozenset(state.o))


def ply(state, epsilon, learning_rate, gamma, Qtable, available):
    action = epsilon_greedy_policy(Qtable, state, epsilon, available)

    reward = state_value(state, action)

    new_state = deepcopy(state)
    new_state.x.add(action)

    hashable_state = state_to_key(state)
    hashable_new_state = state_to_key(new_state)

    current_Q = Qtable[hashable_state][action]
    next_max_Q = (
        max(Qtable[hashable_new_state].values())
        if hashable_new_state in Qtable and Qtable[hashable_new_state]
        else 0
    )

    # Q-value update equation including the action component
    Qtable[hashable_state][action] += learning_rate * (
        reward + gamma * next_max_Q - current_Q
    )

    return action

In [12]:
def train(n_training_episodes, learning_rate, min_epsilon, max_epsilon, decay_rate, gamma, Qtable):
    for episode in range(n_training_episodes):
        # q-learn player is first
        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(
            -decay_rate * episode
        )
        state = State(set(), set())

        available = set(range(1, 9 + 1))
        while available:
            x = ply(state, epsilon, learning_rate, gamma, Qtable, available)
            if x is not None:
                state.x.add(x)
                available.remove(x)
            if win(state.x) or not available:
                break

            o = choice(list(available))
            state.o.add(o)
            available.remove(o)

        # q-learn player is second
        state = State(set(), set())

        available = set(range(1, 9 + 1))
        while available:
            o = choice(list(available))
            state.o.add(o)
            available.remove(o)
            if not available:
                break

            x = ply(state, epsilon, learning_rate, gamma, Qtable, available)
            if x is not None:
                state.x.add(x)
                available.remove(x)
            if win(state.x) or not available:
                break

In [13]:
def test(n_training_episodes, Qtable, min_epsilon, max_epsilon, decay_rate):
    # evaluation function of q-player against an opponent who plays randomly for n_training_episodes
    # with q-learn player playing first and n_training_episodes playing second
    wins = 0
    draws = 0
    for episode in range(n_training_episodes):
        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(
            -decay_rate * episode
        )
        
        # q-learn player is first
        state = State(set(), set())
        available = set(range(1, 9 + 1))

        while available:
            x = epsilon_greedy_policy(Qtable, state, epsilon, available)
            state.x.add(x)
            available.remove(x)
            if win(state.x):
                wins += 1
                break
            if not available:
                draws += 1
                break

            o = choice(list(available))
            state.o.add(o)
            available.remove(o)
            if win(state.o):
                break
            if not available:
                draws += 1
                break

        # q-learn player is second
        state = State(set(), set())
        available = set(range(1, 9 + 1))

        while available:
            o = choice(list(available))
            state.o.add(o)
            available.remove(o)
            if win(state.o) or not available:
                break
            if not available:
                draws += 1
                break

            x = epsilon_greedy_policy(Qtable, state, epsilon, available)
            state.x.add(x)
            available.remove(x)
            if win(state.x):
                wins += 1
                break
            if not available:
                draws += 1
                break

    return wins, draws

In [14]:
Qtable_ttt = defaultdict(lambda: defaultdict(lambda: 0.0))

n_training_episodes = 5_000
learning_rate = 0.2

n_eval_episodes = 1000

gamma = 0.3

max_epsilon = 1.0
min_epsilon = 0.05
decay_rate = 0.0002

train(n_training_episodes, learning_rate, min_epsilon, max_epsilon, gamma, decay_rate, Qtable_ttt)

wins, draws = test(n_eval_episodes, Qtable_ttt, min_epsilon, max_epsilon, decay_rate)
print(f"Q-Learn agent wins {wins * 100 / (2 * n_eval_episodes)}% and draws {draws * 100 / (2 * n_eval_episodes)}% of the time")

Q-Learn agent wins 43.2% and draws 5.9% of the time
