# RL solutions for combinatorial games: Temporal Difference Methods

**The game of Nim and the optimal strategy are defined(informally) in [this notebook](nim_MC_GLIE.ipynb)**

In [1]:
import numpy as np
from tqdm import tqdm

In [2]:
np.random.seed(42)

In [3]:
# A Simple Nim Game
class SimpleNim:
    def __init__(self, state):
        self.heaps = state
        self.turn = 0
        self.num_heaps = len(self.heaps)

    def get_state(self):
        if isinstance(self.heaps, list):
            return tuple(self.heaps)
        else:
            return tuple(self.heaps.tolist())

    def get_num_heaps(self):
        return self.num_heaps

    def get_heap_size(self, heap_index):
        if 0 <= heap_index < self.num_heaps:
            return self.heaps[heap_index]
        return None

    def is_game_over(self):
        return all(pebbles == 0 for pebbles in self.heaps)

    def make_move(self, heap_index, num_pebbles):
        if self.is_game_over():
            return False

        # Validate the move
        if 0 <= heap_index < self.num_heaps and 0 < num_pebbles <= self.heaps[heap_index]:
            self.heaps[heap_index] -= num_pebbles
            self.turn = 1 - self.turn  # Switch turns
            return True
        return False

    def winner(self):
        if self.is_game_over():
            return 1-self.turn
        return None

In [4]:
def play_optimal(game):
    if game.is_game_over():
        return
    heaps = game.get_state()

    nim_sum = 0
    largest = -1
    largest_index = -1
    for i, heap in enumerate(heaps):
        nim_sum ^= heap
        if heap > largest:
            largest = heap
            largest_index = i
    # print(nim_sum)
    if nim_sum != 0:
        game.make_move(largest_index, largest - (largest^nim_sum)) # optimal move
    else:
        game.make_move(largest_index, 1) # take the move that progresses the game to the least possible extent and hope for a blunder

In [5]:
# simple simulation of a nim game with both players using the optimal strategy
heaps = [10] * 3
game = SimpleNim(heaps)
while not game.is_game_over():
    play_optimal(game)
    print(game.get_state())
print("The winner is Player ", game.winner())

(0, 10, 10)
(0, 9, 10)
(0, 9, 9)
(0, 8, 9)
(0, 8, 8)
(0, 7, 8)
(0, 7, 7)
(0, 6, 7)
(0, 6, 6)
(0, 5, 6)
(0, 5, 5)
(0, 4, 5)
(0, 4, 4)
(0, 3, 4)
(0, 3, 3)
(0, 2, 3)
(0, 2, 2)
(0, 1, 2)
(0, 1, 1)
(0, 0, 1)
(0, 0, 0)
The winner is Player  0


## Defining State and Action Spaces:
**The State** is given by a configuration of heaps($\N_0$ means the set of naturals that includes zero):
$$ \mathcal S = \N_0^h $$
where h is the number of heaps
**Actions** are given by two tuples of heap number and number of pebbles to be taken out:
$$ \mathcal A(s) = H(s) \times \N $$ 
where $\N$ does not include zero and $H(s)$ is the set of heaps having non zero number of pebbles in the state s

In [6]:
# for simplicity consider the version of the problem that starts off with same number of pebbles in each heap and 3 heaps
num_heaps = 3
initial_pebbles = 5
max_episodes = 100000

In [7]:
class EpsilonGreedyPolicy:
    def __init__(self):
        self.epsilon = 1

    def update(self, epsilon):
        self.epsilon = epsilon

    def policy(self, s, Q):
        if s not in Q or np.random.random() < self.epsilon:
            index = np.random.randint(0, len(s))
            while s[index] == 0:
                index = np.random.randint(0, len(s))
            return (index, 1)
        else:
            return max(Q[s], key=Q[s].get)

    def final_policy(self, s, Q):
        if s not in Q:
            return (0,0)
        return max(Q[s], key=Q[s].get)

In [8]:
class SimpleGreedyPolicy:
    def update(self, epsilon): pass

    def policy(self, s, Q):
        if s not in Q:
            index = np.random.randint(0, len(s))
            while s[index] == 0:
                index = np.random.randint(0, len(s))
            return (index, 1)
        return max(Q[s], key=Q[s].get)

    def final_policy(self, s, Q):
        if s not in Q:
            return (0,0)
        return max(Q[s], key=Q[s].get)

In [9]:
def sample_step(game, A):
    game.make_move(A[0],A[1])
    play_optimal(game)
    S_next = game.get_state()
    if game.is_game_over() and game.winner() == 0: return 1, S_next
    else: return 0, S_next

In [10]:
# Monte Carlo Control with exploring starts
def sarsa_es(gamma=1, policy=EpsilonGreedyPolicy(), step_size=0.02, es=True):
    Q = {} # take default to be zero

    policy.update(0.05)

    for i in tqdm(range(max_episodes)):
        S = np.random.randint(0, initial_pebbles+1, num_heaps)
        while all(S == 0):
            S = np.random.randint(0, initial_pebbles+1, num_heaps)
        heap_num = np.random.randint(0,num_heaps)
        while S[heap_num] == 0:
            heap_num = np.random.randint(0,num_heaps)
        game = SimpleNim(S)

        S = tuple(S.tolist())
        if es: A = (heap_num, np.random.randint(1, S[heap_num]+1)) # Exploring start
        else: A = policy.policy(S, Q)

        while not game.is_game_over():
            R, S1 = sample_step(game, A)
            if game.is_game_over():
                A1 = (0,0)
            else:
                A1 = policy.policy(S1, Q)

            if S1 not in Q:
                Q[S1] = {}
            if A1 not in Q[S1]:
                Q[S1][A1] = 0
            if S not in Q:
                Q[S] = {}
            if A not in Q[S]:
                Q[S][A] = 0
            target = R + gamma*Q[S1][A1]
            delta = target-Q[S][A]
            Q[S][A] = Q[S][A] + step_size*delta

            A = A1
            S = S1

    return policy, Q

In [11]:
def nim_sum(s):
    xor = 0
    for heap in s:
        xor ^= heap
    return xor

In [12]:
def is_optimal(s, a):
    s = list(s)
    s[a[0]] -= a[1]
    return nim_sum(s) == 0

In [13]:
# Epsilon Greedy and exploring starts
policy, Q = sarsa_es()

print(Q)

100%|██████████| 100000/100000 [00:02<00:00, 42678.26it/s]

{(2, 0, 2): {(2, 1): 0.010172005062978171, (0, 1): 0.006680444109593852, (2, 2): 0.0, (0, 2): 0.0}, (3, 4, 2): {(0, 1): 0.00927048048344161, (2, 2): 0.015009217947385984, (0, 3): 0.019508952893018613, (1, 3): 0.24058532023446186, (0, 2): 0.3065759333657823, (1, 1): 0.24031462903734857, (1, 4): 0.007832213273179445, (1, 2): 0.01722731828429149, (2, 1): 0.39867617387195275}, (1, 0, 1): {(2, 1): 0.014135328443806194, (0, 1): 0.00010372108720674622}, (0, 0, 0): {(0, 0): 0}, (0, 2, 2): {(1, 1): 0.016559484004412456, (2, 2): 0.0, (2, 1): 0.004878302975632682, (1, 2): 0.0}, (2, 2, 4): {(2, 2): 0.011772073058729808, (1, 1): 0.3527179171752397, (0, 2): 0.016182369091054714, (2, 4): 0.3077928672500807, (2, 3): 0.23353689309886658, (0, 1): 0.3244365439014648, (1, 2): 0.010839928441612839, (2, 1): 0.008934603089724375}, (0, 1, 1): {(1, 1): 0.00817020736480991, (2, 1): 0.00826572822382068}, (2, 1, 3): {(1, 1): 0.01590383185144817, (0, 2): 0.03781142814825146, (2, 2): 0.02511684686222652, (2, 3): 0.




In [14]:
import itertools

def evaluate_on_winning_positions(policy, Q):
    fumbles = 0
    num_winning_pos = 0
    for s in itertools.product(range(initial_pebbles), repeat=num_heaps):
        if s == (0)*num_heaps: continue
        action = policy.final_policy(s, Q)
        if nim_sum(s) == 0: continue
        if not is_optimal(s, action):
            fumbles+=1
        num_winning_pos+=1
    return fumbles, num_winning_pos

In [15]:
fumbles, num_winning_pos = evaluate_on_winning_positions(policy, Q)
print('number of fumbles:', fumbles)
print('total number of winning positions:', num_winning_pos)
print('percentage of winning positions fumbled: {:.2f} %'.format(fumbles/num_winning_pos*100))

number of fumbles: 31
total number of winning positions: 106
percentage of winning positions fumbled: 29.25 %


In [16]:
# Epsilon Greedy but no exploring starts
policy, Q = sarsa_es(es=False)

print(Q)

100%|██████████| 100000/100000 [00:02<00:00, 41515.31it/s]

{(0, 3, 3): {(1, 1): 0.037824404674171246, (2, 1): 0.02279573680112294}, (5, 3, 4): {(2, 1): 0.032889260821572425, (0, 1): 0.22305672113661681, (1, 1): 0.02640301088450344}, (0, 2, 2): {(2, 1): 0.0179681704635771, (1, 1): 0.035087939210798266}, (0, 1, 1): {(2, 1): 0.023791689145528913, (1, 1): 0.0005674356323956036}, (0, 0, 0): {(0, 0): 0}, (1, 0, 2): {(0, 1): 0.6094228242112895, (2, 1): 0.44485132244781495}, (2, 0, 3): {(2, 1): 0.5366832371100378, (0, 1): 0.03379270002665906}, (1, 2, 3): {(1, 1): 0.041233031174576384, (2, 1): 0.4806583232284759, (0, 1): 0.033414606951486604}, (4, 2, 4): {(2, 1): 0.47436434935686483, (0, 1): 0.18153909280232938, (1, 1): 0.1777715808455672}, (1, 1, 0): {(0, 1): 0.0013508840528486169, (1, 1): 0.05437806256021763}, (2, 3, 0): {(1, 1): 0.675344746312926, (0, 1): 0.0235864361253679}, (4, 3, 0): {(0, 1): 0.6537626006911748, (1, 1): 0.009016543935712731}, (1, 2, 0): {(0, 1): 0.703323204073772, (1, 1): 0.5371152688761612}, (3, 0, 3): {(2, 1): 0.039075467356460




In [17]:
fumbles, num_winning_pos = evaluate_on_winning_positions(policy, Q)
print('number of fumbles:', fumbles)
print('total number of winning positions:', num_winning_pos)
print('percentage of winning positions fumbled: {:.2f} %'.format(fumbles/num_winning_pos*100))

number of fumbles: 70
total number of winning positions: 106
percentage of winning positions fumbled: 66.04 %
