# **MATH 4060 - Markov Decision Processes and Reinforcement Learning - Final Project**

# Can't Stop AI

In [None]:
import itertools
import jax.numpy as jnp
from jax import random as jrandom
from jax import nn as jnn
from jax import jit
import numpy as np
import random
import time
import sys
import jax

# Code Given

## Variable descriptions and Helper Function

In [None]:
'''Implementation of the board game "CANT STOP"

How the game is represented in Python:

-----Game parameters (that do not change while the game is being played)

N_PLAYERS
   A postive integer
  This is the number of players playing
  In the classic game rules, N_Players = 4

N_COL_TO_WIN
   A positive integer
   This is the number of columns you need to claim to win the game
   In the classic game rules, N_Col_To_Win = 3

N_MAX_RUNNERS
   A positive integer
   This is the maximum number of runners you can have
   In the classic game rules, N_Max_Runners = 3

PLAYER_COL_STATE_INIT
   An vector of shape (11,) of non-negative integers
   This is the number of squares in each game column
   In the classic game rules, this is [3,5,7,9,11,13,11,9,7,5,3]

NOTE on column labelling:
   In the game, the columns are labeled 2-12 (corresponding to dice rolls)
   In Python, the columns locations are indexed 0-10
   This means that to translate from column in Python to columns in the game,
   one must often add or subtract 2 from the column indices. 

---Variables: (that represent what is going on in the game as it is played)

active_player_index
   An index from the range [0,N_PLAYERS-1] indicating whose turn it current is

player_col_state
   An array of shape (N_players,11) of integers
   Each row is the number of squares remaining for that player in each col
   NOTE: 
     This is the number of squares REMAINING, these start at PLAYER_COL_STATE_INIT
     and count DOWN to zero as the game progresses. When the get to zero, the player has claimed the column
   WARNING:
     We will not prohibit these from being negative even though it doesn't mean anything in the game
     (This can happen if the player goes past the number needed to claim the column)

illegal_col
   A vector of shape (11,) of boolean
   Contains the information on which columns are still in play
   (columns that have been claimed by a player are not legal to play in anymore)

runner_col_state
  A vector of shape (11,) of non-negative integers
  Indicates the current state of how far the runners have advanced in each column
  A zero indicates that there is no runner in that column at all
  NOTES:
   1. count_nonzero(runner_col) should not exceed N_Max_Runners for legal runner states
   2. Since player_col_state counts DOWN to 0, runner_col is SUBTRACTED from player_col_state when the player chooses to stop rolling

dice_rolls
  A vector of shape (4,) of integers [1,6] indicating the outcome of the 4 dice rolls

runner_col_choices
  A vector of shape (N_choices, 11) of non-negative integers
  Indicates the available CHOICES the player has of where the runners could be
  This corresponds to legal choices for choosing pairings of the dice
  NOTES:
    1. By the rules of the game, N_choices can be at most 6
    2. If N_choices = 0, then this indicates that there are no legal moves and the player has busted

roll_again
  A boolean on whether or not the player wants to rolls again'''



## Dice Functions
Note that we will compress the four 6-sided dice rolls down to an integer in the range $[0,1295]$ which we call the `diceNum`. This is done by thinking of the roll as a 4 digit number in base 6.

In [None]:
def diceRoll_to_diceNum(diceRollArray):
  '''Converts an array of shape (4,) of the four 6 sided dice to the diceNum in [0,1295]'''
  #Note: Dice rolls are assumed to be numbers 1-6 (i.e. they start at 1!)
  powersOfSix = jnp.array([1,6,36,216])
  return jnp.inner(powersOfSix, jnp.array(diceRollArray)-1)

def diceNum_to_diceRoll(diceNum):
  '''Converts the diceNum in [0,1295] to an array of an array of shape (4,) of the 4 dice rolls'''
  #Note: Dice rolls are assumed to be numbers 1-6 (i.e. they start at 1!)
  powersOfSix = jnp.array([1,6,36,216])

  return 1+(diceNum // powersOfSix) % 6

In [None]:
def calculate_runnerDicePairArray():
  '''Returns an array of shape (6,11,1296) which contains the possible runner locations (encoded as one hot (11,) vectors)
   for all the possible 1296 dice rolls, and 6 possible ways to pair the dice'''
  #Input:
  # Nothing!
  #Output:
  #  a boolean vector of size (6,11,1296) so that out[i,j,:] is the (11,) boolean vector of the runner locations
  #  when the dice roll #i is rolled and choice j is selected
  #  This precomputed output is used when computing the runners that can occur in the game

  #Create an array of shape (4,6,6,6,6) that contains all possible dice rolls
  #  i.e. the entry [:,a,b,c,d] = [a,b,c,d] is 4 dice rolls and a,b,c,d all run from 0 to 5
  four_dice_indices = jnp.indices((6,6,6,6)) 
 
  #Create an array with all 6 ways to choose 2 out of 4 dice
  #  Pairing 0 = choose dice 1 and dice 2
  #  Pairing 1 = choose dice 1 and dice 3 
  #  ... 
  #  Pairing 5 = choose dice 3 and dice 4
  pairing = jnp.array([[1,1,0,0],[1,0,1,0],[1,0,0,1],[0,1,1,0],[0,1,0,1],[0,0,1,1]])

  #Create an array of shape (6,6,6,6,6) which gives the value of the pairing 
  #  i.e. the (a,b,c,d,p) entry is the value of pairing p when the dice come up a,b,c,d
  four_dice_pairings = jnp.einsum("iabcd,pi->abcdp",four_dice_indices,pairing)
  
  #The same array, but of shape (6,6,6,6,6,11) now where it has been converted to a one hot encoding
  #  i.e. (a,b,c,d,p:) is an array of shape (11,) with the one hot encoding of the pairing
  four_dice_pairings_one_hot = jnn.one_hot(four_dice_pairings,11)
  flattened_but_out_of_order = jnp.reshape(four_dice_pairings_one_hot,(1296,6,11),order='F')
  return jnp.transpose( flattened_but_out_of_order, (1,2,0)) #put them in the desired order!

## Helper functions for dealing with runners

In [None]:
@jit
def calculate_player_N_col_claimed(player_col_state):
  ''' Calculate player "scores" (i.e. number of columns claimed) from the board state'''
  #Input:
  #  player_col_state = An int array of size (N_players, 11) showing how many entries REMAINING until column is claimed for each player
  #Output: 
  #  An int vector of size (N_players,) showing how many columns each player has claimed. (In normal rules, first to 3 columns wins) 
  return  jnp.count_nonzero(player_col_state <= 0, axis=1)

@jit
def calculate_illegal_col(player_col_state):
  '''Calculate which columns are legal from the board state (i.e. the unclaimed columns)''' 
  #Input:
  #  player_col_state = An int array of size (N_players, 11) showing how many entries REMAINING until column is claimed for each player
  #Output: 
  #  An boolean vector of size (11,) showing which columns are legal 
  return jnp.any(player_col_state <= 0, axis=0)

@jit
def are_runners_legal(runner_col_states, illegal_col, N_MAX_RUNNERS=3):
  '''Checks if a batch of runner states are legal or not'''
  #Input:
  #  runner_col_states = an int vector of size (N,11) of runner positions
  #  illegal_col = a boolean vector of size (11,) with which columns are illegal
  #Output:
  #  a boolean vector of size (N,) with which of the runner_col_states are legal

  #Number of runners is legal iff there are <=N_MAX_RUNERS runners active:
  are_number_of_runners_legal = (jnp.count_nonzero(runner_col_states,axis=1) <= N_MAX_RUNNERS)

  #Check if all the runners are in legal columns
  #  In each column, either illegal_column must be 0 OR runners must be 0
  illegal_col_or_runner_is_0 = jnp.logical_or(runner_col_states == 0, illegal_col == False)
  #  This must happen in every single column
  are_runners_in_legal_col = jnp.all(illegal_col_or_runner_is_0,axis=1) 
  
  return jnp.logical_and(are_number_of_runners_legal,are_runners_in_legal_col)

DicePairArray = calculate_runnerDicePairArray()
@jit
def generate_all_choices_and_legality(dice_num,player_col_state,runner_col_state, N_MAX_RUNNERS=3): 
  illegal_col = calculate_illegal_col(player_col_state)
  #print("illegal_col", jnp.shape(illegal_col))
  '''Computes out ALL the possible moves based on the dice and 
  whether or not they are legal based on the current state and dice'''
  #  In this version, the input is the dice_num (which is a number in [0,1295]) and 
  #  the array DiceArray are assumed to exists
  #  which is caluclated with the DiceArray function
  #Calculate all the 9 possible moves of playing both pairs (i.e. double) and with any single pair
  # (We will work out which are legal moves afterwards!)

  #Use the dice_num to lookup the runner possibilities from DiceArray
  dice_sums_with_1_cols = DicePairArray[0:3,:,dice_num]
  dice_sums_without_1_cols = DicePairArray[5:2:-1,:,dice_num]

  #print("dice_sums_with_1_cols", jnp.shape(dice_sums_with_1_cols)) 
  #print("runner_col_state", jnp.shape(runner_col_state))
  #This 5:2:-1 gets the pairings in reverse order so that they are complentary to the pairings from dice_sums_with_1
  

  double_runner_choices = runner_col_state + dice_sums_with_1_cols + dice_sums_without_1_cols
  single_runner_choices_with_1 = runner_col_state + dice_sums_with_1_cols 
  single_runner_choices_without_1 = runner_col_state + dice_sums_without_1_cols

  #print("double_runner_choices",jnp.shape(double_runner_choices))

  #Compute if the choices with both pairing played (i.e. double) are legal
  are_double_runners_legal = are_runners_legal(double_runner_choices,illegal_col, N_MAX_RUNNERS)
  #print("are_double_runners_legal",jnp.shape(are_double_runners_legal))
  #print(are_double_runners_legal)
  are_double_runners_illegal = jnp.logical_not(are_double_runners_legal)


  #The moves with a single pair are only legal if the corresponding move with both pairs is illegal 
  #  (i.e. its legal to play only one pair iff after you play it, playing the next move is not legal)
  #  This means we can compute if they are legal on their own first and then
  #  logical_and it with the double runners

  #  first check if they would be ok on their own.
  are_single_runners_with_1_ok = are_runners_legal(single_runner_choices_with_1,illegal_col, N_MAX_RUNNERS)
  are_single_runners_without_1_ok = are_runners_legal(single_runner_choices_without_1,illegal_col, N_MAX_RUNNERS)

  #  then we logical and it with the double runners to only legalize these moves if playing both was illegal
  are_single_runners_with_1_legal = jnp.logical_and(are_double_runners_illegal,are_single_runners_with_1_ok)
  are_single_runners_without_1_legal = jnp.logical_and(are_double_runners_illegal,are_single_runners_without_1_ok)
  #print("ok!")
  #Combine everything together to be outputed
  #all_runner_choices = jnp.row_stack()
  all_runner_choices = jnp.row_stack([double_runner_choices,single_runner_choices_with_1,single_runner_choices_without_1]) 
  #print("all_runner_choices", jnp.shape(all_runner_choices))
  all_runner_choices_legal = jnp.concatenate([are_double_runners_legal, are_single_runners_with_1_legal, are_single_runners_without_1_legal])
  #print("all_runner_choices_legal", jnp.shape(all_runner_choices_legal))

  return all_runner_choices, all_runner_choices_legal

@jit
def update_player_col_state(active_player_index, player_col_state, runner_col_state):
  '''Move a players peices forward by the amount on the runners 
    (This is called when a player bank's their runners and ends their turn by choice)''' 
  #Input:
  #  active_player_index = index of whose turn it is
  #  player_col_state = int array of size (N_player, 11) with squares remaining in each column
  #  runner_col_state = int vector of size (11,) with runner locations
  #Output:
  #  An updated version of player_col_state where the positions have been moved up by the runners.

  #All we have to do is a subtraction, but ensure that we don't go below zero
  updated_active_player_col_state = jnp.clip(player_col_state[active_player_index] - runner_col_state, 0, None)
  #print("updated col state", jnp.shape(updated_active_player_col_state))
  #ans = player_col_state.at[active_player_index].set(updated_active_player_col_state)
  #print("ans", jnp.shape(ans))
  return player_col_state.at[active_player_index].set(updated_active_player_col_state)

## Helper Functions for AI

In [None]:
#This function also assumes that the DicePairArray global variable has been calculated already
@jit
def prob_to_miss_targets(targets):
  '''Compute the probability to miss a list of target cols'''
  #Input:
  #  targets = a boolean array of shape (11,) with which are targets
  #Output:
  #  A real number with the probability to miss all the targets from targets when rolling 4 dice and pairing them
  hit_target = jnp.einsum("abp,b->pa",DicePairArray,targets)
  any_hit_target = jnp.any(hit_target > 0, axis=1)

  #Count the number of times we get a hit!
  number_of_dice_rolls_that_hit_target = jnp.count_nonzero(any_hit_target)
  return (1296- number_of_dice_rolls_that_hit_target)/1296

@jit
def cant_stop_bust_probability(runner_col,illegal_col):
  #Purpose:
  #  Compute the bust_probability if we were to roll again in Can't Stop
  #Input:
  #  runner_col = an array of shape (11,) of integers with the runner locations
  #  illegal_col = an array of shape (11,) of boolean with which columns are illegal to play in 
  #NOTE:
  #  We assume N_Max_Runners = 3 for this one!

  N_Max_Runners = 3

  runner_loc = (runner_col > 0)

  num_runners = jnp.sum(runner_loc) # calculate runners by counting non-zero indices

  #set the target_columns that we will not bust if we hit
  # -the runner locations are always safe
  # -if num_runner<N_Max_Runners, the non-illegal columns are also safe so we add these o
  
  target_columns = jnp.logical_or(runner_loc, (num_runners < N_Max_Runners)*(jnp.logical_not(illegal_col)))
  
  #if num_runners >= N_Max_Runners:
  #  target_columns = runner_col
  #else:
  #  target_columns = jnp.logical_or(runner_col, 1-illegal_col)

  return prob_to_miss_targets(target_columns) # calculate the chance to miss playable columns

## Three simple AIs

Note that the AIs always return a tuple choice_index, roll_again 

choice_index = a number 0-8 indicating which choice they want to make 

roll_again = a boolean of whether or not they want to roll again 

In [None]:
@jit
def pure_random_AI(active_player_index, player_col_state, choices, legal, random_key):
  '''An AI that makes all choices purely at random.'''
  #Input:
  #  active_player_index = An int with whose turn it currently is (which player the AI is playing for)
  #  player_col_state = An int array of size (N_players, 11) showing how many entries REMAINING until column is claimed for each player
  #  choices = An array of size (9,11) with the 9 possible choices available to the player
  #  legal = An array of size (9,) with whether or not each of the 9 choices are legal
  #  N_Col_To_Win, N_Max_Runners = Integers that can specify variants of the game rules 
  #Output: A tupl (choice_index, roll_again)
  #  1st entry: choice_index = An integer in [0,8] with which choice is to be played
  #             (Note, you must make sure the index you play is legal!)
  #  2nd entry: roll_again = A boolean on whether or not to roll again 
  
  #To get the unique choices you can do jnp.unique(choices[legal==True],axis=0)

  keys =  jrandom.split(random_key)
  
  random_scores = jnp.abs( jax.random.normal(keys[0], jnp.shape(legal)) )  
  choice_index = jnp.argmax(legal*random_scores,axis=0)  #Since all random_scores have the same distribution, this is a random legal choice

  random_scores = jax.random.normal(keys[1], (2,) )
  roll_again = jnp.argmax(random_scores,axis=0) #choose to roll again randomly!
  
  return choice_index, roll_again

@jit
def random_timid_AI(active_player_index, player_col_state, choices, legal, random_key):
  '''An AI that chooses what to choose (somewhat) randomly, and then is timid about rolling again or not'''

  #Inputs/Outputs same as for the pure_random_AI
  
  #Choose the choice randomly
  #Rolls again if it has <=2 runners and doesnt roll again if it has 3 runners already

  random_scores = jnp.abs( jax.random.normal(random_key, (9,) ) )  
  choice_index = jnp.argmax(legal*random_scores)  #Since all random_scores have the same distribution, this is a random legal choice

  N_runners = jnp.count_nonzero( choices[choice_index] )
  roll_again = (N_runners <= 2) 

  return choice_index, roll_again #silly AI picks the first choice and stops rolling again

@jit
def runner_weights_AI(active_player_index, player_col_state, choices, legal, random_key):
  '''An AI that uses the position of the runners to make choices, taking into account that some columns are better than others'''
  #Inputs/Ouputs same as other AIs

  column_weights = jnp.array([6,5,4,3,2,1,2,3,4,5,6]) #The weights used for each column
  reroll_threshold = 13.0 #Reroll if the score is lower than this, otherwise stay

  scores = (choices @ column_weights) #Calculate the score for each of the choices
  
  choice_index = jnp.argmax( legal*scores  ) #Choose the best one of the legal options 

  N_runners = jnp.count_nonzero( choices[choice_index] ) #Number of runners in our choice

  #Reroll if we have only 1 runner or if we are less than the reroll threshold
  roll_again =  jnp.logical_or(N_runners <= 1, scores[choice_index] < reroll_threshold) 

  return choice_index, roll_again


# benchmark AIs
random_timid_AI_vmap = jax.jit(jax.vmap( random_timid_AI, in_axes=(None,2,2,1,0), out_axes=0 ))
runner_weights_AI_vmap = jax.jit(jax.vmap( runner_weights_AI, in_axes=(None,2,2,1,0), out_axes=0 ))
pure_random_AI_vmap = jax.jit(jax.vmap( pure_random_AI, in_axes=(None,2,2,1,0), out_axes=0 ))

# Feature Helper Functions


## Calculating the advance probability

Here we calculate the probability that we advance runners by one or two positions in each column on any given roll. This is done by creating a (3888, 11) array of the possible choices for all rolls. Then counting the number of 1's and 2's in each columns. This is the frequency that we advance by 1 or 2 in each column. If we advance by 2 then we must have rolled 2 pairs of identical die, so there are only 3 different ways to combine thus we divide by 3 * 1296 to get the probability. If we advance by 1 then we did not roll 2 pairs of identical die, so there are 6 different ways to combine these die, thus we divide by 6 * 1296. This gives us an the probability to advance by 1 in each column.

### Purpose

This array is used to calculate the true expected advancement of the runners on any given roll. In practice, I hard code this array into my AI so that it doesn't need to be global or recalculated every time.

The true expected advancement can be calculated by the following formula:

\begin{align*}
  \mathbb{E}[\text{R}] = \text{R} + (\text{P}[0] + 2\text{P}[1])((\text{R} > 0) + (1 - \text{I}) * (3 - \text{N_runners})) \tag{1}
\end{align*}

Where R is the runner column state, P[0] is the array of probabilities to advance by 1 in each column, P[1] is the array of probability to advance by 2 in each column, I is an array of illegal columns, and N_runner is the number of runners currently in play.

In [None]:
@jit
def calc_adv_probs():
  '''calculates the probability to advance by 1 or 2 in each column'''
  # Input: Nothing
  # Ouput: A (2,11) array where [0, :] is an array of probabilities of advancing by 1 in each column
  #                       and   [1, :] is an array of probabilities of advancing by 2 in each column

  four_dice_indices = jnp.indices((6,6,6,6)) 
  pairing = jnp.array([[1,1,0,0],[1,0,1,0],[1,0,0,1],[0,1,1,0],[0,1,0,1],[0,0,1,1]])
  four_dice_pairings = jnp.einsum("iabcd,pi->abcdp",four_dice_indices,pairing)
  four_dice_pairings_one_hot = jnn.one_hot(four_dice_pairings,11)

  # there are 3 possible choices in each roll, find the gained progress for all choices in every roll
  all_combinations = jnp.reshape(jnp.sum(four_dice_pairings_one_hot[:,:,:,:,jnp.array([[0, 5], [1, 4], [2,3]])], axis=5), (3888, 11))

  # count num of ways to advance by 1 or 2 in each column
  num_adv_1_col = jnp.count_nonzero(all_combinations[:][:] == jnp.array([1,1,1,1,1,1,1,1,1,1,1]), axis=0)/7776 # (6 * 1296) # 6 different combinations get same result
  num_adv_2_col = jnp.count_nonzero(all_combinations[:][:] == jnp.array([2,2,2,2,2,2,2,2,2,2,2]), axis=0)/3888 # (3 * 1296) # 3 different combinations get same result 

  # turn into probabilty and return
  return jnp.stack([num_adv_1_col, num_adv_2_col])

def calc_adv_probs_test():
  '''Checks if the probabilities add to 1 to satisfy probability axiom'''
  adv_prob = calc_adv_probs()

  # probability should add to 1
  assert(jnp.isclose(jnp.sum(adv_prob[0] + adv_prob[1]), 1))
  print("Probability to advance by 1 in any column", jnp.sum(adv_prob[0]))
  print("Probability to advance by 2 in any column", jnp.sum(adv_prob[1]))

calc_adv_probs_test()

Probability to advance by 1 in any column 0.8873456
Probability to advance by 2 in any column 0.112654306


# Train AI

This AI just has a weights parameter so that it can be modified while training without needing a global variable.

In [None]:
@jit
def Train_AI(w, active_player_index, player_col_state, choices, legal):
  '''An AI function using some weights w which are assumed to be a global variable'''
  #Input:
  # w - the weights
  # active_player_index - whose turn it is
  # player_col_state - distance each player is away from claiming each column
  # choices - an array of possible choices that the AI can make
  # legal - an array of which choices are legal
  #Output:
  # best_choice_index - if there is a legal choice, this is the legal choice that gives the highest estimated probability to win;
  #                     if there is no legal choice it gives the illegal choice with the highest estimated probability to win.
  # roll_again - choice whether to roll again or stay - this is whatever choice gives the highest estimated probability to win.

  # calculate N columns claimed and the illegal columns to be used in the AI
  player_N_col_claimed = calculate_player_N_col_claimed(player_col_state)
  illegal_col = calculate_illegal_col(player_col_state)

  # this passes all the choices into the value function at once
  # the result is a vector of shape (9,) with the value of each of the 9 options
  all_vals = Train_v_func(w, active_player_index, player_col_state, player_N_col_claimed, illegal_col, choices)
  
  # take the maximum but only overthose that are legal!
  # (multuplying by "legal" sets the value of any illegal moves to 0)
  best_choice_index = jnp.argmax(all_vals * legal)

  # deciding whether or not to roll again (uses the Q function)
  # compute the value for rolling again and staying from that state.
  runner_col_state = choices[best_choice_index]
  q_roll_again_val = Train_q_roll_again(w, active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_state)
  q_stay_val = Train_q_stay(w, active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_state)

  # checks if the value of rolling again is higher than the value of staying
  # in this case we should roll again
  roll_again = q_roll_again_val > q_stay_val
    
  return best_choice_index, roll_again

def reflect_column_weights(w):
  '''Takes in 6 weights and reflects them around the middle column to cover all 11 columns'''
  #Input: 
  # w - array of shape (6,) of the weights to reflect
  #Output:
  # array of shape (11,) of the reflected weights around the middle column

  reflected_w = jnp.zeros(11)
  reflected_w = reflected_w.at[0:6].set(w[0:6])
  reflected_w = reflected_w.at[6:12].set(w[4::-1])
  return reflected_w

@jit
def Train_v_func(w,active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_state):
  '''Value function for the AI. Estimates the probability of the active player to win. (i.e. the value of the game where +1 for a win and 0 for a loss)'''
  #Input:
  # w - the weights
  # active_player_index - whose turn it is
  # player_col_state - distance each player is away from claiming each column
  # player_N_col_claimed - the number of columns each player has claimed
  # illegal_col - the current columns that cannot be played in
  # runner_col_state - the distance runner has travelled in each column
  #Output:
  # The probability to win in the given state.

  #Divide the weights w into three categories
  runner_w = w[0:6] # weights for runner location features
  player_col_w = w[6:12] # the amount we have advanced in each column
  enemy_col_w = w[12:18] # the amount the enemy has advanced in each column
  player_col_claim_w = w[18:24] # weights for the columns that we have claimed
  enemy_col_claim_w = w[24:30] # weights for the columns that the enemy has claimed
  player_N_col_claimed_w = w[30] # weights for number of columns we have claimed
  enemy_N_col_claimed_w = w[31] # weights for the number of columns the enemy has claimed
  affine_w = w[32] # this weight just allows for affine functions

  # we use symmetry to reduce number of weights, so here we reflect the weights
  # so they cover all 11 columns.
  runner_w = reflect_column_weights(runner_w)
  player_col_w = reflect_column_weights(player_col_w)
  enemy_col_w = reflect_column_weights(enemy_col_w)
  player_col_claim_w = reflect_column_weights(player_col_claim_w)
  enemy_col_claim_w = reflect_column_weights(enemy_col_claim_w)

  score = jnp.inner(runner_w, runner_col_state)

  #We look only at the difference in player locations here
  column_lengths = jnp.array([3, 5, 7, 9, 11, 13, 11, 9, 7, 5, 3])
  player_col = (column_lengths - player_col_state[active_player_index])*(1 - illegal_col)
  enemy_col = (column_lengths - player_col_state[1 - active_player_index])*(1 - illegal_col)

  score += jnp.inner(player_col_w, player_col)
  score -= jnp.inner(enemy_col_w, enemy_col)
  
  #Here we look only at the difference in the N col climaed between the two players
  player_col_claimed = jnp.where(player_col_state[active_player_index] < 1, 1, 0)
  enemy_col_claimed = jnp.where(player_col_state[1 - active_player_index] < 1, 1, 0)

  # calculates the precentage of the column we have claimed

  score += player_N_col_claimed[active_player_index] * player_N_col_claimed_w
  score -= player_N_col_claimed[1 - active_player_index] * enemy_N_col_claimed_w

  score += jnp.inner(player_col_claim_w, player_col_claimed)
  score -= jnp.inner(enemy_col_claim_w, enemy_col_claimed)

  score += affine_w

  #Apply a sigmoid to the score so that the value function is always between 0 and 1
  return jnn.sigmoid(score)

@jit
def Train_q_roll_again(w,active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_state):
  '''Find the approximated value for rolling again in terms of the value function v'''
  #Input:
  # w - the weights
  # active_player_index - whose turn it is
  # player_col_state - distance each player is away from claiming each column
  # player_N_col_claimed - the number of columns each player has claimed
  # illegal_col - the current columns that cannot be played in
  # runner_col_state - the distance runner has travelled in each column
  #Output:
  # An estimate of the probability to win given that we choose to roll again.
  
  # compute the score if we would bust
  zero_runner = jnp.zeros(11, dtype=jnp.dtype('u1')) 
  bust_value = 1 - Train_v_func(w, 1 - active_player_index, player_col_state, player_N_col_claimed, illegal_col, zero_runner) 

  # calculates the probabilities of advancing 1 or 2 in each column
  #adv_probs = calc_adv_probs()
  # i hard code this!
  adv_prob = jnp.array([[0.02700617, 0.05246913, 0.07638889, 0.09876543, 0.11959876, 0.13888888, 0.11959876, 0.09876543, 0.07638889, 0.05246913, 0.02700617], [0.0007716, 0.00308642, 0.00694444, 0.01234568, 0.01929012, 0.02777778, 0.01929012, 0.01234568, 0.00694444, 0.00308642, 0.0007716]])

  # compute the score if we would not bust
  N_runner = jnp.count_nonzero(runner_col_state)

  # calculates the expected runner col state as in equation (1)
  advance_runner = runner_col_state + (1 * adv_prob[0] + 2 * adv_prob[1]) * ((runner_col_state > 0) + (1 - illegal_col) * (3 - N_runner))

  # the value of the expected runner col state
  advance_value = Train_v_func(w, active_player_index, player_col_state, player_N_col_claimed, illegal_col, advance_runner)

  # find the probability of busting
  p_bust = cant_stop_bust_probability(runner_col_state,illegal_col)

  # the answer is the convex combination of the bust value and advance value using p_bust
  return p_bust * bust_value + (1 - p_bust) * advance_value

@jit
def Train_q_stay(w,active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_state):
  '''Returns the value (according to the value function v) of staying'''
  #Input:
  # w - the weights
  # active_player_index - whose turn it is
  # player_col_state - distance each player is away from claiming each column
  # player_N_col_claimed - the number of columns each player has claimed
  # illegal_col - the current columns that cannot be played in
  # runner_col_state - the distance runner has travelled in each column
  #Output:
  # An estimate of the probability to win given that we choose to stay.

  # if we stay, the runners advance and we can update our game state
  updated_player_col_state = update_player_col_state(active_player_index,player_col_state,runner_col_state) 
  updated_player_N_col_claimed = calculate_player_N_col_claimed(updated_player_col_state)
  updated_illegal_col = calculate_illegal_col(updated_player_col_state)

  # the runners will return to zero
  zero_runner = jnp.zeros(11, dtype=jnp.dtype('u1')) 

  # since v_func is estimating the probability of the active player to win,
  # if we choose to stay we need to do P(we win) = 1-P(other player wins)
  return 1 - Train_v_func(w, 1 - active_player_index, updated_player_col_state, updated_player_N_col_claimed, updated_illegal_col, zero_runner)


## AI value function Gradient
grad_Train_v_func = jax.jit(jax.grad(Train_v_func, 0))

## VMAP AI
Train_AI_vmap = jax.jit(jax.vmap(Train_AI, in_axes=(None, None, 2, 2, 1), out_axes=0 ))
Train_v_func_vmap = jax.jit(jax.vmap(Train_v_func, in_axes=(None, None, 2, 1, 1, 1), out_axes=0))
grad_Train_v_func_vmap=jax.jit(jax.vmap(jax.grad(Train_v_func, 0), in_axes=(None, None, 2, 1, 1, 1), out_axes=1))

# Training

**Method:** I chose to use SARSA for linear function approximation and learned the weights directly on the value function.

**State Space:** The state space consists of three parts *(active_player_index, player_col_state, runner_col_state)*.

* *active_player_index* - $0$ or $1$.
* *player_col_state* - array of shape $(2, 11)$ with the distance for each player to claim each column.
* *runner_col_state* - array of shape $(11, )$ with the distance runners have advanced in each column.

**Action Space:** 

* *choice* - number between $0$ and $8$ of the $9$ possible ways to advance runners.
* *roll_again* - True if we choose to roll again, False if we choose to stay.

**Reward Space:** When transitioning to a terminal state the active player gets a reward of $1$.

**Action Selection:** By SARSA, we choose the action through and $ɛ-$greedy selection. We choose a purely random action with probability $ɛ$ and we choose the best action with probability $1 - ɛ$. As we train $ɛ$ will be reduced and should converge to $0$. Throughout training $ɛ$ was updated using the following rule $ɛ = \frac{0.1}{\sqrt{i+1}}$.

**Learning Rate Selection:** I used a learning rate of $\alpha = \frac{0.02}{\sqrt{i+1}}$ where $i$ is the current iteration. I chose this through checking a bunch of different learning rate. I found that this learning rate produced the best results. I have tried several others but have not seen as much success.

**Update Rules:** I have two update rules. One when we are not in a terminal state, and one when we are in a terminal state.

- Mid game update rule: $w = w + \alpha(v_{func}(S') - v_{func}(S))\Delta v_{func}(S)$
- Terminal update rule: $w = w + \alpha(1 - v_{func}(S))\Delta v_{func}(S)$

The terminal update rule is used when we are in a terminal state, in this case the active player has won the game. Therefore, $v_{func}(S') = 1$.

### Training Helper Function

In [None]:
def epsilon_action(w, player_col_state, choices, legality, random_key, epsilon=0.10):
  '''determine if we should choose action greedy, or purely random'''
  #Input: 
  # w - the weights
  # player_col_state - distance each player is away from claiming each column
  # choices - an array of possible choices that the AI can make
  # legality - an array of which choices are legal
  # random_key - a random key
  # epsilon - the probability to choose a purely random action
  #Output:
  # An estimate of the probability to win given that we choose to stay.

  probability_epsilon_event = (np.random.random() < epsilon)
  
  keys = jrandom.split(random_key)

  if probability_epsilon_event:
    # Since all random_scores have the same distribution, this is a random legal choice
    # choose random action!
    random_scores = jnp.abs(jax.random.normal(keys[0], jnp.shape(legality)))  
    choice = jnp.argmax(legality * random_scores,axis=0)
    roll_again = jrandom.randint(keys[1], jnp.shape(choice), 0, 2)
  else:
    choice, roll_again = Train_AI(w, 0, player_col_state, choices, legality)

  return (choice, roll_again)

def update_v_func_mid(w, active_player_index, next_active_player_index, c_state, n_state, learning_rate):
  '''Update rule for training on the value function during a transition to a non-terminal state'''
  #Input:
  # w - the weights
  # active_player_index - the index of the current player
  # next_active_player_index - the index of the current player in the next state
  # c_state - tuple of the player_col_state and runner_col_state (as described above) in the current state
  # n_state - tuple of the player_col_state and runner_col_state in the next state
  # learning_rate - the step size for the update rule
  #Output:
  # new weights after the update rule

  # get other required information
  player_N_col_state = calculate_player_N_col_claimed(c_state[0])
  illegal_col = calculate_illegal_col(c_state[0])

  next_player_N_col_state = calculate_player_N_col_claimed(n_state[0])
  next_illegal_col = calculate_illegal_col(n_state[0])

  # pack into single tuple for readability
  c_state = (w, active_player_index, c_state[0], player_N_col_state, illegal_col, c_state[1])
  n_state = (w, next_active_player_index, n_state[0], next_player_N_col_state, next_illegal_col, n_state[1])

  # the value of the current state
  curr_v_func = Train_v_func(*c_state)

  # if the next player is the same, then it is still our turn and the probability to win in the next state is
  # given by the value function. If the next player is not the same, then it is not our turn anymore. In this case
  # the value of the next state is given by 1 - v_func.
  if active_player_index == next_active_player_index:
    next_v_func = Train_v_func(*n_state)
  else:
    next_v_func = 1 - Train_v_func(*n_state)

  # the update rule [ w += alpha*(v(s') - v(s))*grad_v(v) ]
  delta = next_v_func - curr_v_func
  w = w + learning_rate*(delta)*grad_Train_v_func(*c_state)

  return w

def update_v_func_terminal(w, active_player_index, c_state, learning_rate):
  # This is the update rule when the next state would be terminal
  # in this case the value is 1.
  #Input:
  # w - current weights
  # active_player_index - the index of the player whose turn it currently is
  # c_state - the current state (player_col_state, runner_col_state)
  # learning_rate - the learning rate
  #Output:
  # new weights

  # get other required information
  player_N_col_state = calculate_player_N_col_claimed(c_state[0])
  illegal_col = calculate_illegal_col(c_state[0])

  # pack into single tuple for readability
  c_state = (w, active_player_index, c_state[0], player_N_col_state, illegal_col, c_state[1])

  # the update rule
  delta = 1 - Train_v_func(*c_state)
  w = w + learning_rate*(delta)*grad_Train_v_func(*c_state)

  return w
  
def roll_and_get_choices(key, player_col_state, runner_col_state, N_MAX_RUNNERS=3):
  dice_num =  int(jax.random.randint(key, (1,), 0,1296)) #random.randint(0,1295)
  runner_choices, runner_legal = generate_all_choices_and_legality(dice_num, player_col_state, runner_col_state, N_MAX_RUNNERS)

  return runner_choices, runner_legal

### Simulator

In [None]:
def train_ai_sim(w, key, enemy, learning_rate=5, epsilon=0.1, verbose=False, N_PLAYERS=2, N_COL_TO_WIN=5, N_MAX_RUNNERS=3, PLAYER_COL_STATE_INIT=[3,5,7,9,11,13,11,9,7,5,3]):
  # initialize the game state
  player_col_state = jnp.tile(jnp.array(PLAYER_COL_STATE_INIT, dtype=jnp.dtype('i1')), (N_PLAYERS, 1))
  runner_col_state = jnp.zeros(11, dtype=jnp.dtype('u1'))
  active_player_index = random.randint(0,N_PLAYERS-1)
  
  game_in_progress = True
  while game_in_progress:

    roll_again_state = not_busted_state = True
    while roll_again_state:
      # roll die
      key, subkey_1, subkey_2 = jax.random.split(key, 3)
      runner_choices, runner_legal = roll_and_get_choices(subkey_1, player_col_state, runner_col_state)

      # check if we have busted
      any_legal_choices = jnp.any(runner_legal)
      not_busted_state = not_busted_state and any_legal_choices

      # if not busted choose an action
      if not_busted_state:
        # if it is our turn then we use epsilon greedy, otherwise we use the enemy AI
        if active_player_index == 0:
          choice_index, roll_again_state = epsilon_action(w, player_col_state, runner_choices, runner_legal, key, epsilon=epsilon)
        else:
          choice_index, roll_again_state = enemy(active_player_index, player_col_state, runner_choices, runner_legal, subkey_2)

        # if we choose an illegal action we have busted!
        not_busted_state = not_busted_state and runner_legal[choice_index]

        # get next runner col state
        next_runner_col_state = runner_choices[choice_index]

      # if you are busted you are not allowed to roll again!
      roll_again_state = roll_again_state and not_busted_state
    
      # we update the value function with only the change in the runner column
      c_state = (player_col_state, runner_col_state)
      n_state = (player_col_state, next_runner_col_state)
      w = update_v_func_mid(w, active_player_index, active_player_index, c_state, n_state, learning_rate)

      runner_col_state = next_runner_col_state

    # end of players turn, calculate the next state

    # if we bust then we should not advance player col
    next_runner_col_state = runner_col_state * not_busted_state
    next_player_col_state = update_player_col_state(active_player_index, player_col_state, next_runner_col_state)

    # at the start of the next turn, runners should be zero and we go to next player
    next_runner_col_state = jnp.zeros(11, dtype=jnp.dtype('u1'))
    next_active_player_index = (active_player_index + 1) % N_PLAYERS

    # check if game is still in progress
    player_N_col_claimed = calculate_player_N_col_claimed(next_player_col_state)
    game_in_progress = not jnp.any(player_N_col_claimed >= N_COL_TO_WIN)

    if game_in_progress:
      # update value function when we switch turn
      c_state = (player_col_state, runner_col_state)
      n_state = (next_player_col_state, next_runner_col_state)
      w = update_v_func_mid(w, active_player_index, next_active_player_index, c_state, n_state, learning_rate)

      # update the current state
      player_col_state = next_player_col_state
      active_player_index = next_active_player_index
      runner_col_state = next_runner_col_state

  # give reward if we win!
  c_state = (player_col_state, runner_col_state)
  w = update_v_func_terminal(w, active_player_index, c_state, learning_rate)

  return player_N_col_claimed >= N_COL_TO_WIN , w

In [None]:
def train_ai(weights, n_runs, enemy, verbose=True):
  '''Train the AI'''
  #Input:
  # weights - the weights
  # n_runs - the number of runs
  # enemy - the enemy AI to train again
  # verbose - optional parameter, this will make the simulator output what is happening
  #Output:
  # The new weights

  # generate a new random key seeded with time
  key = jrandom.PRNGKey(int(time.time()))

  # initialize the score to 0s
  score = jnp.zeros(2)
  for i in jnp.arange(n_runs):

    # these are the best hyperparameters that I could find
    learning_rate=0.02/jnp.sqrt(i+1)
    epsilon=0.1/jnp.sqrt(i+1)

    # Every 10 games we output the current run, the score, learning rate, epsilon, and new weights.
    if (i % 100 == 0):
      print(f"Run {i+1} of {n_runs}")
      print(f"Score: {score}")
      print(f"Learning Rate: {learning_rate}")
      print(f"Epsilon: {epsilon}")
      print(f"Weights: {weights}")

    # split the key, run the simulator, and update the score
    key, _ = jrandom.split(key)
    wins, weights = train_ai_sim(weights, key, enemy, learning_rate=learning_rate, epsilon=epsilon, verbose=verbose)
    score += wins

  # Print the final score and the final weights.
  print(f"After {n_runs} games, first player wins {100*score[0]/n_runs:.2f}% of the time")
  print(f"New weights: {weights}")

  return weights

In [None]:
jnp.set_printoptions(linewidth=100)

w = jnp.ones(33)*0.1
new_w = train_ai(w, 10000, runner_weights_AI, verbose=False)

Run 1 of 10000
Score: [0. 0.]
Learning Rate: 0.019999999552965164
Epsilon: 0.10000000149011612
Weights: [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1
 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]


# Testing


In [None]:
def simulate_game_parr(Player_AI, Verbose = False, random_key=None, Starting_Player=0, N_PARR=100, N_PLAYERS=2, N_COL_TO_WIN=5, N_MAX_RUNNERS=3, PLAYER_COL_STATE_INIT=None):
  '''Run a simulation of the game Can't Stop! running N_PARR games in parrallel'''
  #Note that the TURN order for all the games is the same, so who is the starting player matters! 

  #Input:
  #   random_key = A jax random key used to make all the dice rolls for the game
  #   Player_AI = List with the functions for player AIs
  #   Verbose = whether or not to print out a play-by-play of the game
  #Output:
  #  An array of shape (N_players,) with a 1 at the player who won
  
  #Initialize game state
  #Initialize game state
  if PLAYER_COL_STATE_INIT == None:
    player_col_state = jnp.tile(jnp.array([3,5,7,9,11,13,11,9,7,5,3],dtype=jnp.dtype('i1')),(N_PARR, N_PLAYERS, 1))
  else:
    player_col_state = jnp.tile( PLAYER_COL_STATE_INIT, (N_PARR, 1, 1))  
  
  player_col_state = jnp.transpose(player_col_state, (1,2,0)) #Move it so the N_PARR dimension is LAST
  
  if random_key == None:
    random_key = jrandom.PRNGKey(int(time.time()))
  
  #Record which games are in progress (to make sure we don't mess with those games)
  game_in_progress = jnp.ones(N_PARR, dtype=bool)
  
  #Note that the player whose turn it is the SAME accross all games
  #This means its improtant to run the function more than once so there is no bias towards whoever plays first
  active_player_index = Starting_Player - 1 
  
  #Main loop that goes until all games are over

  turn_num = 0
  roll_num = 0
  while jnp.any(game_in_progress): #This will loop until the game ends 
    turn_num += 1
    
    #Update whose turn it is
    active_player_index = (active_player_index + 1) % N_PLAYERS
    
    if Verbose : print("Player ",active_player_index,":") 
    #if Verbose : print("--Player Column State: \n",player_col_state)

    #Reset runners and "busted"/"roll again" flags
    runner_col_state = jnp.zeros( (11,N_PARR) ,dtype=jnp.dtype('u1'))

    #roll_again_state and not_busted_state keep track of whether or not the player has
    #chosen to roll again and/or busted yet 
    #This is vector of length N_PARR and is only true for games still in progress
    roll_again_state = game_in_progress
    not_busted_state = game_in_progress 

    #Loop while player is chooising to rolling on their turn
    while jnp.any( roll_again_state ):
      roll_num += 1

      #This represents a random dice roll for each N_PARR simulation 
      # (1296 = 6**4 is the number of possibilities for 4 6-sided dice)
      random_key, subkey_1, subkey_2 = jax.random.split(random_key,3)
      dice_num = jrandom.randint(subkey_1, (N_PARR,) , 0,1296)
      
      if Verbose : print("----DiceNums: ",dice_num) 

      #Generate all 9 possible runner choices and whether or not they are legal
      runner_choices, runner_legal = generate_all_choices_and_legality(dice_num,player_col_state,runner_col_state, N_MAX_RUNNERS)
      
      #Update the busted state: you are only not busted if you have at least one legal choice
      any_legal_choices = jnp.any(runner_legal==True,axis=0)

      #To stay in the game, you must be in the game already and have legal choices
      # OR you chose to stop rolling before this turn
      chose_to_stop_rolling_or_legal_choices = jnp.logical_or(any_legal_choices, jnp.logical_not(roll_again_state))
      not_busted_state = jnp.logical_and(not_busted_state, chose_to_stop_rolling_or_legal_choices)

      #Send the choices to the AI to choose from
      #Note that we run this even for ones where we've already busted (we just make sure to do nothing with this data) 
      active_player_AI = Player_AI[active_player_index]
      
      parr_random_keys = jax.random.split(subkey_2, N_PARR)
      choice_index, new_roll_again_state = active_player_AI(active_player_index, player_col_state, runner_choices, runner_legal, parr_random_keys)

      #Ensure the choice you made was legal. If you make an illegal choice we count it as if you busted.
      choice_was_legal = runner_legal[choice_index, jnp.arange(N_PARR)]
      chose_to_stop_rolling_or_choice_was_legal = jnp.logical_or(choice_was_legal, jnp.logical_not(roll_again_state))
      not_busted_state = jnp.logical_and(not_busted_state, chose_to_stop_rolling_or_choice_was_legal)

      #Find the runner position if they would advance according to the choices, choosing choice_index[i] for simulation number i
      new_runner_col_state = jnp.transpose(runner_choices[choice_index,:,jnp.arange(N_PARR)], (1,0) )

      #Update the runners to these new positions only if you were still rolling!
      # If you don't meet this criteria, then your roll doesnt count and runners stay where they are
      runner_col_state = jnp.where( roll_again_state, new_runner_col_state, runner_col_state)
      
      #runner_col_state = jnp.where( jnp.logical_and(roll_again_state,not_busted_state), new_runner_col_state, runner_col_state)

      #Update roll_again_state for next round.
      #In order to roll_again next round, three things must all happend: 
      #1. You are not busted
      #2. You chose to roll again last time 
      #3. You chose to roll again this time 
      roll_again_state = jnp.logical_and( jnp.logical_and( not_busted_state,roll_again_state), new_roll_again_state) 
      
      
      if Verbose : print("----Roll Iteration: ", roll_num,"\n", "Busted_State: ",jnp.logical_not(not_busted_state),"Roll_Again_State: ",roll_again_state)
    
    #-----------------------
    #End of the players turn:
    #-----------------------
    
    #This line resets the runners of anyone who had busted to zero
    runner_col_state = runner_col_state * not_busted_state

    #Update the player positions!
    player_col_state = update_player_col_state(active_player_index,player_col_state,runner_col_state)
    
    if Verbose:
      [print(f"player_col_state Sim #{i} \n",player_col_state[:,:,i]) for i in range(N_PARR)]

    player_N_col_claimed = calculate_player_N_col_claimed(player_col_state)

    
    game_in_progress = jnp.logical_and( game_in_progress, jnp.all(player_N_col_claimed < N_COL_TO_WIN, axis=0))
    if Verbose: print(f"Turn # {turn_num}. Games in progress: {jnp.sum(game_in_progress)}")
   
  #At the end of this loop, one player has won!
  if Verbose : 
    print("GAME OVER!") 
    print(f"Number of rolls simulated {roll_num}")
    print(f"Final number of columns claimed \n {player_N_col_claimed}")
 
  return jnp.sum( player_N_col_claimed >= N_COL_TO_WIN , axis=1)

In [None]:
def monte_carlo_test_state(AI, n_games, init_state=None, Verbose=False):
  outcome = simulate_game_parr(AI, N_PARR=n_games, PLAYER_COL_STATE_INIT=init_state, Verbose=Verbose)
  prob_to_win = outcome[0]/n_games
  return prob_to_win

def monte_carlo_and_v_func(w, enemy_AI, n_games, init_state):
  mc_value = monte_carlo_test_state([Good_AI_vmap, enemy_AI], n_games, init_state)
  value = Good_v_func(w, 0, init_state, calculate_player_N_col_claimed(init_state), calculate_illegal_col(init_state), jnp.zeros(11))
  print(f"MC = {mc_value}, v_func = {value}")

In [None]:
def compare_MC_to_v_func_test(w, enemy_AI, n_games):
  '''This function will test a bunch of initial states by printing the monte carlo probability of winning
      and then printing the v_func estimate of winning, these should be similar'''

  # !! THIS FUNCTION CAN TAKE A LONG TIME TO RUN !!
  init_state = jnp.array([[1,1,1,1,1,1,1,1,1,1,1], [1,1,1,1,1,1,1,1,1,1,1]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,7,9,11,13,11,9,7,5,3], [0,0,3,9,11,13,11,9,7,0,0]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,7,3,11,4,11,1,3,5,3], [0,5,3,9,11,1,11,9,7,2,0]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,7,9,11,13,11,9,7,5,3], [0,5,7,9,0,0,0,9,7,5,0]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,7,9,11,13,11,9,7,5,3], [3,5,7,9,11,13,11,9,7,5,3]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,0,9,11,13,11,9,7,5,3], [3,5,0,9,11,13,11,9,7,5,3]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,7,9,11,13,11,9,7,5,3], [3,5,7,9,11,0,11,9,7,5,3]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,7,9,11,13,11,9,7,5,3], [3,5,7,9,11,0,0,9,7,5,3]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,7,9,11,13,11,9,7,5,3], [3,5,7,9,11,0,0,0,7,5,3]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,7,9,11,13,11,9,7,5,3], [3,5,7,9,11,0,0,0,0,5,3]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,7,9,11,13,11,9,7,5,3], [3,5,7,9,11,0,0,0,0,0,3]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,7,9,11,13,11,9,7,5,3], [0,5,7,9,11,13,11,9,7,5,0]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[0,5,7,9,11,0,11,9,7,5,0], [0,0,7,9,11,13,11,9,7,5,0]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

  init_state = jnp.array([[3,5,7,9,0,13,0,9,7,5,3], [0,5,7,9,11,13,11,9,7,5,0]])
  monte_carlo_and_v_func(w, enemy_AI, n_games, init_state)

In [None]:
mc_value = monte_carlo_test_state([Good_AI_vmap, random_timid_AI_vmap], 500)
print(mc_value)

0.97800004


In [None]:
w = jnp.array([ 0.1035391 ,  0.02324084,  0.00629771, -0.01452476,  0.00655654,  0.00639482,  0.28484833,  0.20598306,
  0.12112736,  0.05217794  ,0.11130193  ,0.08123927  ,0.27438858  ,0.2049303   ,0.11780114  ,0.03907273,
  0.10480063,  0.07496218  ,0.20313375  ,0.2138879   ,0.20972085  ,0.24800974  ,0.19582418  ,0.12072659,
  0.2825165 ,  0.23228781  ,0.24480774  ,0.11446477  ,0.1909898   ,0.09358135  ,0.69109344  ,0.6588065,
  0.23544775])
compare_MC_to_v_func_test(w, Good_AI_vmap, 1000)

MC = 0.718000054359436, v_func = 0.6771398782730103
MC = 0.0010000000474974513, v_func = 0.019828803837299347
MC = 0.16600000858306885, v_func = 0.15622536838054657
MC = 0.0, v_func = 0.016315674409270287
MC = 0.562000036239624, v_func = 0.5585915446281433
MC = 0.5750000476837158, v_func = 0.5579010248184204
MC = 0.3530000150203705, v_func = 0.37356802821159363
MC = 0.2370000183582306, v_func = 0.20314764976501465
MC = 0.10000000149011612, v_func = 0.10526865720748901
MC = 0.032999999821186066, v_func = 0.045493628829717636
MC = 0.0, v_func = 0.019176321104168892
MC = 0.19200000166893005, v_func = 0.16148796677589417
MC = 0.6060000061988831, v_func = 0.5154905319213867
MC = 0.5470000505447388, v_func = 0.5316169857978821


# Final AI! 


This AI was trained for 10000 games against itself and has achieved:

- ~99% win rate **vs** pure random AI
- ~95% win rate **vs** random timid AI
- ~75% win rate **vs** runner weights AI

Description of Features:

|Feature (# weights)    | Description  |
|-----|-----|
|Runner Location (6)|The distance the runners have advanced in each column this turn|
|Player Column (6)|The distance the player is from claiming each column|
|Enemy Column (6)|The distance the enemy is from claiming each column|
|Player Col Claimed (6)|The columns the player has claimed|
|Enemy Col Claimed (6)|The columns the enemy has claimed|
|Player N Col Claimed (1)|The number of columns the player has claimed|
|Enemy N Col Claimed (1)| The number of columns the enemy has claimed|
|Affine (1)| Allows for affine functions (non-zero intercept)|

### Improvement through training (by % wins of 1000 games against runner weights AI)

| # Games Trained | Win % (of 1000 games) |
|--|--|
| 0 | 00.0 |
| 100 | 17.4 |
| 1000 | 43.0 |
| 2000 | 59.2 |
| 3000 | 62.2 |
| 5000 | 65.6 |
| 10000 (against itself) | 75.0 |

### Areas to Improve

- When an opponent is close to claiming a column, the AI knows to stay away from the column.
  - However, it also avoids the corresponding column on the opposite side of the board (by symmetry).
    - I would like to introduce a new feature that uses the **runner location**, **player column**, and **enemy column**




In [None]:
@jit
def Good_AI(active_player_index, player_col_state, choices, legal, random_key):
  '''An AI function using some weights w which are assumed to be a global variable'''
  #Input:
  # w - the weights
  # active_player_index - whose turn it is
  # player_col_state - distance each player is away from claiming each column
  # choices - an array of possible choices that the AI can make
  # legal - an array of which choices are legal
  #Output:
  # best_choice_index - if there is a legal choice, this is the legal choice that gives the highest estimated probability to win;
  #                     if there is no legal choice it gives the illegal choice with the highest estimated probability to win.
  # roll_again - choice whether to roll again or stay - this is whatever choice gives the highest estimated probability to win.

  # hardcoded weights
  w = jnp.array([0.1035391, 0.02324084, 0.00629771, -0.01452476, 0.00655654, 0.00639482, 0.28484833, 0.20598306, 0.12112736, 0.05217794, 0.11130193, 0.08123927, 0.27438858, 0.2049303, 0.11780114, 0.03907273, 0.10480063, 0.07496218, 0.20313375, 0.2138879, 0.20972085, 0.24800974, 0.19582418, 0.12072659, 0.2825165, 0.23228781, 0.24480774, 0.11446477, 0.1909898, 0.09358135, 0.69109344, 0.6588065, 0.23544775])

  #Calculate N columns claimed and the illegal columns to be used in the AI
  player_N_col_claimed = calculate_player_N_col_claimed(player_col_state)
  illegal_col = calculate_illegal_col(player_col_state)

  #This passes all the choices into the value function at once
  # The result is a vector of shape (9,) with the value of each of the 9 options
  all_vals = Good_v_func(w, active_player_index, player_col_state, player_N_col_claimed, illegal_col, choices)
  
  #Take the maximum but only overthose that are legal!
  # (multuplying by "legal" sets the value of any illegal moves to 0)
  best_choice_index = jnp.argmax(all_vals*legal)

  #Deciding whether or not to roll again (uses the Q function)
  #Compute the value for rolling again and staying from that state.
  runner_col_state = choices[best_choice_index]

  # the value if we roll a
  q_roll_again_val = Good_q_roll_again(w, active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_state)
  q_stay_val = Good_q_stay(w, active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_state)

  # checks if the value of rolling again is higher than the value of staying
  # in this case we should roll again
  roll_again = q_roll_again_val > q_stay_val

  N_runners = jnp.count_nonzero(choices[best_choice_index])
    
  return best_choice_index, roll_again

def reflect_column_weights(w):
  '''Takes in 6 weights and reflects them around the middle column to cover all 11 columns'''
  #Input: 
  # w - array of shape (6,) of the weights to reflect
  #Output:
  # array of shape (11,) of the reflected weights around the middle column
  
  reflected_w = jnp.zeros(11)
  reflected_w = reflected_w.at[0:6].set(w[0:6])
  reflected_w = reflected_w.at[6:12].set(w[4::-1])
  return reflected_w

@jit
def Good_v_func(w,active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_state):
  '''Value function for the AI. Estimates the probability of the active player to win. (i.e. the value of the game where +1 for a win and 0 for a loss)'''
  # divide the weights w into three categories
  runner_w = w[0:6] # weights for runner location features
  player_col_w = w[6:12] # the amount we have advanced in each column
  enemy_col_w = w[12:18] # the amount the enemy has advanced in each column
  player_col_claim_w = w[18:24] # weights for the columns that we have claimed
  enemy_col_claim_w = w[24:30] # weights for the columns that the enemy has claimed
  player_N_col_claimed_w = w[30] # weights for number of columns we have claimed
  enemy_N_col_claimed_w = w[31] # weights for the number of columns the enemy has claimed
  affine_w = w[32] # this weight just allows for affine functions

  # we use symmetry to reduce number of weights, so here we reflect the weights
  # so they cover all 11 columns.
  runner_w = reflect_column_weights(runner_w)
  player_col_w = reflect_column_weights(player_col_w)
  enemy_col_w = reflect_column_weights(enemy_col_w)
  player_col_claim_w = reflect_column_weights(player_col_claim_w)
  enemy_col_claim_w = reflect_column_weights(enemy_col_claim_w)

  score = jnp.inner(runner_w, runner_col_state)

  # we look only at the difference in player locations here
  column_lengths = jnp.array([3, 5, 7, 9, 11, 13, 11, 9, 7, 5, 3])
  player_col = (column_lengths - player_col_state[active_player_index])*(1 - illegal_col)
  enemy_col = (column_lengths - player_col_state[1 - active_player_index])*(1 - illegal_col)

  score += jnp.inner(player_col_w, player_col)
  score -= jnp.inner(enemy_col_w, enemy_col)
  
  # here we look only at the difference in the N col climaed between the two players
  player_col_claimed = jnp.where(player_col_state[active_player_index] < 1, 1, 0)
  enemy_col_claimed = jnp.where(player_col_state[1 - active_player_index] < 1, 1, 0)

  # calculates the precentage of the column we have claimed

  score += player_N_col_claimed[active_player_index] * player_N_col_claimed_w
  score -= player_N_col_claimed[1 - active_player_index] * enemy_N_col_claimed_w

  score += jnp.inner(player_col_claim_w, player_col_claimed)
  score -= jnp.inner(enemy_col_claim_w, enemy_col_claimed)

  score += affine_w

  #Apply a sigmoid to the score so that the value function is always between 0 and 1
  return jnn.sigmoid(score)

@jit
def Good_q_roll_again(w,active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_state):
  '''Find the approximated value for rolling again in terms of the value function v'''
  #Input:
  # w - the weights
  # active_player_index - whose turn it is
  # player_col_state - distance each player is away from claiming each column
  # player_N_col_claimed - the number of columns each player has claimed
  # illegal_col - the current columns that cannot be played in
  # runner_col_state - the distance runner has travelled in each column
  #Output:
  # An estimate of the probability to win given that we choose to stay.
  
  #Compute the score if we would bust
  zero_runner = jnp.zeros(11, dtype=jnp.dtype('u1')) 
  bust_value = 1 - Good_v_func(w, 1 - active_player_index, player_col_state, player_N_col_claimed, illegal_col, zero_runner) 

  # calculates the probabilities of advancing 1 or 2 in each column
  # adv_probs = calc_adv_probs()
  # i hard code this!
  adv_probs = jnp.array([[0.02700617, 0.05246913, 0.07638889, 0.09876543, 0.11959876, 0.13888888, 0.11959876, 0.09876543, 0.07638889, 0.05246913, 0.02700617], [0.0007716, 0.00308642, 0.00694444, 0.01234568, 0.01929012, 0.02777778, 0.01929012, 0.01234568, 0.00694444, 0.00308642, 0.0007716]])

  #Compute the score if we would not bust
  N_runner = jnp.count_nonzero(runner_col_state)

  # calculates the expected amount of spaces we advance by
  advance_runner = runner_col_state + (1 * adv_probs[0] + 2 * adv_probs[1]) * ((runner_col_state > 0) + (1 - illegal_col) * (3 - N_runner))
  advance_value = Good_v_func(w, active_player_index, player_col_state, player_N_col_claimed, illegal_col, advance_runner)

  # find the probability of busting
  p_bust = cant_stop_bust_probability(runner_col_state,illegal_col)

  # the answer is the convex combination of the bust value and advance value using p_bust
  return p_bust * bust_value + (1 - p_bust) * advance_value

@jit
def Good_q_stay(w,active_player_index, player_col_state, player_N_col_claimed, illegal_col, runner_col_state):
  '''Returns the value (according to the value function v) of staying'''
  #Input:
  # w - the weights
  # active_player_index - whose turn it is
  # player_col_state - distance each player is away from claiming each column
  # player_N_col_claimed - the number of columns each player has claimed
  # illegal_col - the current columns that cannot be played in
  # runner_col_state - the distance runner has travelled in each column
  #Output:
  # An estimate of the probability to win given that we choose to stay.

  # if we stay, the runners advance and we can update our game state
  updated_player_col_state = update_player_col_state(active_player_index,player_col_state,runner_col_state) 
  updated_player_N_col_claimed = calculate_player_N_col_claimed(updated_player_col_state)
  updated_illegal_col = calculate_illegal_col(updated_player_col_state)

  # the runners will return to zero
  zero_runner = jnp.zeros(11, dtype=jnp.dtype('u1')) 

  # since v_func is estimating the probability of the active player to win,
  # if we choose to stay we need to do P(we win) = 1-P(other player wins)
  return 1 - Good_v_func(w, 1 - active_player_index, updated_player_col_state, updated_player_N_col_claimed, updated_illegal_col, zero_runner)

# vmap
Good_AI_vmap = jax.vmap(Good_AI, in_axes=(None, 2, 2, 1, 0), out_axes=0 )

### Weights over time
Weights:

After 100 Games: 

[0.07617867, 0.03101422, 0.02089879, 0.01368502, 0.0078908,  0.00360821, 0.10811867, 0.09612512, 0.09594401,
 0.11058875, 0.09506108, 0.05413193, 0.09675168, 0.10194193 ,0.09625064, 0.10252529, 0.09696958 ,0.04687759,
 0.10842792, 0.11154703, 0.11574828, 0.11095291, 0.10790873 ,0.10080186, 0.10217196, 0.09962047 ,0.1082868,
 0.10183044, 0.10172137, 0.09769899, 0.1553869 , 0.11133002 ,0.12203544]

After 1000 Games:

[0.0409123 , 0.00835584, 0.00770477, 0.00658723, 0.00382924, 0.00194145, 0.10266912, 0.08842536 ,0.08659112,
 0.06801992, 0.0505174 , 0.04613237, 0.09144995, 0.09072597, 0.0901193 , 0.0740177 , 0.05003512 ,0.04198749,
 0.10769407, 0.1147806 , 0.13379127, 0.12638041, 0.11165017, 0.10124733, 0.10638765, 0.1089701  ,0.12664862,
 0.11273444, 0.10065236, 0.09947819, 0.1955417 , 0.15487361, 0.13705587]

After 2000 Games:

 [0.03452276, 0.0067145,  0.00645219, 0.00588437, 0.00472891, 0.00477389, 0.10040201, 0.08637246, 0.07309923,
 0.06045679 ,0.04568635, 0.02996191 ,0.09392036 ,0.09121624 ,0.07328227 ,0.06145701 ,0.04417024 ,0.03131298,
 0.10900202 ,0.12652776, 0.13936937 ,0.12773804 ,0.1169558  ,0.10040541 ,0.11235666 ,0.122553   ,0.13532351,
 0.11369655 ,0.10309941, 0.09740062 ,0.21999891 ,0.18443018 ,0.13609365]

 After 3000 Games:

 [0.03016623 ,0.0074396,  0.00928442, 0.00579329, 0.00699066, 0.00678289, 0.10306276, 0.07727478, 0.07287303,
 0.05113595 ,0.03762379, 0.02894804, 0.09859531 ,0.08377281 ,0.07241321 ,0.05437551 ,0.03759806 ,0.03042124,
 0.11911649 ,0.1203796 , 0.14076614, 0.13266423 ,0.12066723 ,0.10322393 ,0.12343781 ,0.11599024 ,0.13647184,
 0.1169002  ,0.10455729, 0.09850085, 0.2368132  ,0.19585846 ,0.13411663]

 After 5000 Games:

 [ 0.03540692 , 0.0107718 ,  0.00804657 ,-0.00699463 , 0.00481385 , 0.00375036 , 0.10548824 , 0.07742934,
  0.06493578  ,0.01399624 , 0.0463188  , 0.03736152 , 0.1020679  , 0.08202828 , 0.06522607 , 0.01529308,
  0.04672097 , 0.03780698 , 0.13181761 , 0.13382572 , 0.13381454 , 0.13444152 , 0.11956096 , 0.10478327,
  0.14010414 , 0.13326086 , 0.13248673 , 0.11152916 , 0.10718318 , 0.098153   , 0.25824648 , 0.22271901,
  0.12247454]

After retraining against itself 10 times (1000 games each):

 [ 0.1035391 ,  0.02324084,  0.00629771, -0.01452476,  0.00655654,  0.00639482,  0.28484833,  0.20598306,
  0.12112736,  0.05217794  ,0.11130193  ,0.08123927  ,0.27438858  ,0.2049303   ,0.11780114  ,0.03907273,
  0.10480063,  0.07496218  ,0.20313375  ,0.2138879   ,0.20972085  ,0.24800974  ,0.19582418  ,0.12072659,
  0.2825165 ,  0.23228781  ,0.24480774  ,0.11446477  ,0.1909898   ,0.09358135  ,0.69109344  ,0.6588065,
  0.23544775]