In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers import fully_connected
import gym
import gym_tictactoe
from timeit import Timer
from sklearn.preprocessing import LabelBinarizer
from os import path

In [2]:
DEBUG = False

n_input = (3 * 3 * 3) * 3
n_hidden = 200
n_output = 3 * 3 * 3
learning_rate = 0.01

checkpoint_path = './my_dqn_tictactoe_v2_h{}_lr{}.ckpt'.format(n_hidden, learning_rate)
initializer = tf.contrib.layers.variance_scaling_initializer()

env = gym.make('tictactoe-v0')
timer = Timer()

In [3]:
encoder = LabelBinarizer()
encoder.fit(np.array([[0], [1], [2]]))

def convert_game_to_x_state(obs):
    # gym_tictactoe now supports int-encoded world
    world = np.array(obs, dtype=np.float32)
    data = list(map(lambda x: [x], world.flatten()))
    return encoder.transform(data).flatten()

def convert_action_to_step(action, player):
    action = int(action)
    val = 0
    multiplier = 1
    while action:
        val += (action%3)*multiplier
        multiplier *= 10
        action //= 3
    
    return str(player) + str(val).zfill(3)

In [4]:
player_scopes = ['actor', 'critic']
all_logits = []
outputs = []
ys = []
all_network_trainable_vars_by_name = []

X_state = tf.placeholder(shape=(None, n_input), dtype=tf.float32)
global_step = tf.Variable(0, trainable=False, name='global_step')

for scope in player_scopes:
    with tf.variable_scope(scope) as tf_scope:
        hidden = fully_connected(X_state, n_hidden, activation_fn=tf.nn.relu, weights_initializer=initializer)
        logits = fully_connected(hidden, n_output, activation_fn=None, weights_initializer=initializer)
        all_logits.append(logits)

        output = tf.contrib.layers.softmax(logits)
        outputs.append(output)
        
        y = tf.to_float(tf.multinomial(tf.log(output), num_samples=n_output))
        ys.append(y)
        
        trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
        network_trainable_vars_by_name = { var.name[len(tf_scope.name):]: var for var in trainable_vars }
        all_network_trainable_vars_by_name.append(network_trainable_vars_by_name)

actor_net = outputs[0]
critic_net = outputs[1]

actor_vars = all_network_trainable_vars_by_name[0]
critic_vars = all_network_trainable_vars_by_name[1]

copy_ops = [actor_var.assign(critic_vars[var_name]) for var_name, actor_var in actor_vars.items()]
copy_critic_to_actor = tf.group(*copy_ops)

X_action = tf.placeholder(tf.int32, shape=[None])
q_value = tf.reduce_sum(critic_net * tf.one_hot(X_action, n_output), axis=1, keepdims=True)

y = tf.placeholder(tf.float32, shape=[None, 1])
cost = tf.reduce_mean(tf.square(y - q_value))

optimizer = tf.train.AdamOptimizer(learning_rate)
training_op = optimizer.minimize(cost, global_step=global_step)

init = tf.global_variables_initializer()
saver = tf.train.Saver()

file_writer = tf.summary.FileWriter('logs', tf.get_default_graph())
file_writer.close()

In [5]:
from collections import deque

eps_min = 0.05
eps_max = 1.0
eps_decay_steps = 100000
replay_mem_size = 1000000
replay_mem = deque([], maxlen=replay_mem_size)


def epsilon_greedy(q_values, step):
    epsilon = max(eps_min, eps_max - (eps_max-eps_min)*step/eps_decay_steps)
    if np.random.rand() < epsilon:
        return np.random.randint(n_output)
    else:
        return np.argmax(q_values)

def sample_mem(batch_size):
    indices = np.random.permutation(len(replay_mem))[:batch_size]
    cols = [[], [], [], [], []]
    for idx in indices:
        memory = replay_mem[idx]
        for col, value in zip(cols, memory):
            col.append(value)
    cols = [np.array(col) for col in cols]
    return (cols[0], cols[1], cols[2].reshape(-1, 1), cols[3], cols[4].reshape(-1, 1))

In [15]:
def play_with_human(env, human_starts=True):
    obs = env.reset()
    done = False
    player = 0
    human_starts_round = human_starts
    with tf.Session() as sess:
        saver.restore(sess, checkpoint_path)
        with tf.variable_scope('actor'):
            while not done:
                if human_starts_round:
                    env.render()
                    step = input()
                    obs, reward, done, info = env.step('{}{}'.format(player%2+1, step))
                    player += 1
                
                if done:
                    break

                human_starts_round=True
                state = convert_game_to_x_state(obs)
                action = np.argmax(actor_net.eval(feed_dict={X_state: [state]}))
                obs, reward, done, info = env.step(convert_action_to_step(action, player%2+1))

                player += 1
    env.render()

In [7]:
n_steps = 1000000
training_start = 1000
training_interval = 3
save_steps = 1000
copy_steps = 100
print_steps = 1000
discount_rate = 0.95
batch_size = 50
done = True

In [12]:
iteration = 0

obs = env.reset()
player_states = [convert_game_to_x_state(obs), None]
done = False
player_next_states = [None, None]
player_rewards = [None, None]
player_actions = [None, None]

with tf.Session() as sess:
    if path.exists(checkpoint_path + '.meta'):
        saver.restore(sess, checkpoint_path)
    else:
        init.run()
    
    s_time = timer.timer()
    while True:
        step = global_step.eval()
        if step >= n_steps:
            break
        
        iteration += 1
        
        q_values = actor_net.eval(feed_dict={X_state: [player_states[0]]})
        action = epsilon_greedy(q_values, step)
        obs, reward, done, info = env.step(convert_action_to_step(action, 1))
        
        if done:
            assert player_states[1] is not None # No games finishes in 1 action
            
            if reward > 0:
                # assign reward to winner and -reward to loser
                replay_mem.append((player_states[0], action, reward, convert_game_to_x_state(obs), 1.0-done))
                replay_mem.append((player_states[1], player_actions[1], -reward, player_next_states[1], 1.0-done))
            
            if reward < 0:
                # This is an illegal move
                replay_mem.append((player_states[0], action, reward, convert_game_to_x_state(obs), 1.0-done))
            
            if reward == 0:
                replay_mem.append((player_states[0], action, 0, convert_game_to_x_state(obs), 1.0-done))
                replay_mem.append((player_states[1], player_actions[1], 0, player_next_states[1], 1.0-done))
                
            obs = env.reset()
            player_states = [convert_game_to_x_state(obs), None]
            done = False
            player_next_states = [None, None]
            player_rewards = [None, None]
            player_actions = [None, None]
            
        else:
            player_next_states[0] = convert_game_to_x_state(obs)
            player_actions[0] = action
            player_rewards[0] = reward

            player_states[1] = player_next_states[0]

            q_values = actor_net.eval(feed_dict={X_state: [player_states[1]]})
            action = epsilon_greedy(q_values, step)
            obs, reward, done, info = env.step(convert_action_to_step(action, 2))

            if done:
                assert player_states[0] is not None # No games finishes in 1 action

                if reward > 0:
                    # assign reward to winner and -reward to loser
                    replay_mem.append((player_states[1], action, reward, convert_game_to_x_state(obs), 1.0-done))
                    replay_mem.append((player_states[0], player_actions[0], -reward, player_next_states[0], 1.0-done))

                if reward < 0:
                    # This is an illegal move
                    replay_mem.append((player_states[1], action, reward, convert_game_to_x_state(obs), 1.0-done))

                if reward == 0:
                    replay_mem.append((player_states[1], action, 0, convert_game_to_x_state(obs), 1.0-done))
                    replay_mem.append((player_states[0], player_actions[0], 0, player_next_states[0], 1.0-done))
                
                obs = env.reset()
                player_states = [convert_game_to_x_state(obs), None]
                done = False
                player_next_states = [None, None]
                player_rewards = [None, None]
                player_actions = [None, None]
                
            else:
                player_next_states[1] = convert_game_to_x_state(obs)
                player_actions[1] = action
                player_rewards[1] = reward
                
                player_states[0] = player_next_states[1]
        
        if iteration % print_steps == 0:
            e_time = timer.timer()
            print('Steps:{}'.format(step), 'Time:{0:.2f}s'.format(e_time-s_time))
            s_time = e_time
        
        if iteration % training_interval != 0:
            continue
        
        # Critic learns
        X_state_val, X_action_val, rewards, X_next_state, _ = (sample_mem(batch_size))
        next_q_values = actor_net.eval(feed_dict={X_state: X_next_state})
        max_next_q_values = np.max(next_q_values, axis=1, keepdims=True)
        y_val = rewards + discount_rate * max_next_q_values
        training_op.run(feed_dict={X_state: X_state_val, X_action: X_action_val, y: y_val})
        
        if step % copy_steps:
            copy_critic_to_actor.run()
            
        if step % save_steps:
            saver.save(sess, checkpoint_path)

Steps:333 Time:13.66s
Steps:666 Time:13.66s
Steps:999 Time:14.13s
Steps:1333 Time:13.69s
Steps:1666 Time:13.43s
Steps:1999 Time:13.52s
Steps:2333 Time:14.64s
Steps:2666 Time:13.80s
Steps:2999 Time:13.35s
Steps:3333 Time:13.36s
Steps:3666 Time:13.48s
Steps:3999 Time:13.48s
Steps:4333 Time:13.71s
Steps:4666 Time:13.75s
Steps:4999 Time:13.60s
Steps:5333 Time:13.85s
Steps:5666 Time:13.72s
Steps:5999 Time:13.73s
Steps:6333 Time:13.73s
Steps:6666 Time:13.73s
Steps:6999 Time:13.73s
Steps:7333 Time:13.52s
Steps:7666 Time:13.40s
Steps:7999 Time:13.51s
Steps:8333 Time:13.47s
Steps:8666 Time:13.44s
Steps:8999 Time:13.46s
Steps:9333 Time:13.41s
Steps:9666 Time:13.47s
Steps:9999 Time:13.43s
Steps:10333 Time:13.45s
Steps:10666 Time:13.43s
Steps:10999 Time:13.52s
Steps:11333 Time:13.47s
Steps:11666 Time:13.44s
Steps:11999 Time:13.48s
Steps:12333 Time:13.45s
Steps:12666 Time:13.50s
Steps:12999 Time:13.45s
Steps:13333 Time:13.46s
Steps:13666 Time:13.51s
Steps:13999 Time:13.48s
Steps:14333 Time:13.48s
S

Steps:114333 Time:14.42s
Steps:114666 Time:14.48s
Steps:114999 Time:14.57s
Steps:115333 Time:14.52s
Steps:115666 Time:14.47s
Steps:115999 Time:14.41s
Steps:116333 Time:14.47s
Steps:116666 Time:14.45s
Steps:116999 Time:14.50s
Steps:117333 Time:14.53s
Steps:117666 Time:14.47s
Steps:117999 Time:14.58s
Steps:118333 Time:14.53s
Steps:118666 Time:14.60s
Steps:118999 Time:14.47s
Steps:119333 Time:14.67s
Steps:119666 Time:14.85s
Steps:119999 Time:14.64s
Steps:120333 Time:14.73s
Steps:120666 Time:14.58s
Steps:120999 Time:14.59s
Steps:121333 Time:14.57s
Steps:121666 Time:14.52s
Steps:121999 Time:14.50s
Steps:122333 Time:14.62s
Steps:122666 Time:14.52s
Steps:122999 Time:14.59s
Steps:123333 Time:14.56s
Steps:123666 Time:14.62s
Steps:123999 Time:14.54s
Steps:124333 Time:14.68s
Steps:124666 Time:14.50s
Steps:124999 Time:14.59s
Steps:125333 Time:14.52s
Steps:125666 Time:14.56s
Steps:125999 Time:14.53s
Steps:126333 Time:14.52s
Steps:126666 Time:14.63s
Steps:126999 Time:14.57s
Steps:127333 Time:14.51s


Steps:223666 Time:15.63s
Steps:223999 Time:15.64s
Steps:224333 Time:15.80s
Steps:224666 Time:15.52s
Steps:224999 Time:15.52s
Steps:225333 Time:15.58s
Steps:225666 Time:15.61s
Steps:225999 Time:15.70s
Steps:226333 Time:15.63s
Steps:226666 Time:15.62s
Steps:226999 Time:15.69s
Steps:227333 Time:15.68s
Steps:227666 Time:15.75s
Steps:227999 Time:15.66s
Steps:228333 Time:15.69s
Steps:228666 Time:15.69s
Steps:228999 Time:15.70s
Steps:229333 Time:15.68s
Steps:229666 Time:15.86s
Steps:229999 Time:15.62s
Steps:230333 Time:15.87s
Steps:230666 Time:15.69s
Steps:230999 Time:15.74s
Steps:231333 Time:15.73s
Steps:231666 Time:15.76s
Steps:231999 Time:15.76s
Steps:232333 Time:15.72s
Steps:232666 Time:15.64s
Steps:232999 Time:15.69s
Steps:233333 Time:15.83s
Steps:233666 Time:15.59s
Steps:233999 Time:15.60s
Steps:234333 Time:15.79s
Steps:234666 Time:15.80s
Steps:234999 Time:15.73s
Steps:235333 Time:15.76s
Steps:235666 Time:15.64s
Steps:235999 Time:15.68s
Steps:236333 Time:15.79s
Steps:236666 Time:15.77s


Steps:332999 Time:16.62s
Steps:333333 Time:16.77s
Steps:333666 Time:16.83s
Steps:333999 Time:16.83s
Steps:334333 Time:16.79s
Steps:334666 Time:16.71s
Steps:334999 Time:16.74s
Steps:335333 Time:16.90s
Steps:335666 Time:16.86s
Steps:335999 Time:16.89s
Steps:336333 Time:16.91s
Steps:336666 Time:16.78s
Steps:336999 Time:16.85s
Steps:337333 Time:16.84s
Steps:337666 Time:16.87s
Steps:337999 Time:16.85s
Steps:338333 Time:16.86s
Steps:338666 Time:16.97s
Steps:338999 Time:16.95s
Steps:339333 Time:16.94s
Steps:339666 Time:16.83s
Steps:339999 Time:16.84s
Steps:340333 Time:16.84s
Steps:340666 Time:16.87s
Steps:340999 Time:16.79s
Steps:341333 Time:16.85s
Steps:341666 Time:17.03s
Steps:341999 Time:16.82s
Steps:342333 Time:16.78s
Steps:342666 Time:16.79s
Steps:342999 Time:16.90s
Steps:343333 Time:16.91s
Steps:343666 Time:16.78s
Steps:343999 Time:16.89s
Steps:344333 Time:16.70s
Steps:344666 Time:16.85s
Steps:344999 Time:16.91s
Steps:345333 Time:16.97s
Steps:345666 Time:16.89s
Steps:345999 Time:16.99s


Steps:442333 Time:18.07s
Steps:442666 Time:18.09s
Steps:442999 Time:18.04s
Steps:443333 Time:18.03s
Steps:443666 Time:18.36s
Steps:443999 Time:18.21s
Steps:444333 Time:18.26s
Steps:444666 Time:18.03s
Steps:444999 Time:18.10s
Steps:445333 Time:18.18s
Steps:445666 Time:18.38s
Steps:445999 Time:18.20s
Steps:446333 Time:17.98s
Steps:446666 Time:18.21s
Steps:446999 Time:18.17s
Steps:447333 Time:18.24s
Steps:447666 Time:18.21s
Steps:447999 Time:18.17s
Steps:448333 Time:18.13s
Steps:448666 Time:18.15s
Steps:448999 Time:18.21s
Steps:449333 Time:18.20s
Steps:449666 Time:18.36s
Steps:449999 Time:18.29s
Steps:450333 Time:18.20s
Steps:450666 Time:18.27s
Steps:450999 Time:18.28s
Steps:451333 Time:18.18s
Steps:451666 Time:18.31s
Steps:451999 Time:18.04s
Steps:452333 Time:18.26s
Steps:452666 Time:18.27s
Steps:452999 Time:18.23s
Steps:453333 Time:18.34s
Steps:453666 Time:18.22s
Steps:453999 Time:18.25s
Steps:454333 Time:18.26s
Steps:454666 Time:18.21s
Steps:454999 Time:18.12s
Steps:455333 Time:18.22s


Steps:551666 Time:19.14s
Steps:551999 Time:19.41s
Steps:552333 Time:19.79s
Steps:552666 Time:19.58s
Steps:552999 Time:19.29s
Steps:553333 Time:19.42s
Steps:553666 Time:19.31s
Steps:553999 Time:19.35s
Steps:554333 Time:19.31s
Steps:554666 Time:19.29s
Steps:554999 Time:19.27s
Steps:555333 Time:19.08s
Steps:555666 Time:19.09s
Steps:555999 Time:19.23s
Steps:556333 Time:19.37s
Steps:556666 Time:19.09s
Steps:556999 Time:19.17s
Steps:557333 Time:19.29s
Steps:557666 Time:19.28s
Steps:557999 Time:19.36s
Steps:558333 Time:19.33s
Steps:558666 Time:19.39s
Steps:558999 Time:19.44s
Steps:559333 Time:19.32s
Steps:559666 Time:19.36s
Steps:559999 Time:19.34s
Steps:560333 Time:19.33s
Steps:560666 Time:19.33s
Steps:560999 Time:19.50s
Steps:561333 Time:19.45s
Steps:561666 Time:19.50s
Steps:561999 Time:19.48s
Steps:562333 Time:19.25s
Steps:562666 Time:19.35s
Steps:562999 Time:19.41s
Steps:563333 Time:19.42s
Steps:563666 Time:19.50s
Steps:563999 Time:19.44s
Steps:564333 Time:19.45s
Steps:564666 Time:19.40s


Steps:660999 Time:20.33s
Steps:661333 Time:20.41s
Steps:661666 Time:20.36s
Steps:661999 Time:20.37s
Steps:662333 Time:20.33s
Steps:662666 Time:20.34s
Steps:662999 Time:20.36s
Steps:663333 Time:20.44s
Steps:663666 Time:20.38s
Steps:663999 Time:20.35s
Steps:664333 Time:20.43s
Steps:664666 Time:20.45s
Steps:664999 Time:20.38s
Steps:665333 Time:20.42s
Steps:665666 Time:20.62s
Steps:665999 Time:20.57s
Steps:666333 Time:20.57s
Steps:666666 Time:20.41s
Steps:666999 Time:20.32s
Steps:667333 Time:20.48s
Steps:667666 Time:20.44s
Steps:667999 Time:20.52s
Steps:668333 Time:20.51s
Steps:668666 Time:20.41s
Steps:668999 Time:20.57s
Steps:669333 Time:20.47s
Steps:669666 Time:20.42s
Steps:669999 Time:20.46s
Steps:670333 Time:20.55s
Steps:670666 Time:20.53s
Steps:670999 Time:20.54s
Steps:671333 Time:20.44s
Steps:671666 Time:20.46s
Steps:671999 Time:20.56s
Steps:672333 Time:20.57s
Steps:672666 Time:20.60s
Steps:672999 Time:20.72s
Steps:673333 Time:20.65s
Steps:673666 Time:20.67s
Steps:673999 Time:20.72s


Steps:770333 Time:21.99s
Steps:770666 Time:21.86s
Steps:770999 Time:22.13s
Steps:771333 Time:22.27s
Steps:771666 Time:21.96s
Steps:771999 Time:21.92s
Steps:772333 Time:22.10s
Steps:772666 Time:21.81s
Steps:772999 Time:22.12s
Steps:773333 Time:22.03s
Steps:773666 Time:22.02s
Steps:773999 Time:21.91s
Steps:774333 Time:22.15s
Steps:774666 Time:22.01s
Steps:774999 Time:21.99s
Steps:775333 Time:22.09s
Steps:775666 Time:22.07s
Steps:775999 Time:22.02s
Steps:776333 Time:22.03s
Steps:776666 Time:22.04s
Steps:776999 Time:22.38s
Steps:777333 Time:22.41s
Steps:777666 Time:22.03s
Steps:777999 Time:22.16s
Steps:778333 Time:22.08s
Steps:778666 Time:22.04s
Steps:778999 Time:22.11s
Steps:779333 Time:21.81s
Steps:779666 Time:22.04s
Steps:779999 Time:22.22s
Steps:780333 Time:22.11s
Steps:780666 Time:22.11s
Steps:780999 Time:22.08s
Steps:781333 Time:22.18s
Steps:781666 Time:22.16s
Steps:781999 Time:22.39s
Steps:782333 Time:22.18s
Steps:782666 Time:22.10s
Steps:782999 Time:22.24s
Steps:783333 Time:22.29s


Steps:879666 Time:23.51s
Steps:879999 Time:23.46s
Steps:880333 Time:23.31s
Steps:880666 Time:23.59s
Steps:880999 Time:23.67s
Steps:881333 Time:23.58s
Steps:881666 Time:23.40s
Steps:881999 Time:23.56s
Steps:882333 Time:23.53s
Steps:882666 Time:23.59s
Steps:882999 Time:23.37s
Steps:883333 Time:23.54s
Steps:883666 Time:23.61s
Steps:883999 Time:23.52s
Steps:884333 Time:23.70s
Steps:884666 Time:23.38s
Steps:884999 Time:23.57s
Steps:885333 Time:23.78s
Steps:885666 Time:23.53s
Steps:885999 Time:23.60s
Steps:886333 Time:23.56s
Steps:886666 Time:23.55s
Steps:886999 Time:23.53s
Steps:887333 Time:23.59s
Steps:887666 Time:23.61s
Steps:887999 Time:23.68s
Steps:888333 Time:23.68s
Steps:888666 Time:23.61s
Steps:888999 Time:23.75s
Steps:889333 Time:23.74s
Steps:889666 Time:23.19s
Steps:889999 Time:23.64s
Steps:890333 Time:23.60s
Steps:890666 Time:23.69s
Steps:890999 Time:23.59s
Steps:891333 Time:23.68s
Steps:891666 Time:23.67s
Steps:891999 Time:23.59s
Steps:892333 Time:23.84s
Steps:892666 Time:23.74s


Steps:988999 Time:25.12s
Steps:989333 Time:25.01s
Steps:989666 Time:25.07s
Steps:989999 Time:25.07s
Steps:990333 Time:25.25s
Steps:990666 Time:25.07s
Steps:990999 Time:25.00s
Steps:991333 Time:25.06s
Steps:991666 Time:25.02s
Steps:991999 Time:25.16s
Steps:992333 Time:25.14s
Steps:992666 Time:25.06s
Steps:992999 Time:25.04s
Steps:993333 Time:25.07s
Steps:993666 Time:25.13s
Steps:993999 Time:25.05s
Steps:994333 Time:25.05s
Steps:994666 Time:25.05s
Steps:994999 Time:25.05s
Steps:995333 Time:25.17s
Steps:995666 Time:25.19s
Steps:995999 Time:25.32s
Steps:996333 Time:25.16s
Steps:996666 Time:25.17s
Steps:996999 Time:25.05s
Steps:997333 Time:25.02s
Steps:997666 Time:25.09s
Steps:997999 Time:25.09s
Steps:998333 Time:25.21s
Steps:998666 Time:25.15s
Steps:998999 Time:25.15s
Steps:999333 Time:25.27s
Steps:999666 Time:25.60s
Steps:999999 Time:25.25s


In [16]:
play_with_human(env, human_starts=True)

INFO:tensorflow:Restoring parameters from ./my_dqn_tictactoe_v2_h200_lr0.01.ckpt
- - -    - - -    - - -    
- - -    - - -    - - -    
- - -    - - -    - - -    
011
- - -    - - -    - - -    
- x -    - o -    - - -    
- - -    - - -    - - -    
001
- - -    - - -    o - -    
x x -    - o -    - - -    
- - -    - - -    - - -    
002
- - -    - - -    O - -    
x x -    - O -    - - -    
x - O    - - -    - - -    


In [20]:
play_with_human(env, human_starts=False)

INFO:tensorflow:Restoring parameters from ./my_dqn_tictactoe_v2_h200_lr0.01.ckpt
- - -    - - -    x - -    
- - -    - - -    - - -    
- - -    - - -    - - -    
111
- - -    x - -    x - -    
- - -    - o -    - - -    
- - -    - - -    - - -    
000
o - x    x - -    x - -    
- - -    - o -    - - -    
- - -    - - -    - - -    
010
o o X    x X -    X - -    
- - -    - o -    - - -    
- - -    - - -    - - -    


In [None]:
with tf.Session() as sess:
    saver.restore(sess, checkpoint_path)
    print(global_step.eval())

In [None]:
env.render()

In [None]:
env._done