In [1]:
import myenv
import gym
from gym import wrappers
import numpy as np
import time
import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl

In [2]:
from tqdm import tqdm

In [3]:
gamma = 0.999
alpha = 0.5

In [4]:
class QFunction(chainer.Chain):
    def __init__(self, obs_size, n_actions, n_hidden_channels=50):
        super(QFunction, self).__init__()
        with self.init_scope():
            self.l0 = L.Linear(obs_size, n_hidden_channels)
            self.l1 = L.Linear(n_hidden_channels, n_hidden_channels)
            self.l2 = L.Linear(n_hidden_channels, n_actions)
    def __call__(self, x, test=False):
        h1 = F.tanh(self.l0(x))
        h2 = F.tanh(self.l1(h1))
        return chainerrl.action_value.DiscreteActionValue(self.l2(h2))

In [5]:
env = gym.make('LiftingODE-v0')
env = wrappers.Monitor(env, 'videos', video_callable=(lambda ep: ep % 10 == 0), force=True)
num_episodes = 20000



In [6]:
q_func = QFunction(env.observation_space.shape[0], env.action_space.n)
optimizer = chainer.optimizers.Adam(1e-2)
optimizer.setup(q_func)
explorer = chainerrl.explorers.LinearDecayEpsilonGreedy(
    start_epsilon=1.0, 
    end_epsilon=0.1,
    decay_steps = num_episodes,
    random_action_func = env.action_space.sample
)
replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10**6)
phi = lambda x: x.astype(np.float32, copy=False)

In [7]:
agent = chainerrl.agents.DQN(
    q_func, optimizer, replay_buffer, gamma, explorer,
    replay_start_size = 500,
    update_interval = 1,
    target_update_interval = 100,
    phi = phi,
#     gpu = 0
)

In [8]:
for episode in range(num_episodes):
    observation = env.reset()
    done = False
    reward = 0
    R = 0
    while not done:
        if episode%10==0:
            env.render('human')
        action = agent.act_and_train(observation, reward)
        observation, reward, done, info = env.step(action)
        R += reward
    agent.stop_episode_and_train(observation, reward, done)
    if episode%10==0:
        print('episode: ', episode, 'R :', R, 'statistics :', agent.get_statistics())
        agent.save('agent/agent_' + str(episode))

(62.5, 12.5)
('episode: ', 0, 'R :', 5.0, 'statistics :', [(u'average_q', 0.1441309877332883), (u'average_loss', 0)])


KeyboardInterrupt: 