## CartPole DQN with ChainerRL
---

In [2]:
import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl
import gym
import numpy as np
from datetime import datetime as dt

class QFunction(chainer.Chain):
    def __init__(self, obs_size, n_actions, n_hidden_channels=50):
        super().__init__(
            l0 = L.Linear(obs_size, n_hidden_channels),
            l1 = L.Linear(n_hidden_channels, n_hidden_channels),
            l2 = L.Linear(n_hidden_channels, n_actions))
        
    def __call__(self, x, test=False):
        h = F.tanh(self.l0(x))
        h = F.tanh(self.l1(h))
        return chainerrl.action_value.DiscreteActionValue(self.l2(h))

In [13]:
type(env.step(0)[3])

dict

In [4]:
env = gym.make('CartPole-v0')

obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n
q_func = QFunction(obs_size, n_actions)

optimizer = chainer.optimizers.Adam(eps=1e-3)
optimizer.setup(q_func)

explorer = chainerrl.explorers.ConstantEpsilonGreedy(epsilon=0.2, random_action_func=env.action_space.sample)

replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10**6)

phi = lambda x: x.astype(np.float32, copy=False)

agent = chainerrl.agents.DQN(
    q_func, optimizer, replay_buffer, gamma=0.95, explorer=explorer,
    replay_start_size=500, update_interval=1,
    target_update_interval=100, phi=phi)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [3]:
#train
chainerrl.experiments.train_agent_with_evaluation(
    agent, env, steps=2000, eval_n_runs=10,
    max_episode_len=2000, eval_interval=1000,
    outdir='result/')

In [4]:
#apply
obs = env.reset()
done = False

while not done:
    env.render(mode='rgb_array')
    action = agent.act(obs)
    obs, _, done, _ = env.step(action)

In [10]:
env.step(1)

(array([ 2.53293986,  1.98376962,  0.16458156, -0.12678701]), 0.0, True, {})

In [6]:
env.close()