In [1]:
import myenv
import gym
from gym import wrappers
import numpy as np
import time
import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl

In [2]:
class QFunction(chainer.Chain):
    def __init__(self, obs_size, n_actions, n_hidden_channels=100):
        super(QFunction, self).__init__()
        with self.init_scope():
            self.l0 = L.Linear(obs_size, n_hidden_channels)
            self.l1 = L.Linear(n_hidden_channels, n_hidden_channels)
            self.l2 = L.Linear(n_hidden_channels, n_actions)
    def __call__(self, x, test=False):
        h = F.tanh(self.l0(x))
        h = F.tanh(self.l1(h))
        return chainerrl.action_value.DiscreteActionValue(self.l2(h))

In [3]:
env = gym.make('RobotArmODE-v0')
# env = wrappers.Monitor(env, 'videos', video_callable=(lambda ep: ep % 10 == 0), force=True)



In [4]:
gamma = 0.99
alpha = 0.5
max_number_of_steps = 2000
num_episodes = 2000

In [5]:
q_func = QFunction(env.observation_space.shape[0], env.action_space.n)
optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)
explorer = chainerrl.explorers.LinearDecayEpsilonGreedy(
    start_epsilon = 1.0,
    end_epsilon = 0.1,
    decay_steps = num_episodes,
    random_action_func = env.action_space.sample
)
replay_buffer = chainerrl.replay_buffer.PrioritizedReplayBuffer(capacity=10**6)
phi = lambda x: x.astype(np.float32, copy=False)

In [6]:
agent = chainerrl.agents.DoubleDQN(
    q_func, optimizer, replay_buffer, gamma, explorer,
    replay_start_size = 500,
    update_interval = 1,
    target_update_interval = 100,
    phi = phi
)

In [7]:
for episode in range(num_episodes):
    observation = env.reset()
    done = False
    reward = 0
    R = 0
    for t in range(max_number_of_steps):
        if episode%10==0:
            env.render('human')
        action = agent.act_and_train(observation, reward)
        observation, reward, done, info = env.step(action)
        R += reward
        if done and R < 2000:
            break

    agent.stop_episode_and_train(observation, reward, done)
    if episode % 10 == 0:
        agent.save('agent+1/agent_' + str(episode))
        print('episode: ', episode, 'R: ', R, 'statistics', agent.get_statistics())
    

('episode: ', 0, 'R: ', 0.01089999999999999, 'statistics', [(u'average_q', 0.05076845685854456), (u'average_loss', 0)])
('episode: ', 10, 'R: ', 0.003399999999999998, 'statistics', [(u'average_q', 0.2074540048386371), (u'average_loss', 0.011078575534531539)])
('episode: ', 20, 'R: ', 0.010999999999999989, 'statistics', [(u'average_q', 0.22889498924886534), (u'average_loss', 0.0019207919314112086)])
('episode: ', 30, 'R: ', 0.012099999999999982, 'statistics', [(u'average_q', 0.26047172869573887), (u'average_loss', 0.00143284559458547)])
('episode: ', 40, 'R: ', 0.0026999999999999993, 'statistics', [(u'average_q', 0.272646646510645), (u'average_loss', 0.0009117878618989727)])
('episode: ', 50, 'R: ', 0.002899999999999999, 'statistics', [(u'average_q', 0.27362321314863536), (u'average_loss', 0.0008968515446461502)])
('episode: ', 60, 'R: ', 0.0023, 'statistics', [(u'average_q', 0.2642198282708353), (u'average_loss', 0.0007946235729329568)])
('episode: ', 70, 'R: ', 0.0024999999999999996, 

('episode: ', 590, 'R: ', 0.004099999999999998, 'statistics', [(u'average_q', 1.2488805786463664), (u'average_loss', 0.006855806334827227)])
('episode: ', 600, 'R: ', 6.196476510973164, 'statistics', [(u'average_q', 1.3877672557026721), (u'average_loss', 0.007885902387737363)])
('episode: ', 610, 'R: ', 0.01739999999999995, 'statistics', [(u'average_q', 1.670273467714801), (u'average_loss', 0.007732596193610922)])
('episode: ', 620, 'R: ', 9.060049423649312, 'statistics', [(u'average_q', 2.0677501173529484), (u'average_loss', 0.012937171386781158)])
('episode: ', 630, 'R: ', 0.0036999999999999976, 'statistics', [(u'average_q', 2.314173758123834), (u'average_loss', 0.011732658914686835)])
('episode: ', 640, 'R: ', 4.682407378798407, 'statistics', [(u'average_q', 2.441927398842279), (u'average_loss', 0.011290651176212025)])
('episode: ', 650, 'R: ', 0.004299999999999998, 'statistics', [(u'average_q', 2.526669762996352), (u'average_loss', 0.01054893311610251)])
('episode: ', 660, 'R: ', 0

('episode: ', 1200, 'R: ', 0.0022, 'statistics', [(u'average_q', 228.63531172986723), (u'average_loss', 3.3129940464205565)])
('episode: ', 1210, 'R: ', 0.0023, 'statistics', [(u'average_q', 228.83951968616012), (u'average_loss', 3.424887849094066)])
('episode: ', 1220, 'R: ', 33.62632449199737, 'statistics', [(u'average_q', 227.68377551173796), (u'average_loss', 3.499802306812263)])
('episode: ', 1230, 'R: ', 0.0022, 'statistics', [(u'average_q', 245.42219711435675), (u'average_loss', 3.4781678849749444)])
('episode: ', 1240, 'R: ', 0.0021000000000000003, 'statistics', [(u'average_q', 243.66216692706718), (u'average_loss', 3.624863759880083)])
('episode: ', 1250, 'R: ', 0.0039999999999999975, 'statistics', [(u'average_q', 241.65310601175813), (u'average_loss', 3.4786185909621063)])
('episode: ', 1260, 'R: ', 0.005700000000000002, 'statistics', [(u'average_q', 233.5021076086083), (u'average_loss', 3.1319261459280003)])
('episode: ', 1270, 'R: ', 0.003499999999999998, 'statistics', [(u'

KeyboardInterrupt: 