Copyright **`(c)`** 2023 Giovanni Squillero `<giovanni.squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free for personal or classroom use; see [`LICENSE.md`](https://github.com/squillero/computational-intelligence/blob/master/LICENSE.md) for details.  

# LAB10

Use reinforcement learning to devise a tic-tac-toe player.

### Deadlines:

* Submission: [Dies Natalis Solis Invicti](https://en.wikipedia.org/wiki/Sol_Invictus)
* Reviews: [Befana](https://en.wikipedia.org/wiki/Befana)

Notes:

* Reviews will be assigned  on Monday, December 4
* You need to commit in order to be selected as a reviewer (ie. better to commit an empty work than not to commit)

In [74]:
from itertools import combinations
from collections import namedtuple, defaultdict
from random import choice, uniform
from copy import deepcopy

import numpy as np

In [75]:
State = namedtuple("State", ["x", "o"])
MAGIC = [2, 7, 6, 9, 5, 1, 4, 3, 8]

In [76]:
def win(elements):
    """Checks if elements is winning"""
    return any(sum(c) == 15 for c in combinations(elements, 3))


def state_value(pos: State):
    """Evaluate state: +1 first player wins"""
    if win(pos.x):
        return 1
    elif win(pos.o):
        return -1
    else:
        return 0
    
    
def print_board(pos):
    """Nicely prints the board"""
    for r in range(3):
        for c in range(3):
            i = r * 3 + c
            if MAGIC[i] in pos.x:
                print('X', end='')
            elif MAGIC[i] in pos.o:
                print('O', end='')
            else:
                print('.', end='')
        print()
    print()

In [77]:
def epsilon_greedy_policy(Qtable, state, epsilon, available):
    state_key = state_to_key(state)
    random_int = uniform(0, 1)
    if random_int > epsilon:
        actions = sorted(Qtable[state_key], reverse=True)
        if len(actions) == 0:
            action = choice(list(available))
        for a in actions:
            if a in available:
                action = a
                break
    else:
        action = choice(list(available))
    return action


def greedy_policy(Qtable, state, available):
    state_key = state_to_key(state)
    actions = sorted(Qtable[state_key], reverse=True)
    if len(actions) == 0:
        action = choice(list(available))
    for a in actions:
        if a in available:
            action = a
            break
    return action


def state_to_key(state):
    return (frozenset(state.x), frozenset(state.o))


def ply(state, epsilon, learning_rate, gamma, Qtable, available):
    action = epsilon_greedy_policy(Qtable, state, epsilon, available)
    available.remove(action)

    new_state = deepcopy(state)
    new_state.x.add(action)
    reward = state_value(new_state)

    hashable_state = state_to_key(state)
    hashable_new_state = state_to_key(new_state)

    current_Q = Qtable[hashable_state][action]
    next_max_Q = (
        max(Qtable[hashable_new_state].values()) if hashable_new_state in Qtable else 0
    )
    Qtable[hashable_state][action] = (1 - learning_rate) * current_Q + learning_rate * (
        reward + gamma * next_max_Q
    )

    return new_state


Qtable_ttt = defaultdict(lambda: defaultdict(lambda: 0.0))

n_training_episodes = 500_000
learning_rate = 0.5

n_eval_episodes = 1000

gamma = 0.5

max_epsilon = 1.0
min_epsilon = 0.05
decay_rate = 0.0003


def train(
    n_training_episodes,
    learning_rate,
    min_epsilon,
    max_epsilon,
    decay_rate,
    gamma,
    Qtable,
):
    for episode in range(n_training_episodes):
        epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(
            -decay_rate * episode
        )
        state = State(set(), set())

        available = set(range(1, 9 + 1))
        while available:
            state = ply(state, epsilon, learning_rate, gamma, Qtable, available)
            if win(state.x) or not available:
                break

            o = choice(list(available))
            state.o.add(o)
            available.remove(o)
            if win(state.o) or not available:
                break

    return Qtable


def test(n_training_episodes, Qtable):
    wins = 0
    for _ in range(n_training_episodes):
        state = State(set(), set())
        available = set(range(1, 9 + 1))

        while available:
            x = greedy_policy(Qtable, state, available)
            state.x.add(x)
            if win(state.x):
                wins += 1
                break
            if not available:
                break

            o = choice(list(available))
            state.o.add(o)
            available.remove(o)
            if win(state.o):
                break

    for _ in range(n_training_episodes):
        state = State(set(), set())
        available = set(range(1, 9 + 1))

        while available:
            o = choice(list(available))
            state.o.add(o)
            available.remove(o)
            if win(state.o) or not available:
                break

            x = greedy_policy(Qtable, state, available)
            state.x.add(x)
            if win(state.x):
                wins += 1
                break
            if not available:
                break

    return wins


train(
    n_training_episodes, learning_rate, min_epsilon, max_epsilon, gamma, decay_rate, Qtable_ttt
)

print(
    f"Q-Learn agent wins {test(n_eval_episodes, Qtable_ttt) * 100 / 2000}% of the time"
)

Q-Learn agent wins 25.45% of the time
