In [1]:
import os
import random

import gym
from gym import wrappers
import chainer as C
import numpy as np

In [2]:
n_episodes = 1000
rng = random.Random()
rng.seed(42)

In [3]:
class Model(C.Chain):
    def __init__(self):
        super().__init__()
        state_size = 4
        action_size = 2
        self.l1 = C.links.Linear(state_size, 16)
        self.l2 = C.links.Linear(16, action_size)
        
    def __call__(self, state):
        h = C.functions.relu(self.l1(state))
        return self.l2(h)


In [20]:
class Agent:
    def __init__(self):
        self._lr = 0.1
        self._epsilon = 1.0
        self._gamma = 0.95
        self._exps = []
        self._last_act = 0
        self._state = None
        self._model = Model()
        self._optim = C.optimizers.SGD(lr=self._lr)
        self._optim.setup(self._model)
    
    def act(self, state):
        if rng.random() < self._epsilon:
            self._last_act = rng.choice([0, 1])
        else:
            x = state.reshape((1, -1))
            self._last_act = np.argmax(self._model(C.Variable(x)).data)
        self._state = state
        return self._last_act
    
    def _target(self, experience):
        # y = rew + gam * Qmax(action, next_state)
        y = np.float32(experience[2])
        if experience[3] is not None:
            # next state x
            x = experience[3].reshape((1, -1)).astype(np.float32)
            y += self._gamma * np.max(self._model(C.Variable(x)).data)
        return y
    
    def reward(self, reward, next_state):
        self._epsilon -= 1e-3
        self._exps.append((self._state, self._last_act, reward, next_state))
        batch_size = 32
        #sample a batch
        sample = rng.sample(self._exps, k=min(batch_size, len(self._exps)))
        #eval batch
        states = np.stack([s[0] for s in sample])
        actions = np.stack([s[1] for s in sample]).astype(np.int32)
        y = C.Variable(np.stack([self._target(s) for s in sample]).astype(np.float32))
        #calc loss
        self._model.cleargrads()
        qs = self._model(C.Variable(states))
        q = qs[np.arange(len(sample)), actions]
        loss = C.functions.mean_squared_error(y, q)
        #XXX print('loss', float(loss.data))
        loss.backward()
        self._optim.update()

agent = Agent()
print([agent.act(np.ones(4, dtype=np.float32)) for i in range(50)])
agent.reward(1.0, np.ones(4, dtype=np.float32))

[1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0]


In [21]:
env = gym.wrappers.Monitor(gym.make('CartPole-v0'),
                           directory='out',
                           force=True)

env.seed(0)
agent = Agent()

for ep in range(n_episodes):
    state = env.reset()
    done = False
    while not done:
        state, reward, done, _ = env.step(agent.act(np.array(state, dtype=np.float32)))
        state = np.array(state, dtype=np.float32)
        agent.reward(reward, None if done else state)
    print(ep, 'steps', env.get_episode_lengths()[-1])


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
0 steps 23
1 steps 21
2 steps 18
3 steps 21
4 steps 24
5 steps 11
6 steps 12
7 steps 15
8 steps 16
9 steps 19
10 steps 40
11 steps 13
12 steps 17
13 steps 27
14 steps 11
15 steps 13
16 steps 25
17 steps 19
18 steps 25
19 steps 21
20 steps 21
21 steps 13
22 steps 26
23 steps 14
24 steps 20
25 steps 12
26 steps 16
27 steps 21
28 steps 10
29 steps 12
30 steps 18
31 steps 10
32 steps 19
33 steps 18
34 steps 22
35 steps 10
36 steps 12
37 steps 11
38 steps 28
39 steps 9
40 steps 11
41 steps 9
42 steps 14
43 steps 9
44 steps 10
45 steps 13
46 steps 25
47 steps 12
48 steps 13
49 steps 10
50 steps 10
51 steps 8
52 steps 10
53 steps 8
54 steps 8
55 steps 9
56 steps 10
57 steps 13
58 steps 11
59 steps 10
60 steps 11
61 steps 11
62 steps 11
63 steps 12
64 steps 10
65 steps 15
66 steps 9
67 steps 8
68 steps 10
69 steps 8
70 steps 10
71 steps 10
72 steps 9
73 steps 17
74 steps 8
75 steps 10
76

KeyboardInterrupt: 