In [168]:
import os
import random
from typing import *
from collections import deque, namedtuple

from tqdm.auto import tqdm, trange

import torch.nn as nn
import torch.optim as optim

from game.api import BlackjackWrapper
from game.game_models import *

In [169]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [170]:
class BlackjackDQN(nn.Module):
    """
    Model that accepts a flattened state and outputs 12 values:
    1. Q-values of bet percentages from 0.1 to 1.0 (increments of 0.1)
    2. Q-value of taking a card (hit) or not (stand)
    """
    def __init__(self, in_features: int, bet_choices: List[float], card_choices: List[bool], epsilon: float, min_epsilon: float):
        super().__init__()
        self.bet_choices = bet_choices
        self.card_choices = card_choices
        self.epsilon = epsilon
        self.min_epsilon = min_epsilon
        # common layers shared by both outputs
        self.init_layers = nn.Sequential(
            nn.Linear(in_features, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.LeakyReLU(),
        )
        # layers for bet percentage Q-value output
        self.bet_layers = nn.Sequential(
            nn.Linear(32, 16),
            nn.LeakyReLU(),
            nn.Linear(16, len(bet_choices)),
        )
        # layers for card action Q-value output
        self.card_layers = nn.Sequential(
            nn.Linear(32, 16),
            nn.LeakyReLU(),
            nn.Linear(16, 8),
            nn.LeakyReLU(),
            nn.Linear(8, len(card_choices)),
        )

    def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.init_layers(x)
        x1 = self.bet_layers(x)
        x2 = self.card_layers(x)
        return x1, x2

    def batched_forward_with_concat(self, x) -> torch.Tensor:
        x = self.init_layers(x)
        x1 = self.bet_layers(x)
        x2 = self.card_layers(x)
        return torch.cat((x1, x2), dim=-1)

    def get_bet_percent(
        self, normalized_state, allow_explore: bool, num_steps: int
    ) -> Tuple[float, int, torch.Tensor]:
        bet_values, card_values = self.forward(normalized_state)
        if allow_explore and random.random() < max(self.epsilon ** num_steps, self.min_epsilon):
            # explore
            idx = random.randint(0, len(self.bet_choices) - 1)
        else:
            # exploit
            idx = bet_values.argmax().item()
        action = self.bet_choices[idx]
        concat_output = torch.cat((bet_values, card_values), dim=-1).unsqueeze(0).cpu()
        return action, idx, concat_output


    def get_card_action(self, normalized_state, allow_explore: bool, num_steps: int) -> Tuple[bool, int, torch.Tensor]:
        bet_values, card_values = self.forward(normalized_state)
        if allow_explore and random.random() < max(self.epsilon ** num_steps, self.min_epsilon):
            # explore
            idx = random.randint(0, len(self.card_choices) - 1)
        else:
            # exploit
            idx = card_values.argmax().item()
        action = self.card_choices[idx]
        concat_output = torch.cat((bet_values, card_values), dim=-1).unsqueeze(0).cpu()
        return action, idx + len(self.bet_choices), concat_output

In [171]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayBuffer(deque):

    def __init__(self, capacity):
        super().__init__([], maxlen=capacity)

    def push(self, transition: Transition):
        self.append(transition)

    def sample(self, batch_size: int):
        return random.sample(self, batch_size)

In [172]:
gamma = 0.9
learning_rate = 1e-3
num_eps = 1000
batch_size = 32
max_steps = 1000

epsilon = 0.9
min_epsilon = 0.05
tau = 0.005

In [173]:
bet_choices = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
card_choices = [True, False]

policy_model = BlackjackDQN(
    in_features=GameState.get_state_size(),
    bet_choices=bet_choices,
    card_choices=card_choices,
    epsilon=epsilon,
    min_epsilon=min_epsilon,
).to(device)
target_model = BlackjackDQN(
    in_features=GameState.get_state_size(),
    bet_choices=bet_choices,
    card_choices=card_choices,
    epsilon=epsilon,
    min_epsilon=min_epsilon,
).to(device)
target_model.load_state_dict(policy_model.state_dict())

game_wrapper = BlackjackWrapper()
optimizer = optim.Adam(policy_model.parameters(), lr=learning_rate)
replay_buffer = ReplayBuffer(10000)

In [174]:
def train(
    policy_model: BlackjackDQN,
    optimizer: optim.Optimizer,
    replay_buffer: ReplayBuffer,
    batch_size: int,
    gamma: float,
):
    if len(replay_buffer) < batch_size:
        return
    transitions = replay_buffer.sample(batch_size)
    # Transpose batch of transitions to get transitions with batches
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action).unsqueeze(-1)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    q_batch = policy_model.batched_forward_with_concat(state_batch)
    state_action_values = q_batch.gather(dim=1, index=action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(batch_size, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_model.batched_forward_with_concat(non_final_next_states).max(dim=1)[0]
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * gamma) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_model.parameters(), 100)
    optimizer.step()

In [175]:
total_reward = 0

for i_episode in trange(num_eps):
    game_wrapper.reset()
    game_state = game_wrapper.get_state()
    state = game_state.torch_flatten(device)
    for i_step in range(max_steps):
        if i_step == 0:
            bet_percent, action, q_values = policy_model.get_bet_percent(
                normalized_state=state, allow_explore=True, num_steps=i_episode
            )
            outcome = game_wrapper.bet_step(bet_percent)
        else:
            card_action, action, q_values = policy_model.get_card_action(
                normalized_state=state, allow_explore=True, num_steps=i_episode
            )
            outcome = game_wrapper.card_step(take_card=card_action)
        terminated = outcome.terminated
        reward = outcome.reward
        reward_tensor = torch.Tensor([reward], device=device)
        action_tensor = torch.Tensor([action], device=device).type(torch.int64)
        total_reward += reward
        next_state = (
            outcome.new_state.torch_flatten(device)
            if not terminated else None
        )

        # Store the transition in memory
        replay_buffer.push(Transition(state, action_tensor, next_state, reward_tensor))

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        train(
            policy_model=policy_model,
            optimizer=optimizer,
            replay_buffer=replay_buffer,
            batch_size=batch_size,
            gamma=gamma,
        )

        # Update target model to weighted sum of policy and target model
        target_state_dict = target_model.state_dict()
        policy_state_dict = policy_model.state_dict()
        for key in target_state_dict:
            target_state_dict[key] = tau * policy_state_dict[key] + (1 - tau) * target_state_dict[key]
        target_model.load_state_dict(target_state_dict)

        if terminated:
            break

    if i_episode % (num_eps // 100) == 0:
        tqdm.write(f"Episode {i_episode}\t\tRunning Average Score: {round(total_reward / (i_episode + 1), 3)}")

  0%|          | 0/1000 [00:00<?, ?it/s]

Episode 0		Running Average Score: 1.0
Episode 10		Running Average Score: 1.0
Episode 20		Running Average Score: 1.0
Episode 30		Running Average Score: 1.0
Episode 40		Running Average Score: 1.0
Episode 50		Running Average Score: 1.0
Episode 60		Running Average Score: 1.0
Episode 70		Running Average Score: 1.0
Episode 80		Running Average Score: 1.0
Episode 90		Running Average Score: 1.0
Episode 100		Running Average Score: 1.0
Episode 110		Running Average Score: 1.0
Episode 120		Running Average Score: 1.0
Episode 130		Running Average Score: 1.0
Episode 140		Running Average Score: 1.0
Episode 150		Running Average Score: 1.0
Episode 160		Running Average Score: 1.0
Episode 170		Running Average Score: 1.0
Episode 180		Running Average Score: 1.0
Episode 190		Running Average Score: 1.0
Episode 200		Running Average Score: 1.0
Episode 210		Running Average Score: 1.0
Episode 220		Running Average Score: 1.0
Episode 230		Running Average Score: 1.0
Episode 240		Running Average Score: 1.0
Episode 250

In [176]:
model_name = "blackjack_dqn"
proj_path = os.path.join(os.getcwd(), "..")
model_path = os.path.join(proj_path, "models", model_name)