# Atari 游戏 BreakoutDeterministic-v4

In [3]:
%matplotlib inline
import os
import sys
import time
import itertools
import logging

import numpy as np
np.random.seed(0)
import pandas as pd
import gym
import tensorflow as tf
from tensorflow import keras
from PIL import Image
import matplotlib.pyplot as plt

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG,
        format='%(asctime)s [%(levelname)s] %(message)s')

2022-05-07 19:14:14,235 [DEBUG] Loaded backend module://matplotlib_inline.backend_inline version unknown.


In [4]:
env_spec_id = 'BreakoutDeterministic-v4'
# env_spec_id = 'PongDeterministic-v4'
# env_spec_id = 'SeaquestDeterministic-v4'
# env_spec_id = 'SpaceInvadersDeterministic-v4'
# env_spec_id = 'BeamRiderDeterministic-v4'
env = gym.make(env_spec_id)
print('观测空间 = {}'.format(env.observation_space))
print('动作空间 = {}'.format(env.action_space))
print('回合最大步数 = {}'.format(env._max_episode_steps))
env.seed(0)

DependencyNotInstalled: No module named 'atari_py'. (HINT: you can install Atari dependencies by running 'pip install gym[atari]'.)

### 深度 Q 网络智能体
经验回放

In [None]:
class DQNReplayer:
    def __init__(self, capacity):
        self.memory = pd.DataFrame(index=range(capacity),
                columns=['observation', 'action', 'reward',
                'next_observation', 'done'])
        self.i = 0
        self.count = 0
        self.capacity = capacity
    
    def store(self, *args):
        self.memory.loc[self.i] = args
        self.i = (self.i + 1) % self.capacity
        self.count = min(self.count + 1, self.capacity)
        
    def sample(self, size):
        indices = np.random.choice(self.count, size=size)
        return tuple(np.stack(self.memory.loc[indices, field]) for \
                field in self.memory.columns)

智能体

In [None]:
class DQNAgent:
    def __init__(self, env, input_shape, learning_rate=0.00025,
            load_path=None, gamma=0.99,
            replay_memory_size=1000000, batch_size=32,
            replay_start_size=0,
            epsilon=1., epsilon_decrease_rate=9e-7, min_epsilon=0.1,
            random_initial_steps=0,
            clip_reward=True, rescale_state=True,
            update_freq=1, target_network_update_freq=1):
        
        self.action_n = env.action_space.n
        self.gamma = gamma
        
        # 经验回放参数
        self.replay_memory_size = replay_memory_size
        self.replay_start_size = replay_start_size
        self.batch_size = batch_size
        self.replayer = DQNReplayer(replay_memory_size)
        
        # 图像输入参数
        self.img_shape = (input_shape[-1], input_shape[-2])
        self.img_stack = input_shape[-3]
        
        # 探索参数
        self.epsilon = epsilon
        self.epsilon_decrease_rate = epsilon_decrease_rate
        self.min_epsilon = min_epsilon
        self.random_initial_steps = random_initial_steps
        
        self.clip_reward = clip_reward
        self.rescale_state = rescale_state
        
        self.update_freq = update_freq
        self.target_network_update_freq = target_network_update_freq
        
        # 评估网络
        self.evaluate_net = self.build_network(
                input_shape=input_shape, output_size=self.action_n,
                conv_activation=tf.nn.relu,
                fc_hidden_sizes=[512,], fc_activation=tf.nn.relu,
                learning_rate=learning_rate, load_path=load_path)
        self.evaluate_net.summary() # 输出网络结构
        # 目标网络
        self.target_net = self.build_network(
                input_shape=input_shape, output_size=self.action_n,
                conv_activation=tf.nn.relu,
                fc_hidden_sizes=[512,], fc_activation=tf.nn.relu,
                )
        self.update_target_network()
        
        # 初始化计数值
        self.step = 0
        self.fit_count = 0


    def build_network(self, input_shape, output_size, conv_activation,
            fc_hidden_sizes, fc_activation, output_activation=None,
            learning_rate=0.001, load_path=None):
        # 网络输入格式: (样本, 通道, 行, 列)
        model = keras.models.Sequential()
        # tf 要求从 (样本, 通道, 行, 列) 改为 (样本, 行, 列, 通道)
        model.add(keras.layers.Permute((2, 3, 1), input_shape=input_shape))
        
        # 卷积层
        model.add(keras.layers.Conv2D(32, 8, strides=4,
                activation=conv_activation))
        model.add(keras.layers.Conv2D(64, 4, strides=2,
                activation=conv_activation))
        model.add(keras.layers.Conv2D(64, 3, strides=1,
                activation=conv_activation))
        
        model.add(keras.layers.Flatten())
        
        # 全连接层
        for hidden_size in fc_hidden_sizes:
            model.add(keras.layers.Dense(hidden_size,
                    activation=fc_activation))
        model.add(keras.layers.Dense(output_size,
                activation=output_activation))

        if load_path is not None:
            logging.info('载入网络权重 {}.'.format(load_path))
            model.load_weights(load_path)

        try: # tf2
            optimizer = keras.optimizers.RMSprop(learning_rate, 0.95,
                    momentum=0.95, epsilon=0.01)
        except: # tf1
            optimizer = tf.train.RMSPropOptimizer(learning_rate, 0.95,
                    momentum=0.95, epsilon=0.01)
        model.compile(loss=keras.losses.mse, optimizer=optimizer)
        return model
        
    def get_next_state(self, state=None, observation=None):
        img = Image.fromarray(observation, 'RGB') 
        img = img.resize(self.img_shape).convert('L') # 改大小,变灰度
        img = np.asarray(img.getdata(), dtype=np.uint8
                ).reshape(img.size[1], img.size[0]) # 转成 np.array
        
        # 堆叠图像
        if state is None:
            next_state = np.array([img,] * self.img_stack) # 初始化
        else:
            next_state = np.append(state[1:], [img,], axis=0)
        return next_state
    
    def decide(self, state, test=False, step=None):
        if step is not None and step < self.random_initial_steps:
            epsilon = 1.
        elif test:
            epsilon = 0.05
        else:
            epsilon = self.epsilon
        if np.random.rand() < epsilon:
            action = np.random.choice(self.action_n)
        else:
            if self.rescale_state:
                state = state / 128. - 1.
            q_values = self.evaluate_net.predict(state[np.newaxis])[0]
            action = np.argmax(q_values)
        return action

    def learn(self, state, action, reward, next_state, done):
        self.replayer.store(state, action, reward, next_state, done)

        self.step += 1
        
        if self.step % self.update_freq == 0 and \
                self.replayer.count >= self.replay_start_size:
            states, actions, rewards, next_states, dones = \
                    self.replayer.sample(self.batch_size) # 回放

            if self.rescale_state:
                states = states / 128. - 1.
                next_states = next_states / 128. - 1.
            if self.clip_reward:
                rewards = np.clip(rewards, -1., 1.)
            
            next_qs = self.target_net.predict(next_states)
            next_max_qs = next_qs.max(axis=-1)
            targets = self.evaluate_net.predict(states)
            targets[range(self.batch_size), actions] = rewards + \
                    self.gamma * next_max_qs * (1. - dones)

            h = self.evaluate_net.fit(states, targets, verbose=0)
            self.fit_count += 1
            
            if self.fit_count % 100 == 0:
                logging.info('训练 {}, 回合 {}, 存储大小 {}, 损失 {}' \
                        .format(self.fit_count, self.epsilon,
                        self.replayer.count, h.history['loss'][0]))
            
            if self.fit_count % self.target_network_update_freq == 0:
                self.update_target_network()
        
        # 更新 epsilon 的值：线性下降
        if self.step >= self.replay_start_size:
            self.epsilon = max(self.epsilon - self.epsilon_decrease_rate,
                               self.min_epsilon)

    def update_target_network(self): # 更新目标网络
        self.target_net.set_weights(self.evaluate_net.get_weights())
        logging.info('目标网络已更新')

    def save_network(self, path): # 保存网络
        dirname = os.path.dirname(save_path)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
            logging.info('创建文件夹 {}'.format(dirname))
        self.evaluate_net.save_weights(path)
        logging.info('网络权重已保存 {}'.format(path))

测试

In [None]:
def test(env, agent, episodes=100, render=False, verbose=True):
    steps, episode_rewards = [], []
    for episode in range(episodes):
        episode_reward = 0
        observation = env.reset()
        state = agent.get_next_state(None, observation)
        for step in itertools.count():
            if render:
                env.render()
            action = agent.decide(state, test=True, step=step)
            observation, reward, done, info = env.step(action)
            state = agent.get_next_state(state, observation)
            episode_reward += reward
            if done:
                break
        step += 1
        steps.append(step)
        episode_rewards.append(episode_reward)
        logging.info('[测试] 回合 {}: 步骤 {}, 奖励 {}, 步数 {}'
                .format(episode, step, episode_reward, np.sum(steps)))
            
    if verbose:
        logging.info('[测试小结] 步数: 平均 = {}, 最小 = {}, 最大 = {}.' \
                .format(np.mean(steps), np.min(steps), np.max(steps)))
        logging.info('[测试小结] 奖励: 平均 = {}, 最小 = {}, 最大 = {}' \
                .format(np.mean(episode_rewards), np.min(episode_rewards),
                np.max(episode_rewards)))
    return episode_rewards

参数设置

In [None]:
render = False
load_path = None
save_path = './output/' + env.unwrapped.spec.id + '-' + \
        time.strftime('%Y%m%d-%H%M%S') + '/model.h5'

"""
Nature 文章使用的参数, 运行极慢, 请勿轻易尝试
"""
input_shape = (4, 110, 84) # 输入网络大小
batch_size = 32
replay_memory_size = 1000000
target_network_update_freq = 10000
gamma = 0.99
update_freq = 4 # 训练网络的间隔
learning_rate = 0.00025 # 优化器学习率
epsilon = 1. # 初始探索率
min_epsilon = 0.1 # 最终探索率
epsilon_decrease = 9e-7 # 探索减小速度
replay_start_size = 50000 # 开始训练前的经验数
random_initial_steps = 30 # 每个回合开始时随机步数
frames = 50000000 # 整个算法的总训练步数
test_freq = 50000 # 验证智能体的步数间隔
test_episodes = 100 # 每次验证智能体的回合数


"""
小规模参数, 运行时间数小时, 有一点点训练效果
"""
batch_size = 32
replay_memory_size = 50000
target_network_update_freq = 4000
replay_start_size = 10000
random_initial_steps = 30
frames = 100000
test_freq = 25000
test_episodes = 50


# """
# 超小规模参数, 数分钟即可运行完, 基本没有训练效果
# """
# batch_size = 6
# replay_memory_size = 5000
# target_network_update_freq = 80
# replay_start_size = 500
# random_initial_steps = 30
# frames = 7500
# test_freq = 2500
# test_episodes = 10

训练

In [None]:
agent = DQNAgent(env, input_shape=input_shape, batch_size=batch_size,
        replay_memory_size=replay_memory_size,
        learning_rate=learning_rate, gamma=gamma,
        epsilon=epsilon, epsilon_decrease_rate=epsilon_decrease,
        min_epsilon=min_epsilon, random_initial_steps=random_initial_steps,
        load_path=load_path,
        update_freq=update_freq,
        target_network_update_freq=target_network_update_freq)

logging.info("训练开始")

frame = 0
max_mean_episode_reward = float("-inf")
for episode in itertools.count():
    observation = env.reset()
    episode_reward = 0
    state = agent.get_next_state(None, observation)
    for step in itertools.count():
        if render:
            env.render()
        frame += 1
        action = agent.decide(state, step=step)
        observation, reward, done, _ = env.step(action)
        next_state = agent.get_next_state(state, observation)
        episode_reward += reward
        agent.learn(state, action, reward, next_state, done)
        
        # 验证
        if frame % test_freq == 0 or \
                (done and (frame + 1) % test_freq == 0):
            test_episode_rewards = test(env=env,
                    agent=agent, episodes=test_episodes, render=render)
            if max_mean_episode_reward < np.mean(test_episode_rewards):
                max_mean_episode_reward = np.mean(test_episode_rewards)
                agent.save_network(save_path)
                path = save_path[:-2] + str(agent.fit_count) + '.h5'
                agent.save_network(path)
        
        if done:
            step += 1
            frame += 1
            break
        state = next_state
    
    logging.info("回合 {}, 步数 {}, 奖励 {}, 总步数 {}".format(
            episode, step, episode_reward, frame))
    
    if frame > frames:
        break

logging.info("训练结束")

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
permute (Permute)            (None, 110, 84, 4)        0         
_________________________________________________________________
conv2d (Conv2D)              (None, 26, 20, 32)        8224      
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 12, 9, 64)         32832     
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 10, 7, 64)         36928     
_________________________________________________________________
flatten (Flatten)            (None, 4480)              0         
_________________________________________________________________
dense (Dense)                (None, 512)               2294272   
_________________________________________________________________
dense_1 (Dense)              (None, 4)                 2052      
Total para

2019-01-01 10:27:53,321 [INFO] 回合 130, 步数 134, 奖励 0.0, 总步数 22991
2019-01-01 10:28:13,858 [INFO] 回合 131, 步数 263, 奖励 3.0, 总步数 23255
2019-01-01 10:28:19,902 [INFO] 训练 5800, 回合 0.9791209000006874, 存储大小 23200, 损失 5.217835496296175e-05
2019-01-01 10:28:30,777 [INFO] 回合 132, 步数 218, 奖励 2.0, 总步数 23474
2019-01-01 10:28:51,042 [INFO] 训练 5900, 回合 0.9787609000006993, 存储大小 23600, 损失 0.007824686355888844
2019-01-01 10:28:57,085 [INFO] 回合 133, 步数 338, 奖励 4.0, 总步数 23813
2019-01-01 10:29:11,382 [INFO] 回合 134, 步数 180, 奖励 1.0, 总步数 23994
2019-01-01 10:29:21,461 [INFO] 回合 135, 步数 132, 奖励 0.0, 总步数 24127
2019-01-01 10:29:22,398 [INFO] 训练 6000, 回合 0.9784009000007111, 存储大小 24000, 损失 1.397706273564836e-05
2019-01-01 10:29:31,734 [INFO] 回合 136, 步数 129, 奖励 0.0, 总步数 24257
2019-01-01 10:29:47,879 [INFO] 回合 137, 步数 205, 奖励 2.0, 总步数 24463
2019-01-01 10:29:53,961 [INFO] 训练 6100, 回合 0.978040900000723, 存储大小 24400, 损失 2.0121200577705167e-05
2019-01-01 10:30:03,074 [INFO] 回合 138, 步数 191, 奖励 1.0, 总步数 24655
2019-01-01 10:30

2019-01-01 10:45:05,302 [INFO] 回合 172, 步数 141, 奖励 0.0, 总步数 30002
2019-01-01 10:45:19,812 [INFO] 训练 7500, 回合 0.9730009000008889, 存储大小 30000, 损失 0.0003025174664799124
2019-01-01 10:45:22,921 [INFO] 回合 173, 步数 207, 奖励 2.0, 总步数 30210
2019-01-01 10:45:34,267 [INFO] 回合 174, 步数 135, 奖励 0.0, 总步数 30346
2019-01-01 10:45:46,185 [INFO] 回合 175, 步数 139, 奖励 0.0, 总步数 30486
2019-01-01 10:45:54,190 [INFO] 训练 7600, 回合 0.9726409000009008, 存储大小 30400, 损失 2.0472414689720608e-05
2019-01-01 10:45:58,735 [INFO] 回合 176, 步数 143, 奖励 0.0, 总步数 30630
2019-01-01 10:46:12,331 [INFO] 回合 177, 步数 157, 奖励 1.0, 总步数 30788
2019-01-01 10:46:29,105 [INFO] 训练 7700, 回合 0.9722809000009126, 存储大小 30800, 损失 6.645798566751182e-05
2019-01-01 10:46:35,785 [INFO] 回合 178, 步数 269, 奖励 3.0, 总步数 31058
2019-01-01 10:46:46,923 [INFO] 回合 179, 步数 128, 奖励 0.0, 总步数 31187
2019-01-01 10:47:03,259 [INFO] 训练 7800, 回合 0.9719209000009245, 存储大小 31200, 损失 0.007749106734991074
2019-01-01 10:47:05,624 [INFO] 回合 180, 步数 222, 奖励 2.0, 总步数 31410
2019-01-01 10:4

2019-01-01 11:05:08,538 [INFO] 回合 248, 步数 161, 奖励 1.0, 总步数 43070
2019-01-01 11:05:21,077 [INFO] 回合 249, 步数 135, 奖励 0.0, 总步数 43206
2019-01-01 11:05:34,195 [INFO] 回合 250, 步数 134, 奖励 0.0, 总步数 43341
2019-01-01 11:05:44,978 [INFO] 训练 10800, 回合 0.96112090000128, 存储大小 43200, 损失 2.434356792946346e-05
2019-01-01 11:05:46,613 [INFO] 回合 251, 步数 126, 奖励 0.0, 总步数 43468
2019-01-01 11:06:06,756 [INFO] 回合 252, 步数 212, 奖励 2.0, 总步数 43681
2019-01-01 11:06:23,237 [INFO] 训练 10900, 回合 0.9607609000012919, 存储大小 43600, 损失 0.0006142361089587212
2019-01-01 11:06:23,307 [INFO] 回合 253, 步数 175, 奖励 1.0, 总步数 43857
2019-01-01 11:06:40,528 [INFO] 回合 254, 步数 186, 奖励 1.0, 总步数 44044
2019-01-01 11:06:59,930 [INFO] 训练 11000, 回合 0.9604009000013037, 存储大小 44000, 损失 0.0003221797524020076
2019-01-01 11:07:03,088 [INFO] 回合 255, 步数 246, 奖励 3.0, 总步数 44291
2019-01-01 11:07:19,069 [INFO] 回合 256, 步数 161, 奖励 1.0, 总步数 44453
2019-01-01 11:07:38,609 [INFO] 训练 11100, 回合 0.9600409000013156, 存储大小 44400, 损失 1.502894883742556e-05
2019-01-01 11

2019-01-01 11:35:19,412 [INFO] 回合 289, 步数 166, 奖励 1.0, 总步数 50603
2019-01-01 11:35:31,518 [INFO] 训练 12600, 回合 0.9546409000014934, 存储大小 50000, 损失 5.707787295250455e-06
2019-01-01 11:35:37,473 [INFO] 回合 290, 步数 131, 奖励 0.0, 总步数 50735
2019-01-01 11:35:57,136 [INFO] 回合 291, 步数 145, 奖励 0.0, 总步数 50881
2019-01-01 11:36:15,088 [INFO] 回合 292, 步数 132, 奖励 0.0, 总步数 51014
2019-01-01 11:36:25,956 [INFO] 训练 12700, 回合 0.9542809000015052, 存储大小 50000, 损失 3.099538298556581e-05
2019-01-01 11:36:40,743 [INFO] 回合 293, 步数 188, 奖励 2.0, 总步数 51203
2019-01-01 11:37:01,986 [INFO] 回合 294, 步数 160, 奖励 1.0, 总步数 51364
2019-01-01 11:37:19,918 [INFO] 训练 12800, 回合 0.9539209000015171, 存储大小 50000, 损失 2.4297041818499565e-05
2019-01-01 11:37:46,797 [INFO] 回合 295, 步数 331, 奖励 5.0, 总步数 51696
2019-01-01 11:38:08,670 [INFO] 回合 296, 步数 162, 奖励 0.0, 总步数 51859
2019-01-01 11:38:14,046 [INFO] 训练 12900, 回合 0.9535609000015289, 存储大小 50000, 损失 2.4311124434461817e-05
2019-01-01 11:38:34,182 [INFO] 回合 297, 步数 187, 奖励 1.0, 总步数 52047
2019-01-0

2019-01-01 12:04:52,185 [INFO] 回合 363, 步数 187, 奖励 1.0, 总步数 64139
2019-01-01 12:05:06,640 [INFO] 回合 364, 步数 129, 奖励 0.0, 总步数 64269
2019-01-01 12:05:17,233 [INFO] 训练 16000, 回合 0.9424009000018964, 存储大小 50000, 损失 1.6038948160712607e-05
2019-01-01 12:05:17,280 [INFO] 目标网络已更新
2019-01-01 12:05:28,783 [INFO] 回合 365, 步数 202, 奖励 2.0, 总步数 64472
2019-01-01 12:05:57,993 [INFO] 回合 366, 步数 262, 奖励 3.0, 总步数 64735
2019-01-01 12:06:01,405 [INFO] 训练 16100, 回合 0.9420409000019082, 存储大小 50000, 损失 3.77221658709459e-05
2019-01-01 12:06:22,403 [INFO] 回合 367, 步数 225, 奖励 2.0, 总步数 64961
2019-01-01 12:06:45,427 [INFO] 训练 16200, 回合 0.9416809000019201, 存储大小 50000, 损失 1.456985319236992e-05
2019-01-01 12:06:47,798 [INFO] 回合 368, 步数 229, 奖励 2.0, 总步数 65191
2019-01-01 12:07:03,149 [INFO] 回合 369, 步数 141, 奖励 0.0, 总步数 65333
2019-01-01 12:07:18,066 [INFO] 回合 370, 步数 135, 奖励 0.0, 总步数 65469
2019-01-01 12:07:29,328 [INFO] 训练 16300, 回合 0.9413209000019319, 存储大小 50000, 损失 6.307307558017783e-06
2019-01-01 12:07:38,324 [INFO] 回合 371

2019-01-01 12:26:41,386 [INFO] [测试] 回合 22: 步骤 229, 奖励 3.0, 步数 3818
2019-01-01 12:26:44,528 [INFO] [测试] 回合 23: 步骤 127, 奖励 0.0, 步数 3945
2019-01-01 12:26:49,858 [INFO] [测试] 回合 24: 步骤 202, 奖励 2.0, 步数 4147
2019-01-01 12:26:54,289 [INFO] [测试] 回合 25: 步骤 171, 奖励 1.0, 步数 4318
2019-01-01 12:26:57,293 [INFO] [测试] 回合 26: 步骤 123, 奖励 0.0, 步数 4441
2019-01-01 12:27:00,332 [INFO] [测试] 回合 27: 步骤 127, 奖励 0.0, 步数 4568
2019-01-01 12:27:03,463 [INFO] [测试] 回合 28: 步骤 128, 奖励 0.0, 步数 4696
2019-01-01 12:27:08,586 [INFO] [测试] 回合 29: 步骤 198, 奖励 2.0, 步数 4894
2019-01-01 12:27:11,673 [INFO] [测试] 回合 30: 步骤 128, 奖励 0.0, 步数 5022
2019-01-01 12:27:14,856 [INFO] [测试] 回合 31: 步骤 129, 奖励 0.0, 步数 5151
2019-01-01 12:27:18,079 [INFO] [测试] 回合 32: 步骤 130, 奖励 0.0, 步数 5281
2019-01-01 12:27:22,342 [INFO] [测试] 回合 33: 步骤 168, 奖励 1.0, 步数 5449
2019-01-01 12:27:25,422 [INFO] [测试] 回合 34: 步骤 125, 奖励 0.0, 步数 5574
2019-01-01 12:27:28,585 [INFO] [测试] 回合 35: 步骤 131, 奖励 0.0, 步数 5705
2019-01-01 12:27:33,879 [INFO] [测试] 回合 36: 步骤 203, 奖励 2.0, 步数 

2019-01-01 12:43:50,333 [INFO] 训练 21200, 回合 0.9236809000025127, 存储大小 50000, 损失 7.574947085231543e-05
2019-01-01 12:44:04,509 [INFO] 回合 481, 步数 182, 奖励 1.0, 总步数 85440
2019-01-01 12:44:16,750 [INFO] 回合 482, 步数 135, 奖励 0.0, 总步数 85576
2019-01-01 12:44:26,530 [INFO] 训练 21300, 回合 0.9233209000025245, 存储大小 50000, 损失 1.6860989489941858e-05
2019-01-01 12:44:35,943 [INFO] 回合 483, 步数 214, 奖励 2.0, 总步数 85791
2019-01-01 12:44:50,599 [INFO] 回合 484, 步数 161, 奖励 1.0, 总步数 85953
2019-01-01 12:45:02,459 [INFO] 训练 21400, 回合 0.9229609000025364, 存储大小 50000, 损失 8.802755473880097e-06
2019-01-01 12:45:02,459 [INFO] 回合 485, 步数 132, 奖励 0.0, 总步数 86086
2019-01-01 12:45:14,848 [INFO] 回合 486, 步数 139, 奖励 0.0, 总步数 86226
2019-01-01 12:45:27,474 [INFO] 回合 487, 步数 138, 奖励 0.0, 总步数 86365
2019-01-01 12:45:38,668 [INFO] 训练 21500, 回合 0.9226009000025482, 存储大小 50000, 损失 1.6201463949983008e-05
2019-01-01 12:45:40,229 [INFO] 回合 488, 步数 142, 奖励 0.0, 总步数 86508
2019-01-01 12:46:10,024 [INFO] 回合 489, 步数 330, 奖励 5.0, 总步数 86839
2019-01-0

2019-01-01 13:03:22,411 [INFO] 回合 556, 步数 227, 奖励 2.0, 总步数 98498
2019-01-01 13:03:27,727 [INFO] 训练 24500, 回合 0.9118009000029038, 存储大小 50000, 损失 4.359880040283315e-05
2019-01-01 13:03:40,839 [INFO] 回合 557, 步数 207, 奖励 2.0, 总步数 98706
2019-01-01 13:03:58,330 [INFO] 回合 558, 步数 196, 奖励 2.0, 总步数 98903
2019-01-01 13:04:03,333 [INFO] 训练 24600, 回合 0.9114409000029157, 存储大小 50000, 损失 7.627880222571548e-06
2019-01-01 13:04:13,087 [INFO] 回合 559, 步数 165, 奖励 1.0, 总步数 99069
2019-01-01 13:04:31,090 [INFO] 回合 560, 步数 200, 奖励 2.0, 总步数 99270
2019-01-01 13:04:39,208 [INFO] 训练 24700, 回合 0.9110809000029275, 存储大小 50000, 损失 0.007810384035110474
2019-01-01 13:04:50,972 [INFO] 回合 561, 步数 224, 奖励 2.0, 总步数 99495
2019-01-01 13:05:02,690 [INFO] 回合 562, 步数 132, 奖励 0.0, 总步数 99628
2019-01-01 13:05:14,715 [INFO] 训练 24800, 回合 0.9107209000029394, 存储大小 50000, 损失 5.626908387057483e-05
2019-01-01 13:05:17,703 [INFO] 回合 563, 步数 169, 奖励 1.0, 总步数 99798
2019-01-01 13:05:29,546 [INFO] 回合 564, 步数 133, 奖励 0.0, 总步数 99932
2019-01-01 1

测试

In [None]:
test_agent = DQNAgent(env, input_shape=input_shape, load_path=save_path)
test_episode_rewards = test(env, test_agent, episodes=test_episodes)
print('平均回合奖励 = {}'.format(np.mean(test_episode_rewards)))

2019-01-01 13:12:06,576 [INFO] 载入网络权重 ./output/BreakoutDeterministic-v4-20190123-100223/model.h5.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
permute_2 (Permute)          (None, 110, 84, 4)        0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 26, 20, 32)        8224      
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 12, 9, 64)         32832     
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 10, 7, 64)         36928     
_________________________________________________________________
flatten_2 (Flatten)          (None, 4480)              0         
_________________________________________________________________
dense_4 (Dense)              (None, 512)               2294272   
____________________________________________