In [1]:
import random
from typing import *
from collections import deque

import numpy as np
from tqdm.auto import tqdm, trange

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as distributions

from game.api import BlackjackWrapper
from game.game_models import *

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
class BlackjackPolicyModel(nn.Module):
    """
    Model that accepts a flattened state and outputs 2 values:
    1. Bet percentage from 0 to 1
    2. Probability of taking a card (hit) from 0 to 1
    """
    def __init__(self, in_features: int):
        super().__init__()
        # common layers shared by both outputs
        self.init_layers = nn.Sequential(
            nn.Linear(in_features, 64),
            nn.Linear(64, 128),
            nn.Linear(128, 256),
            nn.Linear(256, 128),
            nn.Linear(128, 64),
            nn.Linear(64, 16),
        )
        # layers for bet percentage output
        self.bet_layers = nn.Sequential(
            nn.Linear(16, 4),
            nn.Linear(4, 1),
        )
        self.bet_act = nn.Sigmoid()
        # layers for card action output
        self.card_layers = nn.Sequential(
            nn.Linear(16, 8),
            nn.Linear(8, 4),
            nn.Linear(4, 1),
        )
        self.card_act = nn.Sigmoid()

    def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.init_layers(x)
        x1 = self.bet_layers(x)
        x1 = self.bet_act(x1)
        x2 = self.card_layers(x)
        x2 = self.card_act(x2)
        return x1, x2

    def get_bet_percent(self, normalized_state) -> torch.Tensor:
        state = torch.from_numpy(normalized_state).float().unsqueeze(0).to(device)
        bet_percent, _ = self.forward(state)
        return bet_percent.cpu()

    def get_card_action(self, normalized_state) -> torch.Tensor:
        state = torch.from_numpy(normalized_state).float().unsqueeze(0).to(device)
        _, card_prob = self.forward(state)
        return card_prob.cpu()

In [4]:
def reinforce(
    game_wrapper: BlackjackWrapper,
    policy_model: BlackjackPolicyModel,
    optimizer: optim.Optimizer,
    num_eps: int,
    gamma: float,
    log_step: Optional[int] = None
):
    print("Starting RL training process...")
    log_step: int = log_step or max(num_eps // 100, 1)
    eps_scores: List[float] = []

    for i_episode in trange(num_eps):
        saved_outputs = []
        rewards: List[float] = []
        game_wrapper.reset()
        starting_state = True
        terminated = False
        state = game_wrapper.get_state()

        while not terminated:
            if starting_state:
                bet_percent = policy_model.get_bet_percent(state.flatten())
                saved_outputs.append(bet_percent)
                outcome = game_wrapper.bet_step(bet_percent)
            else:
                card_action = policy_model.get_card_action(state.flatten())
                saved_outputs.append(card_action)
                outcome = game_wrapper.card_step(take_card=card_action.item() > random.random())
            state = outcome.new_state
            terminated = outcome.terminated
            rewards.append(outcome.reward)
            starting_state = False

        n_steps = len(rewards)
        eps_scores.append(sum(rewards))
        returns = deque(maxlen=n_steps)

        for t in range(n_steps)[::-1]:
            disc_return_t = returns[0] if len(returns) > 0 else 0
            returns.appendleft(gamma * disc_return_t + rewards[t])

        returns = torch.tensor(returns)
        # normalize returns
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        model_loss_arr = []
        for output, pred_return in zip(saved_outputs, returns):
            model_loss_arr.append(-output * pred_return)
        model_loss = torch.cat(model_loss_arr).sum()

        optimizer.zero_grad()
        model_loss.backward()
        optimizer.step()

        if i_episode % log_step == 0:
            tqdm.write(f"Episode {i_episode}\tRunning Average Score: {round(np.mean(eps_scores).item(), 3)}")

    return eps_scores

In [5]:
gamma = 0.9
learning_rate = 1e-3
num_eps = 10000

In [6]:
game_wrapper = BlackjackWrapper()
model = BlackjackPolicyModel(in_features=GameState.get_state_size())
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
reinforce(
    game_wrapper=game_wrapper,
    policy_model=model,
    optimizer=optimizer,
    num_eps=num_eps,
    gamma=gamma,
)

Starting RL training process...


  0%|          | 0/10000 [00:00<?, ?it/s]

Episode 0	Running Average Score: 1.0
Episode 100	Running Average Score: 1.0
Episode 200	Running Average Score: 1.0
Episode 300	Running Average Score: 1.0
Episode 400	Running Average Score: 1.0
Episode 500	Running Average Score: 1.0
Episode 600	Running Average Score: 1.0
Episode 700	Running Average Score: 1.0
Episode 800	Running Average Score: 1.0
Episode 900	Running Average Score: 1.0
Episode 1000	Running Average Score: 1.0
Episode 1100	Running Average Score: 1.0
Episode 1200	Running Average Score: 1.0
Episode 1300	Running Average Score: 1.0
Episode 1400	Running Average Score: 1.0
Episode 1500	Running Average Score: 1.0
Episode 1600	Running Average Score: 1.0
Episode 1700	Running Average Score: 1.0
Episode 1800	Running Average Score: 1.0
Episode 1900	Running Average Score: 1.0
Episode 2000	Running Average Score: 1.0
Episode 2100	Running Average Score: 1.0
Episode 2200	Running Average Score: 1.0
Episode 2300	Running Average Score: 1.0
Episode 2400	Running Average Score: 1.0
Episode 2500

[1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0