In [6]:
import os
import random

import gym
from gym import wrappers
import chainer as C
import numpy as np

In [7]:
n_episodes = 1000
rng = random.Random()
rng.seed(42)

In [8]:
class Model(C.Chain):
    def __init__(self):
        super().__init__()
        state_size = 4
        action_size = 2
        hidden_size = 16
        # Chain.init_scope is necessary for gradient book-keeping to be set up
        # for all the links defined below, otherwise errors are not
        # propagated back through the graph
        with self.init_scope():
            self.l1 = C.links.Linear(state_size, hidden_size)
            self.l2 = C.links.Linear(hidden_size, action_size)

        
    def __call__(self, state):
        h = C.functions.relu(self.l1(state))
        return self.l2(h)


In [9]:
class Agent:
    def __init__(self):
        self._lr = 0.01
        self._epsilon = 1.0
        self._gamma = 0.95
        # experience buffer
        # {'state', 'action', 'reward', 'next_state'}
        self._exps = []
        self._last_act = 0
        self._state = None
        self._model = Model()
        self._optim = C.optimizers.SGD(lr=self._lr)
        self._optim.setup(self._model)
    
    def act(self, state):
        if rng.random() < self._epsilon:
            self._last_act = rng.choice([0, 1])
        else:
            x = state.reshape((1, -1))
            self._last_act = np.argmax(self._model(C.Variable(x)).data)
        self._state = state
        return self._last_act
    
    def _target(self, experience):
        # y = reward + gamma * Qmax(action, next_state)
        y = np.float32(experience['reward'])
        if experience['next_state'] is not None:
            # next state x
            x = experience['next_state'].reshape((1, -1)).astype(np.float32)
            y += self._gamma * np.max(self._model(C.Variable(x)).data)
        return y
    
    def _make_exp(self, state, action, reward, next_state):
        return dict(state=state, action=action, reward=reward, next_state=next_state)
    
    def reward(self, reward, next_state):
        self._epsilon -= 1e-3
        self._exps.append(self._make_exp(self._state, self._last_act, reward, next_state))
        batch_size = 32
        #sample a batch
        sample = rng.sample(self._exps, k=min(batch_size, len(self._exps)))
        #eval batch
        states = np.stack([s['state'] for s in sample])
        actions = np.stack([s['action'] for s in sample]).astype(np.int32)
        y = C.Variable(np.stack([self._target(s) for s in sample]).astype(np.float32))
        #calc loss
        self._model.cleargrads()
        qs = self._model(C.Variable(states))
        q = qs[np.arange(len(sample)), actions]
        loss = C.functions.mean_squared_error(y, q)
        #XXX print('loss', float(loss.data))
        loss.backward()
        self._optim.update()

agent = Agent()
print([agent.act(np.ones(4, dtype=np.float32)) for i in range(50)])
agent.reward(1.0, np.ones(4, dtype=np.float32))

[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1]


In [10]:
env = gym.wrappers.Monitor(gym.make('CartPole-v0'),
                           directory='out',
                           force=True)

env.seed(0)
agent = Agent()

for ep in range(n_episodes):
    state = env.reset()
    done = False
    while not done:
        state, reward, done, _ = env.step(agent.act(np.array(state, dtype=np.float32)))
        state = np.array(state, dtype=np.float32)
        agent.reward(reward, None if done else state)
    print(ep, 'steps', env.get_episode_lengths()[-1])


0 steps 33
1 steps 14
2 steps 11
3 steps 15
4 steps 10
5 steps 13
6 steps 43
7 steps 13
8 steps 39
9 steps 34
10 steps 28
11 steps 18
12 steps 19
13 steps 40
14 steps 12
15 steps 15
16 steps 13
17 steps 19
18 steps 60
19 steps 10
20 steps 80
21 steps 26
22 steps 57
23 steps 27
24 steps 56
25 steps 43
26 steps 100
27 steps 40
28 steps 43
29 steps 74
30 steps 122
31 steps 80
32 steps 64
33 steps 99
34 steps 95
35 steps 200
36 steps 96
37 steps 200
38 steps 139
39 steps 200
40 steps 163
41 steps 150
42 steps 200
43 steps 200
44 steps 200
45 steps 150
46 steps 175
47 steps 200
48 steps 196
49 steps 200
50 steps 200
51 steps 194
52 steps 144
53 steps 156


KeyboardInterrupt: 