In [1]:
from tetris import TetrisEnv

env = TetrisEnv()

In [2]:
from stable_baselines.common.vec_env import DummyVecEnv

dummy_env = DummyVecEnv([lambda: env])

In [10]:
import tensorflow as tf
import numpy as np

from stable_baselines import DQN
from stable_baselines.a2c.utils import conv, linear, conv_to_fc
from stable_baselines.deepq.policies import FeedForwardPolicy

def cnn(scaled_images, **kwargs):
    activ = tf.nn.relu
    l1 = activ(conv(scaled_images, 'c1', n_filters=32, filter_size=2, stride=1, init_scale=np.sqrt(2), **kwargs))
    l2 = activ(conv(l1, 'c2', n_filters=64, filter_size=2, stride=1, init_scale=np.sqrt(2), **kwargs))
    l3 = activ(conv(l2, 'c3', n_filters=64, filter_size=2, stride=1, init_scale=np.sqrt(2), **kwargs))
    l4 = conv_to_fc(l3)
    l5 = activ(linear(l4, 'fc1', n_hidden=512, init_scale=np.sqrt(2)))
    return l5

class CnnPolicy(FeedForwardPolicy):
    def __init__(self, *args, **kwargs):
        FeedForwardPolicy.__init__(self, *args, **kwargs, 
                                   cnn_extractor=cnn, feature_extraction='cnn')

In [11]:
model = DQN(CnnPolicy, dummy_env, verbose=0, tensorboard_log='./log')

Instructions for updating:
Use tf.cast instead.


In [12]:
model.learn(1000, reset_num_timesteps=False)

<stable_baselines.deepq.dqn.DQN at 0x13f7830f0>

In [14]:
done = False
obs = dummy_env.reset()

while not done:
    action, states = model.predict(obs)
    obs, rew, done, info = dummy_env.step(action)
    env.render()

In [18]:
model.save('dqn-{}'.format(model.num_timesteps))