# Import the necessary modules 

In [1]:
import torch
import gym
from collections import defaultdict
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
env = gym.make('Blackjack-v0')

use on-policy for choosing action

In [2]:
def run_episode(env, Q, epsilon, n_action):
  state = env.reset()
  rewards = []
  actions = []
  states = []
  is_done = False
  while not is_done:
    probs = torch.ones(n_action) * epsilon / n_action
    best_action = torch.argmax(Q[state]).item()
    probs[best_action] += 1.0 - epsilon
    action = torch.multinomial(probs, 1).item()
    actions.append(action)
    states.append(state)
    state, reward, is_done, info = env.step(action)
    rewards.append(reward)
    if is_done:
      break
  return states, actions, rewards

In [3]:
def mc_control_epsilon_greedy(env, gamma, n_episode, epsilon):
  n_action = env.action_space.n
  G_sum = defaultdict(float)
  N = defaultdict(int)
  Q = defaultdict(lambda: torch.empty(n_action))
  for episode in range(n_episode):
    states_t, actions_t, rewards_t = run_episode(env, Q, epsilon, n_action)
    return_t = 0
    G = {}
    for state_t, action_t, reward_t in zip(states_t[::-1], actions_t[::-1], rewards_t[::-1]):
      return_t = gamma * return_t + reward_t
      G[(state_t, action_t)] = return_t
    for state_action, return_t in G.items():
      state, action = state_action
      if state[0] <= 21:
        G_sum[state_action] += return_t
        N[state_action] += 1
        Q[state][action] = G_sum[state_action] / N[state_action]
  policy = {}
  for state, actions in Q.items():
    policy[state] = torch.argmax(actions).item()
  return Q, policy

In [4]:
gamma = 1
n_episode = 500000
epsilon = 0.1
optimal_Q, optimal_policy = mc_control_epsilon_greedy(env, gamma, n_episode, epsilon)

In [5]:
print(optimal_Q)

defaultdict(<function mc_control_epsilon_greedy.<locals>.<lambda> at 0x7fa777f12560>, {(18, 8, False): tensor([ 0.0973, -0.4852]), (20, 10, False): tensor([ 0.4295, -0.8822]), (13, 4, False): tensor([-0.2063, -0.3918]), (21, 10, True): tensor([ 0.8870, -0.0044]), (20, 4, False): tensor([ 0.6562, -0.8789]), (16, 6, False): tensor([-0.1636, -0.3736]), (19, 10, True): tensor([-0.0549, -0.1411]), (14, 8, False): tensor([-0.4454, -0.3935]), (6, 10, False): tensor([-0.7030, -0.4390]), (16, 5, True): tensor([-0.0977,  0.0812]), (14, 6, False): tensor([-0.1624, -0.3523]), (13, 2, False): tensor([-0.2939, -0.3816]), (12, 10, False): tensor([-0.6007, -0.4531]), (6, 6, False): tensor([-0.1498, -0.1786]), (10, 9, False): tensor([-0.5300,  0.0773]), (7, 10, False): tensor([-0.6170, -0.3770]), (9, 7, False): tensor([-0.5217,  0.1383]), (13, 7, True): tensor([-0.5294,  0.0238]), (21, 8, True): tensor([0.9322, 0.0455]), (20, 10, True): tensor([ 0.4624, -0.1776]), (17, 10, False): tensor([-0.4668, -0.6

# helpful resource

[Blackjack game](https://en.wikipedia.org/wiki/Blackjack)