# 可視化のための依存環境&ライブラリ

In [1]:
import gym
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Dropout, Activation
import numpy as np
import random
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm import tqdm
from collections import deque
# d = Display()
# d.start()

In [2]:
class QNetwork():
  def __init__(self, action_space):

    model = Sequential()
    model.add(Dense(128, input_dim = 4))
    model.add(Activation("relu"))
    model.add(Dense(128))
    model.add(Activation("relu"))
    model.add(Dense(action_space))
    opt = keras.optimizers.Adam(learning_rate=0.01)
    model.compile(loss='mse', optimizer=opt)
    self.model = model

    model = Sequential()
    model.add(Dense(128, input_dim = 4))
    model.add(Activation("relu"))
    model.add(Dense(128))
    model.add(Activation("relu"))
    model.add(Dense(action_space))
    opt = keras.optimizers.Adam(learning_rate=0.01)
    model.compile(loss='mse', optimizer=opt)
    self.teacher_model = model
    self.update_teacher()
  
  def predict(self, x):
    y = self.model.predict(x)
    return y

  def predict_by_teacher(self, x):
    y = self.teacher_model.predict(x)
    return y

  def update_teacher(self):
    self.teacher_model.set_weights(self.model.get_weights())
  

  

In [3]:
class Trainer():
  def __init__(self, env, max_len = 1024, batch = 64, gamma = 0.95):
    self.gamma = gamma
    self.env = env
    self.env.reset()
    self.QNet = QNetwork(2)
    self.experiences = deque(maxlen = max_len)
    self.batch_size = batch
    self.training = False

  def train(self):
    self.run_episode()

  def step(self, num, s):
    next_s, reward, d, _  = self.env.step(num)
    self.experiences.append([s, next_s, reward, d, num]) #[現在の状態, 変位先, 報酬, 終了か, actionの番号]
    return next_s, reward, d

  def policy(self, s, epsilon = 0.1):
    if np.random.random() <= epsilon:
      return np.random.randint(2)
    else:
      return np.argmax(self.QNet.predict(np.array([np.array(s)])))

  def run_episode(self, times = 400):
    all = [] 
    for cnt in tqdm(range(times)):
      done = False
      s = self.env.reset()
      stand_count = 0

      if cnt%10 == 0:
        self.QNet.update_teacher()

      if self.training:
        self.update()

      while not done:
        stand_count += 1
        next_s, reward, done = self.step(self.policy(s), s)
        
        if not self.training:
          if len(self.experiences) == self.batch_size:
            self.training = True
        s = next_s
        if done:
          all.append(stand_count)
          print(stand_count)

    return plt.plot(all)


  def update(self):
    exp = random.sample(self.experiences, self.batch_size)
    target = []
    state = []
    for e in exp:
      s, next_s, reward, done, num = e
      y = self.QNet.predict(np.array([np.array(s)]))
      if not done:
        reward += self.gamma * np.max(self.QNet.predict_by_teacher(np.array([np.array(next_s)])))

      r = np.array([y[0][0], reward]) if num == 1 else np.array([reward, y[0][1]])
      target.append(r)
      state.append(s)

    state = np.array(state)
    target = np.array(target)
    self.QNet.model.fit(state, target)

  def show(self):
    s = self.env.reset()
    self.env.render()
    done = False
    while not done:
        next_s, reward, done = self.step(np.argmax(self.QNet.predict(np.array([np.array(s)]))), s)
        s = next_s
        self.env.render()




In [4]:
ENV = "CartPole-v1"
env = gym.make(ENV)
trainer = Trainer(env)
trainer.QNet.model.load_weights("param.hdf5")

In [12]:
# trainer.QNet.model.save_weights('param.hdf5')

In [15]:
# from google.colab import files
# files.download("./param.hdf5")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [7]:
trainer.show()

In [None]:
env = gym.make("CartPole-v1")
env.reset()
env.render('rgb_array').shape
img = plt.imshow(env.render('rgb_array'))
for _ in range(100):
    o, r, l, i = env.step(env.action_space.sample()) # 本当はDNNからアクションを入れる

    display.clear_output(wait=True)
    img.set_data(env.render('rgb_array'))
    plt.axis('off')
    display.display(plt.gcf())

    if l:
        env.reset()