## 필요한 모듈 설치 확인

In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from replay_memory import ReplayBuffer, PrioritizedReplayBuffer

import random
import os
import pickle
import time
from collections import deque
from copy import deepcopy
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set()

from wrappers import wrap, wrap_cover, SubprocVecEnv

## 하이퍼 파라미터 정의하기

In [2]:
'''DQN settings'''
# sequential images to define state
STATE_LEN = 4
# target policy sync interval
TARGET_REPLACE_ITER = 1
# simulator steps for start learning
LEARN_START = int(1e+3)
# (prioritized) experience replay memory size
MEMORY_CAPACITY = int(1e+5)
# simulator steps for learning interval
LEARN_FREQ = 1
# quantile numbers for QR-DQN
N_QUANT = 64
# quantiles
QUANTS = np.linspace(0.0, 1.0, N_QUANT + 1)[1:]

'''Environment Settings'''
# number of environments for C51
N_ENVS = 4
# openai gym env name
ENV_NAME = 'BreakoutNoFrameskip-v4'
env = SubprocVecEnv([wrap_cover(ENV_NAME) for i in range(N_ENVS)])
N_ACTIONS = env.action_space.n
N_STATES = env.observation_space.shape
# Total simulation step
STEP_NUM = int(1e+7)
# gamma for MDP
GAMMA = 0.99
# visualize for agent playing
RENDERING = False

'''Training settings'''
# check GPU usage
USE_GPU = torch.cuda.is_available()
print('USE GPU: '+str(USE_GPU))
# mini-batch size
BATCH_SIZE = 32
# learning rage
LR = 1e-4
# epsilon-greedy
EPSILON = 1.0

'''Save&Load Settings'''
# check save/load
SAVE = True
LOAD = False
# save frequency
SAVE_FREQ = int(1e+3)
# paths for predction net, target net, result log
PRED_PATH = './data/model/iqn_pred_net.pkl'
TARGET_PATH = './data/model/iqn_target_net.pkl'
RESULT_PATH = './data/plots/result.pkl'



USE GPU: True




## 네트워크 구조 정의하기

In [3]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        # nn.Sequential을 사용하면 다음과 같입 코드를 간결하게 바꿀 수 있습니다.
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(STATE_LEN, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        self.phi = nn.Linear(1, 7 * 7 * 64, bias=False)
        self.phi_bias = nn.Parameter(torch.zeros(7 * 7 * 64))
        self.fc = nn.Linear(7 * 7 * 64, 512)
        
        # action value distribution
        self.fc_q = nn.Linear(512, N_ACTIONS) 
            
        # 파라미터 값 초기화 코드는 다음과 같이 간결하게 바꿀 수 있습니다.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.orthogonal_(m.weight, gain = np.sqrt(2))
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
            

    def forward(self, x):
        # x.size(0) : minibatch size
        mb_size = x.size(0)
        # x는 (m, 84, 84, 4)의 tensor
        x = self.feature_extraction(x / 255.0) # (m, 7 * 7 * 64)
        tau = torch.rand(N_QUANT,1) # (N_QUANT, 1)
        quants = torch.arange(0, N_QUANT, 1.0)
        if USE_GPU:
            tau = tau.cuda()
            quants = quants.cuda()
        cos_trans = torch.cos(quants * tau * 3.141592).unsqueeze(2) # (N_QUANT, N_QUANT, 1)
        rand_feat = F.relu(self.phi(cos_trans).mean(dim=1) + self.phi_bias.unsqueeze(0)).unsqueeze(0) 
        # (1, N_QUANT, 7 * 7 * 64)
        x = x.view(x.size(0), -1).unsqueeze(1) # (m, 1, 7 * 7 * 64)
        x = x * rand_feat # (m, N_QUANT, 7 * 7 * 64)
        x = F.relu(self.fc(x)) # (m, N_QUANT, 512)
        
        # note that output of IQN is quantile values of value distribution
        action_value = self.fc_q(x).transpose(1, 2) # (m, N_ACTIONS, N_QUANT)

        return action_value, tau

    def save(self, PATH):
        torch.save(self.state_dict(),PATH)

    def load(self, PATH):
        self.load_state_dict(torch.load(PATH))

## IQN 정의하기

In [4]:
class DQN(object):
    def __init__(self):
        self.pred_net, self.target_net = ConvNet(), ConvNet()
        # sync eval target
        self.update_target(self.target_net, self.pred_net, 1.0)
        # use gpu
        if USE_GPU:
            self.pred_net.cuda()
            self.target_net.cuda()
            
        # simulator step conter
        self.memory_counter = 0
        # target network step counter
        self.learn_step_counter = 0
        
        # ceate the replay buffer
        self.replay_buffer = ReplayBuffer(MEMORY_CAPACITY)
        
        # define optimizer
        self.optimizer = torch.optim.Adam(self.pred_net.parameters(), lr=LR)
        
    def update_target(self, target, pred, update_rate):
        # update target network parameters using predcition network
        for target_param, pred_param in zip(target.parameters(), pred.parameters()):
            target_param.data.copy_((1.0 - update_rate) \
                                    * target_param.data + update_rate*pred_param.data)
            
    def save_model(self):
        # save prediction network and target network
        self.pred_net.save(PRED_PATH)
        self.target_net.save(TARGET_PATH)

    def load_model(self):
        # load prediction network and target network
        self.pred_net.load(PRED_PATH)
        self.target_net.load(TARGET_PATH)

    def choose_action(self, x, EPSILON):
        x = torch.FloatTensor(x)
        if USE_GPU:
            x = x.cuda()

        if np.random.uniform() >= EPSILON:
            # greedy case
            action_value, tau = self.pred_net(x) # (N_ENVS, N_ACTIONS, N_QUANT)
            action_value = action_value.mean(dim=2)
            action = torch.argmax(action_value, dim=1).data.cpu().numpy()
        else:
            # random exploration case
            action = np.random.randint(0, N_ACTIONS, (x.size(0)))
        return action

    def store_transition(self, s, a, r, s_, done):
        self.memory_counter += 1
        self.replay_buffer.add(s, a, r, s_, float(done))

    def learn(self):
        self.learn_step_counter += 1
        # target parameter update
        if self.learn_step_counter % TARGET_REPLACE_ITER == 0:
            self.update_target(self.target_net, self.pred_net, 1e-2)
    
        b_s, b_a, b_r,b_s_, b_d = self.replay_buffer.sample(BATCH_SIZE)
        b_w, b_idxes = np.ones_like(b_r), None
            
        b_s = torch.FloatTensor(b_s)
        b_a = torch.LongTensor(b_a)
        b_r = torch.FloatTensor(b_r)
        b_s_ = torch.FloatTensor(b_s_)
        b_d = torch.FloatTensor(b_d)

        if USE_GPU:
            b_s, b_a, b_r, b_s_, b_d = b_s.cuda(), b_a.cuda(), b_r.cuda(), b_s_.cuda(), b_d.cuda()

        # action value distribution prediction
        q_eval, q_eval_tau = self.pred_net(b_s) # (m, N_ACTIONS, N_QUANT), (N_QUANT, 1)
        mb_size = q_eval.size(0)
        q_eval = torch.stack([q_eval[i].index_select(0, b_a[i]) for i in range(mb_size)]).squeeze(1) 
        # (m, N_QUANT)
        q_eval = q_eval.unsqueeze(2) # (m, N_QUANT, 1)
        # note that dim 1 is for present quantile, dim 2 si for next quantile
        
        # get next state value
        q_next, q_next_tau = self.target_net(b_s_) # (m, N_ACTIONS, N_QUANT), (N_QUANT, 1)
        best_actions = q_next.mean(dim=2).argmax(dim=1) # (m)
        q_next = torch.stack([q_next[i].index_select(0, best_actions[i]) for i in range(mb_size)]).squeeze(1)
        # (m, N_QUANT)
        q_target = b_r.unsqueeze(1) + GAMMA * (1. -b_d.unsqueeze(1)) * q_next 
        # (m, N_QUANT)
        q_target = q_target.unsqueeze(1).detach() # (m , 1, N_QUANT)

        # quantile Huber loss
        u = q_target.detach() - q_eval # (m, N_QUANT, N_QUANT)
        tau = q_eval_tau.unsqueeze(0) # (1, N_QUANT, 1)
        # note that tau is for present quantile
        weight = torch.abs(tau - u.le(0.).float()) # (m, N_QUANT, N_QUANT)
        loss = F.smooth_l1_loss(q_eval, q_target.detach(), reduction='none')
        # (m, N_QUANT, N_QUANT)
        loss = torch.mean(weight * loss, dim=1).mean(dim=1)
        
        # calc importance weighted loss
        b_w = torch.Tensor(b_w)
        if USE_GPU:
            b_w = b_w.cuda()
        loss = torch.mean(b_w*loss)
        
        # backprop loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

## 학습

In [None]:
dqn = DQN()

# model load with check
if LOAD and os.path.isfile(PRED_PATH) and os.path.isfile(TARGET_PATH):
    dqn.load_model()
    pkl_file = open(RESULT_PATH,'rb')
    result = pickle.load(pkl_file)
    pkl_file.close()
    print('Load complete!')
else:
    result = []
    print('Initialize results!')

print('Collecting experience...')

# episode step for accumulate reward 
epinfobuf = deque(maxlen=100)
# check learning time
start_time = time.time()

# env reset
s = np.array(env.reset())

for step in tqdm(range(1, STEP_NUM//N_ENVS+1)):
    a = dqn.choose_action(s, EPSILON)

    # take action and get next state
    s_, r, done, infos = env.step(a)
    # log arrange
    for info in infos:
        maybeepinfo = info.get('episode')
        if maybeepinfo: epinfobuf.append(maybeepinfo)
    s_ = np.array(s_)

    # clip rewards for numerical stability
    clip_r = np.sign(r)

    # store the transition
    for i in range(N_ENVS):
        dqn.store_transition(s[i], a[i], clip_r[i], s_[i], done[i])

    # annealing the epsilon(exploration strategy)
    if step <= int(1e+3):
        # linear annealing to 0.9 until million step
        EPSILON -= 0.9/1e+3
    elif step <= int(1e+4):
        # linear annealing to 0.99 until the end
        EPSILON -= 0.09/(1e+4 - 1e+3)

    # if memory fill 50K and mod 4 = 0(for speed issue), learn pred net
    if (LEARN_START <= dqn.memory_counter) and (dqn.memory_counter % LEARN_FREQ == 0):
        dqn.learn()

    # print log and save
    if step % SAVE_FREQ == 0:
        # check time interval
        time_interval = round(time.time() - start_time, 2)
        # calc mean return
        mean_100_ep_return = round(np.mean([epinfo['r'] for epinfo in epinfobuf]),2)
        result.append(mean_100_ep_return)
        # print log
        print('Used Step:',dqn.memory_counter,
              'EPS: ', round(EPSILON, 3),
              '| Mean ep 100 return: ', mean_100_ep_return,
              '| Used Time:',time_interval)
        # save model
        dqn.save_model()
        pkl_file = open(RESULT_PATH, 'wb')
        pickle.dump(np.array(result), pkl_file)
        pkl_file.close()

    s = s_

    if RENDERING:
        env.render()

  0%|          | 0/2500000 [00:00<?, ?it/s]

Initialize results!
Collecting experience...


  0%|          | 1004/2500000 [00:18<15:36:33, 44.47it/s]

Used Step: 4000 EPS:  0.1 | Mean ep 100 return:  1.43 | Used Time: 18.28


  0%|          | 2008/2500000 [00:40<14:59:37, 46.28it/s]

Used Step: 8000 EPS:  0.09 | Mean ep 100 return:  1.45 | Used Time: 40.23


  0%|          | 3005/2500000 [01:02<14:59:03, 46.29it/s]

Used Step: 12000 EPS:  0.08 | Mean ep 100 return:  1.34 | Used Time: 61.92


  0%|          | 4007/2500000 [01:24<15:29:39, 44.75it/s]

Used Step: 16000 EPS:  0.07 | Mean ep 100 return:  1.24 | Used Time: 83.89


  0%|          | 5004/2500000 [01:45<15:52:30, 43.66it/s]

Used Step: 20000 EPS:  0.06 | Mean ep 100 return:  1.24 | Used Time: 105.55


  0%|          | 6004/2500000 [02:07<15:03:27, 46.01it/s]

Used Step: 24000 EPS:  0.05 | Mean ep 100 return:  1.25 | Used Time: 127.37


  0%|          | 7009/2500000 [02:29<14:39:52, 47.22it/s]

Used Step: 28000 EPS:  0.04 | Mean ep 100 return:  1.37 | Used Time: 149.28


  0%|          | 8004/2500000 [02:51<15:47:19, 43.84it/s]

Used Step: 32000 EPS:  0.03 | Mean ep 100 return:  1.83 | Used Time: 171.03


  0%|          | 9004/2500000 [03:12<15:34:18, 44.44it/s]

Used Step: 36000 EPS:  0.02 | Mean ep 100 return:  2.14 | Used Time: 192.84


  0%|          | 10008/2500000 [03:34<14:48:48, 46.69it/s]

Used Step: 40000 EPS:  0.01 | Mean ep 100 return:  2.4 | Used Time: 214.79


  0%|          | 11008/2500000 [03:56<16:09:02, 42.81it/s]

Used Step: 44000 EPS:  0.01 | Mean ep 100 return:  2.64 | Used Time: 236.38


  0%|          | 12008/2500000 [04:18<14:56:32, 46.25it/s]

Used Step: 48000 EPS:  0.01 | Mean ep 100 return:  3.04 | Used Time: 258.52


  1%|          | 13008/2500000 [04:40<15:15:51, 45.26it/s]

Used Step: 52000 EPS:  0.01 | Mean ep 100 return:  3.27 | Used Time: 280.29


  1%|          | 14008/2500000 [05:02<15:04:21, 45.82it/s]

Used Step: 56000 EPS:  0.01 | Mean ep 100 return:  3.23 | Used Time: 301.92


  1%|          | 15008/2500000 [05:23<15:12:06, 45.41it/s]

Used Step: 60000 EPS:  0.01 | Mean ep 100 return:  3.34 | Used Time: 323.67


  1%|          | 16004/2500000 [05:45<15:08:20, 45.58it/s]

Used Step: 64000 EPS:  0.01 | Mean ep 100 return:  3.54 | Used Time: 345.35


  1%|          | 17004/2500000 [06:07<15:19:52, 44.99it/s]

Used Step: 68000 EPS:  0.01 | Mean ep 100 return:  3.74 | Used Time: 367.0


  1%|          | 18008/2500000 [06:29<15:22:42, 44.83it/s]

Used Step: 72000 EPS:  0.01 | Mean ep 100 return:  3.82 | Used Time: 389.41


  1%|          | 19008/2500000 [06:51<15:13:14, 45.28it/s]

Used Step: 76000 EPS:  0.01 | Mean ep 100 return:  4.11 | Used Time: 411.15


  1%|          | 20008/2500000 [07:13<16:23:48, 42.01it/s]

Used Step: 80000 EPS:  0.01 | Mean ep 100 return:  4.46 | Used Time: 433.05


  1%|          | 21008/2500000 [07:35<15:40:16, 43.94it/s]

Used Step: 84000 EPS:  0.01 | Mean ep 100 return:  4.74 | Used Time: 454.95


  1%|          | 22008/2500000 [07:56<14:55:27, 46.12it/s]

Used Step: 88000 EPS:  0.01 | Mean ep 100 return:  5.09 | Used Time: 476.43


  1%|          | 23009/2500000 [08:18<14:56:23, 46.06it/s]

Used Step: 92000 EPS:  0.01 | Mean ep 100 return:  5.48 | Used Time: 498.08


  1%|          | 24004/2500000 [08:39<14:52:45, 46.22it/s]

Used Step: 96000 EPS:  0.01 | Mean ep 100 return:  5.82 | Used Time: 519.41


  1%|          | 25009/2500000 [09:01<14:52:49, 46.20it/s]

Used Step: 100000 EPS:  0.01 | Mean ep 100 return:  6.3 | Used Time: 540.95


  1%|          | 26004/2500000 [09:22<14:56:05, 46.01it/s]

Used Step: 104000 EPS:  0.01 | Mean ep 100 return:  6.69 | Used Time: 562.24


  1%|          | 27004/2500000 [09:43<15:37:03, 43.98it/s]

Used Step: 108000 EPS:  0.01 | Mean ep 100 return:  7.07 | Used Time: 583.88


  1%|          | 28009/2500000 [10:05<14:44:51, 46.56it/s]

Used Step: 112000 EPS:  0.01 | Mean ep 100 return:  7.33 | Used Time: 605.59


  1%|          | 29005/2500000 [10:27<14:45:56, 46.49it/s]

Used Step: 116000 EPS:  0.01 | Mean ep 100 return:  7.67 | Used Time: 627.02


  1%|          | 30006/2500000 [10:48<15:52:16, 43.23it/s]

Used Step: 120000 EPS:  0.01 | Mean ep 100 return:  8.02 | Used Time: 648.34


  1%|          | 31007/2500000 [11:09<14:59:05, 45.77it/s]

Used Step: 124000 EPS:  0.01 | Mean ep 100 return:  8.25 | Used Time: 669.83


  1%|▏         | 32009/2500000 [11:31<14:47:13, 46.36it/s]

Used Step: 128000 EPS:  0.01 | Mean ep 100 return:  8.58 | Used Time: 691.48


  1%|▏         | 33004/2500000 [11:52<14:50:11, 46.19it/s]

Used Step: 132000 EPS:  0.01 | Mean ep 100 return:  8.76 | Used Time: 712.88


  1%|▏         | 34005/2500000 [12:14<15:11:26, 45.09it/s]

Used Step: 136000 EPS:  0.01 | Mean ep 100 return:  9.07 | Used Time: 734.35


  1%|▏         | 35006/2500000 [12:35<14:32:15, 47.10it/s]

Used Step: 140000 EPS:  0.01 | Mean ep 100 return:  9.32 | Used Time: 755.74


  1%|▏         | 36006/2500000 [12:57<14:57:19, 45.77it/s]

Used Step: 144000 EPS:  0.01 | Mean ep 100 return:  9.62 | Used Time: 777.18


  1%|▏         | 37006/2500000 [13:18<14:52:35, 45.99it/s]

Used Step: 148000 EPS:  0.01 | Mean ep 100 return:  9.69 | Used Time: 798.61


  2%|▏         | 38006/2500000 [13:40<14:47:55, 46.21it/s]

Used Step: 152000 EPS:  0.01 | Mean ep 100 return:  10.12 | Used Time: 819.96


  2%|▏         | 39006/2500000 [14:01<14:43:33, 46.42it/s]

Used Step: 156000 EPS:  0.01 | Mean ep 100 return:  10.47 | Used Time: 841.3


  2%|▏         | 40006/2500000 [14:22<15:01:17, 45.49it/s]

Used Step: 160000 EPS:  0.01 | Mean ep 100 return:  10.81 | Used Time: 862.75


  2%|▏         | 41006/2500000 [14:44<14:44:21, 46.34it/s]

Used Step: 164000 EPS:  0.01 | Mean ep 100 return:  11.01 | Used Time: 883.96


  2%|▏         | 42006/2500000 [15:05<14:41:16, 46.49it/s]

Used Step: 168000 EPS:  0.01 | Mean ep 100 return:  11.48 | Used Time: 905.43


  2%|▏         | 43007/2500000 [15:26<14:55:15, 45.74it/s]

Used Step: 172000 EPS:  0.01 | Mean ep 100 return:  11.67 | Used Time: 926.81


  2%|▏         | 44008/2500000 [15:48<15:08:29, 45.06it/s]

Used Step: 176000 EPS:  0.01 | Mean ep 100 return:  11.55 | Used Time: 948.2


  2%|▏         | 45008/2500000 [16:09<15:08:39, 45.03it/s]

Used Step: 180000 EPS:  0.01 | Mean ep 100 return:  11.62 | Used Time: 969.69


  2%|▏         | 46006/2500000 [16:31<14:56:37, 45.62it/s]

Used Step: 184000 EPS:  0.01 | Mean ep 100 return:  11.94 | Used Time: 991.03


  2%|▏         | 47006/2500000 [16:52<14:47:04, 46.09it/s]

Used Step: 188000 EPS:  0.01 | Mean ep 100 return:  12.36 | Used Time: 1012.22


  2%|▏         | 48007/2500000 [17:13<14:33:39, 46.78it/s]

Used Step: 192000 EPS:  0.01 | Mean ep 100 return:  12.62 | Used Time: 1033.48


  2%|▏         | 49007/2500000 [17:34<15:04:44, 45.15it/s]

Used Step: 196000 EPS:  0.01 | Mean ep 100 return:  12.59 | Used Time: 1054.81


  2%|▏         | 50007/2500000 [17:56<14:41:12, 46.34it/s]

Used Step: 200000 EPS:  0.01 | Mean ep 100 return:  12.51 | Used Time: 1076.17


  2%|▏         | 51007/2500000 [18:17<14:44:30, 46.15it/s]

Used Step: 204000 EPS:  0.01 | Mean ep 100 return:  12.79 | Used Time: 1097.58


  2%|▏         | 52007/2500000 [18:39<14:47:39, 45.96it/s]

Used Step: 208000 EPS:  0.01 | Mean ep 100 return:  12.78 | Used Time: 1119.02


  2%|▏         | 53004/2500000 [19:00<14:53:09, 45.66it/s]

Used Step: 212000 EPS:  0.01 | Mean ep 100 return:  12.86 | Used Time: 1140.3


  2%|▏         | 54005/2500000 [19:21<14:33:14, 46.68it/s]

Used Step: 216000 EPS:  0.01 | Mean ep 100 return:  13.0 | Used Time: 1161.54


  2%|▏         | 55007/2500000 [19:42<14:36:52, 46.47it/s]

Used Step: 220000 EPS:  0.01 | Mean ep 100 return:  13.1 | Used Time: 1182.61


  2%|▏         | 56005/2500000 [20:03<14:36:03, 46.50it/s]

Used Step: 224000 EPS:  0.01 | Mean ep 100 return:  13.08 | Used Time: 1203.88


  2%|▏         | 57007/2500000 [20:25<14:42:33, 46.13it/s]

Used Step: 228000 EPS:  0.01 | Mean ep 100 return:  13.46 | Used Time: 1224.99


  2%|▏         | 58006/2500000 [20:46<14:23:31, 47.13it/s]

Used Step: 232000 EPS:  0.01 | Mean ep 100 return:  13.81 | Used Time: 1246.03


  2%|▏         | 59007/2500000 [21:07<14:31:18, 46.69it/s]

Used Step: 236000 EPS:  0.01 | Mean ep 100 return:  14.35 | Used Time: 1267.2


  2%|▏         | 60007/2500000 [21:28<14:52:00, 45.59it/s]

Used Step: 240000 EPS:  0.01 | Mean ep 100 return:  14.6 | Used Time: 1288.38


  2%|▏         | 61005/2500000 [21:49<14:45:57, 45.88it/s]

Used Step: 244000 EPS:  0.01 | Mean ep 100 return:  14.89 | Used Time: 1309.53


  2%|▏         | 62005/2500000 [22:10<14:23:22, 47.06it/s]

Used Step: 248000 EPS:  0.01 | Mean ep 100 return:  15.3 | Used Time: 1330.75


  3%|▎         | 63009/2500000 [22:32<14:29:48, 46.70it/s]

Used Step: 252000 EPS:  0.01 | Mean ep 100 return:  15.26 | Used Time: 1352.12


  3%|▎         | 64005/2500000 [22:54<18:18:40, 36.95it/s]

Used Step: 256000 EPS:  0.01 | Mean ep 100 return:  15.54 | Used Time: 1374.15


  3%|▎         | 65007/2500000 [23:20<17:54:38, 37.76it/s]

Used Step: 260000 EPS:  0.01 | Mean ep 100 return:  15.97 | Used Time: 1399.86


  3%|▎         | 66007/2500000 [23:45<16:52:23, 40.07it/s]

Used Step: 264000 EPS:  0.01 | Mean ep 100 return:  16.56 | Used Time: 1425.26


  3%|▎         | 67005/2500000 [24:10<17:57:15, 37.64it/s]

Used Step: 268000 EPS:  0.01 | Mean ep 100 return:  16.62 | Used Time: 1450.85


  3%|▎         | 68005/2500000 [24:36<18:04:30, 37.37it/s]

Used Step: 272000 EPS:  0.01 | Mean ep 100 return:  17.06 | Used Time: 1476.62


  3%|▎         | 69006/2500000 [25:02<17:36:53, 38.34it/s]

Used Step: 276000 EPS:  0.01 | Mean ep 100 return:  17.26 | Used Time: 1502.22


  3%|▎         | 70007/2500000 [25:28<17:27:07, 38.68it/s]

Used Step: 280000 EPS:  0.01 | Mean ep 100 return:  17.38 | Used Time: 1527.94


  3%|▎         | 70933/2500000 [25:51<17:32:00, 38.48it/s]

## 결과 시각화

In [None]:
plt.plot(range(len(entire_ep_rs)), entire_ep_rs)
plt.tight_layout()
plt.show()

In [None]:
from matplotlib import animation

def display_frames_as_gif(frames):
    patch = plt.imshow(frames[0])
    plt.axis('off')
    def animate(i):
        patch.set_data(frames[i])
        
    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=5)
    anim.save('./iqn_breakout_result.gif', writer='imagemagick', fps=30)

In [None]:
env = wrap(gym.make('BreakoutNoFrameskip-v4'))
s = np.array(env.reset())
total_reward = 0
frames = []

for t in range(10000):
    # Render into buffer. 
    frames.append(env.render(mode = 'rgb_array'))
    a, v, l = dqn.choose_action(np.expand_dims(s,axis=0))
    # take action and get next state
    s_, r, done, info = env.step(a)
    s_ = np.array(s_)
    total_reward += r
    if done:
        break
    s = s_
env.close()
print('Total Reward : %.2f'%total_reward)
display_frames_as_gif(frames)

![alt text](./iqn_breakout_result.gif "segment")