Simple rules of this blackjack variant:

Game is played with an infinite deck of cards

Each draw from deck results in value between 1 and 10 (uniformly distributed) with a color of red (probability 1/4) or black (probability 3/4)

At the start of the game both the player and dealer draw one black card

Each turn the player may either stick or hit

If player hits, then she draws another card from the deck

If player sticks, she receives no further cards

The value of player cards is added (black cards) or subtracted (red cards)

If player’s sum exceeds 21, or becomes less than 1, then she “goes bust” and loses the game (reward -1)
If the player sticks, then the dealer starts taking turns. The dealer always sticks on any sum of 16 or greater, and hits otherwise. If the dealer goes bust, then the player wins; otherwise the outcome-win (reward +1), lose (reward -1), or draw (reward 0) - is the player with the largest sum.

In [1]:
import random

import pandas as pd

import math

random.seed(42)

def win_rate(game_record):
    record = pd.Series(game_record)
    return record.value_counts()[1] / len(record)

### Exact Q-learning

In [2]:
class Weirdjack:
    def __init__(self, alpha=0.3, gamma=0.9, epsilon=0.1, stop_point=50000):
        self.alpha = alpha  # step size/learning rate
        self.gamma = gamma  # discount factor
        self.epsilon = epsilon  # for epsilon - greedy
        self.record = []
        self.stop = stop_point

        # Initialise q_values
        self.q_values = {}
        for i in range(21):
            for d in range(10):
                self.q_values[(i+1,d+1)] = {'hit':0,'stick':0}

    def play_game(self):
        state = self.draw_card(black=True)  # Draw black card
        dealer_start = self.draw_card(black=True) # dealer's card
        while state > 0 and state <= 21:
            action = self.select_e_greedy_a((state,dealer_start))  # Select action based on strategy
            if action == 'hit':  ## can result in bust terminal state
                card = self.draw_card()
                successor = state + card
                if self.is_bust(successor):
                    self.update_q_value((state,dealer_start), action, -1)
                    self.record.append(-1)
                    break
                else:
                    self.update_q_value((state,dealer_start), action, 0, successor)
                    state += card
            if action == 'stick':
                dealer = self.dealer(dealer_start)
                result = self.score_reward(state, dealer)
                self.update_q_value((state,dealer_start), action, result)
                self.record.append(result)
                break

    def update_q_value(self, state, action, reward, successor=False):
        ## q_new = q_old + alpha * [reward + gamma * V(s_prime) - q_old]
        if len(self.record) >= self.stop:
            pass
        else:
            old = self.q_values[state][action]
            if successor == False:  # stick, go to terminal state
                v = 0
            elif self.is_bust(successor):  # If successor is bust state, terminal
                v = 0
            else:  # Update q value with V(s_prime)
                v = max([self.q_values[(successor,state[1])]['hit'], self.q_values[(successor,state[1])]['stick']])
            self.q_values[state][action] = old + self.alpha * (reward + self.gamma * v - old)

    def is_bust(self, state):
        return (state > 21) or (state < 1)

    def score_reward(self, player_state, dealer_state):
        ## This function runs when player chooses to stick - the other terminal state
        if self.is_bust(dealer_state):  ## if dealer is bust
            return 1
        else:
            if dealer_state == player_state:
                return 0
            elif player_state > dealer_state:
                return 1
            else:
                return -1

    def select_greedy_a(self, state):
        ## Select best action for a state greedily
        if self.q_values[state]['hit'] > self.q_values[state]['stick']:
            return 'hit'
        elif self.q_values[state]['stick'] > self.q_values[state]['hit']:
            return 'stick'
        else:  ## If tie, random
            return random.choice(['hit', 'stick'])

    def select_e_greedy_a(self, state):
        ## Select best action for a state with e-greedy strategy
        rng = random.random()
        greed = self.select_greedy_a(state)
        if len(self.record) >= self.stop:
            return greed
        if rng > self.epsilon:
            return greed
        else:
            if greed == 'hit':
                return 'stick'
            else:
                return 'hit'

    def dealer(self,start):
        start = start
        while start > 0 and start < 16:
            start += self.draw_card()
        return start

    def draw_card(self, black=False):
        num = random.randint(1, 10)
        col = random.randint(1, 4)
        if black:
            return num
        if col == 1:
            return num * -1
        else:
            return num
        
    def compute_policy(self):
        policy_dict = {}
        for state in self.q_values.keys():
            if self.q_values[state]['hit'] >= self.q_values[state]['stick']:
                policy_dict[state] = 'hit'
            else:
                policy_dict[state] = 'stick'
        return policy_dict

In [3]:
game = Weirdjack(stop_point=2000000)
for _ in range(2010000):
    game.play_game()

In [4]:
win_rate(game.record)

0.36934427860696517

In [5]:
# First 500 games
print(win_rate(game.record[:500]))
# Subsequent 500 games
print(win_rate(game.record[500:1000]))
# Subsequent 1k games
print(win_rate(game.record[1000:2000]))
# Subsequent 1k games
print(win_rate(game.record[2000:3000]))
# Last 10k games
print(win_rate(game.record[-10000:]))

0.354
0.406
0.379
0.382
0.3825


In [6]:
game.compute_policy()

{(1, 1): 'hit',
 (1, 2): 'hit',
 (1, 3): 'stick',
 (1, 4): 'hit',
 (1, 5): 'stick',
 (1, 6): 'hit',
 (1, 7): 'hit',
 (1, 8): 'stick',
 (1, 9): 'hit',
 (1, 10): 'hit',
 (2, 1): 'hit',
 (2, 2): 'hit',
 (2, 3): 'hit',
 (2, 4): 'stick',
 (2, 5): 'hit',
 (2, 6): 'hit',
 (2, 7): 'hit',
 (2, 8): 'stick',
 (2, 9): 'hit',
 (2, 10): 'hit',
 (3, 1): 'hit',
 (3, 2): 'hit',
 (3, 3): 'hit',
 (3, 4): 'hit',
 (3, 5): 'stick',
 (3, 6): 'hit',
 (3, 7): 'hit',
 (3, 8): 'hit',
 (3, 9): 'hit',
 (3, 10): 'hit',
 (4, 1): 'hit',
 (4, 2): 'stick',
 (4, 3): 'hit',
 (4, 4): 'stick',
 (4, 5): 'stick',
 (4, 6): 'stick',
 (4, 7): 'hit',
 (4, 8): 'hit',
 (4, 9): 'hit',
 (4, 10): 'hit',
 (5, 1): 'stick',
 (5, 2): 'hit',
 (5, 3): 'hit',
 (5, 4): 'hit',
 (5, 5): 'hit',
 (5, 6): 'hit',
 (5, 7): 'hit',
 (5, 8): 'hit',
 (5, 9): 'hit',
 (5, 10): 'hit',
 (6, 1): 'hit',
 (6, 2): 'hit',
 (6, 3): 'stick',
 (6, 4): 'hit',
 (6, 5): 'hit',
 (6, 6): 'hit',
 (6, 7): 'hit',
 (6, 8): 'hit',
 (6, 9): 'stick',
 (6, 10): 'hit',
 (7, 1):

### Approximate Q-learning

In [7]:
class Weirdjack2:
    def __init__(self, alpha=0.3, gamma=0.7, epsilon=0.1, stop_point=50000):
        self.alpha = alpha  # step size/learning rate
        self.gamma = gamma  # discount factor
        self.epsilon = epsilon  # for epsilon - greedy
        self.record = []
        self.stop = stop_point

        # Initialise weights
        self.weights = {
            'f1': 0,
            'f2': 0,
            'f3': 0,
            'f4': 0
        }

    def f1(self, state, action):
        ## A linear Boolean feature that punishes sticking at low player states.
        if action == 'stick':
            return (state[0]-21)/21
        return 0

    def f2(self, state, action):
        ### Incentivise hitting at low player state. Scales linearly from player state 1-21 (scaled from 1-0)
        if action == 'hit':
            return (21-state[0])/21
        return 0
    
    def f3(self,state,action):
        # Bias, always returns 1
        return 1
    
    def f4(self,state,action):
        # If dealer state is higher (val is positive), increase q-value for hitting
        # If dealer state is lower (val is negative), increase q-value for sticking
        val = (state[0]-state[1])/20
        if action == 'hit':
            return val 
        else:
            return -val
            
    def play_game(self):
        player_state = self.draw_card(black=True)  # Draw black card
        dealer_start = self.draw_card(black=True)  # Dealer's black card
        while not self.is_bust(player_state):
            action = self.select_e_greedy_a((player_state,dealer_start))  # Select action based on strategy
            if action == 'hit':  ## can result in bust terminal state
                card = self.draw_card()
                successor = player_state + card
                if self.is_bust(successor):
                    self.update_all_weights((player_state,dealer_start), action, -1)
                    self.record.append(-1)
                    break
                else:
                    self.update_all_weights((player_state,dealer_start), action, 0, successor)
                    player_state += card
            if action == 'stick':
                dealer = self.dealer(dealer_start)
                result = self.score_reward(player_state, dealer)
                self.update_all_weights((player_state,dealer_start), action, result)
                self.record.append(result)
                break

    def update_all_weights(self, state, action, reward, successor=False):
        if len(self.record) >= self.stop:
            pass
        else:
            for weight in self.weights.keys():
                self.update_weights(weight, state, action, reward, successor)

    def update_weights(self, feature, state, action, reward, successor=False):
        ## w_new = w_old + alpha * [reward + gamma * max q of successor state - q(s,a)] * f(s,a)
        ## State here is a tuple of (player_state,dealer_start)
        w_old = self.weights[feature]+0
        qsa = self.compute_q_value(state, action)
        fsa = getattr(self, feature)(state, action)

        if successor == False:  # terminal state has no successor state
            v = 0
        elif self.is_bust(successor):  # If successor is bust state, terminal
            v = 0
        else:  # Update q value with V(s_prime)
            v = max([self.compute_q_value((successor,state[1]), 'hit'), 
                     self.compute_q_value((successor,state[1]), 'stick')])

        self.weights[feature] = w_old + self.alpha * (reward + self.gamma * v - qsa) * fsa

    def is_bust(self, state):
        return (state > 21) or (state < 1)

    def score_reward(self, player_state, dealer_state):
        ## This function runs when player chooses to stick
        if self.is_bust(dealer_state):  ## if dealer is bust
            return 1
        else:
            if dealer_state == player_state:
                return 0
            elif player_state > dealer_state:
                return 1
            else:
                return -1

    def compute_q_value(self, state, action):
        ## Q(s,a) = w1 * f1(s,a) + w2 * f2(s,a) + ....
        return sum(self.weights[f_name] * getattr(self, f_name)(state, action) for f_name in self.weights.keys())

    def select_greedy_a(self, state):
        ## Select best action for a state greedily
        hit_q = self.compute_q_value(state, 'hit')
        stick_q = self.compute_q_value(state, 'stick')
        if hit_q > stick_q:
            return 'hit'
        elif stick_q > hit_q:
            return 'stick'
        else:  ## If tie, random
            return random.choice(['hit', 'stick'])

    def select_e_greedy_a(self, state):
        ## Select best action for a state with e-greedy strategy
        rng = random.random()
        greed = self.select_greedy_a(state)
        if len(self.record) >= self.stop:
            return greed
        if rng > self.epsilon:
            return greed
        else:
            if greed == 'hit':
                return 'stick'
            else:
                return 'hit'

    def dealer(self,start):
        start = start
        while start > 0 and start < 16:
            start += self.draw_card()
        return start

    def draw_card(self, black=False):
        num = random.randint(1, 10)
        col = random.randint(1, 4)
        if black:
            return num
        if col == 1:
            return num * -1
        else:
            return num

    def compute_q_dict(self):
        q_values = {}
        for i in range(21):
            for d in range(10):
                state = (i+1,d+1)
                q_values[state] = {'hit': self.compute_q_value(state, 'hit'), 'stick': self.compute_q_value(state, 'stick')}
        return q_values
    
    def compute_policy(self):
        q_values = self.compute_q_dict()
        policy_dict = {}
        for state in q_values.keys():
            if q_values[state]['hit'] >= q_values[state]['stick']:
                policy_dict[state] = 'hit'
            else:
                policy_dict[state] = 'stick'
        return policy_dict

In [8]:
game2 = Weirdjack2(stop_point=500,gamma=0.8,epsilon=0,alpha=0.2)
for _ in range(10500):
    game2.play_game()

In [9]:
game2.weights

{'f1': 2.3537400871673615,
 'f2': -0.8824392189842414,
 'f3': 0.33990113805437827,
 'f4': -0.38849252886021934}

In [10]:
win_rate(game2.record)

0.39885714285714285

In [11]:
# First 500 games
print(win_rate(game2.record[:500]))
# Subsequent 500 games
print(win_rate(game2.record[500:1000]))
# Subsequent 1k games
print(win_rate(game2.record[1000:2000]))
# Last 10k games
print(win_rate(game2.record[-10000:]))

0.37
0.398
0.362
0.4003


In [42]:
game2.compute_policy()

{(1, 1): 'hit',
 (1, 2): 'hit',
 (1, 3): 'hit',
 (1, 4): 'hit',
 (1, 5): 'hit',
 (1, 6): 'hit',
 (1, 7): 'hit',
 (1, 8): 'hit',
 (1, 9): 'hit',
 (1, 10): 'hit',
 (2, 1): 'hit',
 (2, 2): 'hit',
 (2, 3): 'hit',
 (2, 4): 'hit',
 (2, 5): 'hit',
 (2, 6): 'hit',
 (2, 7): 'hit',
 (2, 8): 'hit',
 (2, 9): 'hit',
 (2, 10): 'hit',
 (3, 1): 'hit',
 (3, 2): 'hit',
 (3, 3): 'hit',
 (3, 4): 'hit',
 (3, 5): 'hit',
 (3, 6): 'hit',
 (3, 7): 'hit',
 (3, 8): 'hit',
 (3, 9): 'hit',
 (3, 10): 'hit',
 (4, 1): 'hit',
 (4, 2): 'hit',
 (4, 3): 'hit',
 (4, 4): 'hit',
 (4, 5): 'hit',
 (4, 6): 'hit',
 (4, 7): 'hit',
 (4, 8): 'hit',
 (4, 9): 'hit',
 (4, 10): 'hit',
 (5, 1): 'hit',
 (5, 2): 'hit',
 (5, 3): 'hit',
 (5, 4): 'hit',
 (5, 5): 'hit',
 (5, 6): 'hit',
 (5, 7): 'hit',
 (5, 8): 'hit',
 (5, 9): 'hit',
 (5, 10): 'hit',
 (6, 1): 'hit',
 (6, 2): 'hit',
 (6, 3): 'hit',
 (6, 4): 'hit',
 (6, 5): 'hit',
 (6, 6): 'hit',
 (6, 7): 'hit',
 (6, 8): 'hit',
 (6, 9): 'hit',
 (6, 10): 'hit',
 (7, 1): 'hit',
 (7, 2): 'hit',
 (