## 필요한 모듈 설치 확인

In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import numpy as np

import random
import os
import pickle
import time
from collections import deque

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
sns.set()

from wrappers import wrap, wrap_cover, SubprocVecEnv
from runner import Runner

## 하이퍼 파라미터 정의하기

In [2]:
'''PPO Settings'''
TRAJ_LEN = 1000
N_OPT_EPOCHS = 4
ENT_COEF = 1e-2
CLIP_RANGE = 0.1
LAMBDA = 0.95

'''Environment Settings'''
# sequential images to define state
STATE_LEN = 4
# openai gym env name
ENV_NAME = 'BreakoutNoFrameskip-v4'
# number of environments for A2C
N_ENVS = 4
# define gym 
env = SubprocVecEnv([wrap_cover(ENV_NAME) for i in range(N_ENVS)])
# check gym setting
N_ACTIONS = env.action_space.n;print('N_ACTIONS : ',N_ACTIONS) #  6
N_STATES = env.observation_space.shape;print('N_STATES : ',N_STATES) # (4, 84, 84)
# Total simulation step
N_STEP = 10**7
# gamma for MDP
GAMMA = 0.99
# visualize for agent playing
RENDERING = False

'''Training settings'''
# check GPU usage
USE_GPU = torch.cuda.is_available()
print('USE GPU: '+str(USE_GPU))
# mini-batch size
BATCH_SIZE = 32
# learning rage
LR = 1e-4
# clip gradient
MAX_GRAD_NORM = 0.1
ZERO_GP = True
# log optimization
LOG_OPT = False

'''Save&Load Settings'''
# log frequency
LOG_FREQ = 1
# check save/load
SAVE = True
LOAD = False
# paths for predction net, target net, result log
NET_PATH = './data/model/ppo_net.pkl'
DIS_PATH = './data/model/gail_dis_net.pkl'



N_ACTIONS :  4
N_STATES :  (4, 84, 84)
USE GPU: True




## 네트워크 구조 정의하기

In [3]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        # architecture def
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(STATE_LEN, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        self.fc = nn.Linear(7 * 7 * 64, 256)
        # actor
        self.actor = nn.Linear(256, N_ACTIONS)
        # critic
        self.critic = nn.Linear(256, 1)
            
        # parameter initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain = np.sqrt(2))
                nn.init.constant_(m.bias, 0.0)
            

    def forward(self, x):
        # x is a tensor of (m, 4, 84, 84)
        x = self.feature_extraction(x / 255.0)
        # x.size(0) : mini-batch size
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc(x))
        # use log_softmax for numerical stability
        action_log_prob = F.log_softmax(self.actor(x), dim=1)
        state_value = self.critic(x)

        return action_log_prob, state_value

    def save(self, PATH):
        torch.save(self.state_dict(),PATH)

    def load(self, PATH):
        self.load_state_dict(torch.load(PATH))
        
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # architecture def
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(STATE_LEN, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        self.fc1 = nn.Linear(7 * 7 * 64, 256)
        self.action_feature = nn.Linear(N_ACTIONS, 256)
        # actor
        self.fc2 = nn.Linear(256, 1)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.orthogonal_(m.weight, gain = np.sqrt(2))
                nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0.0)
            

    def forward(self, x, a):
        # x is a tensor of (m, 4, 84, 84)
        x = self.feature_extraction(x / 255.0)
        # x.size(0) : mini-batch size
        x = x.view(x.size(0), -1)
        a_onehot = torch.zeros(x.size(0), N_ACTIONS)
        if USE_GPU:
            a_onehot = a_onehot.cuda(device=0)
        a_onehot.scatter_(1, a.unsqueeze(1), 1)
        x = F.relu(self.fc1(x) * F.leaky_relu(self.action_feature(a_onehot), negative_slope=2e-1))
        action_value = self.fc2(x)

        return action_value

    def save(self, PATH):
        torch.save(self.state_dict(),PATH)

    def load(self, PATH):
        self.load_state_dict(torch.load(PATH))

## PPO 정의하기

In [4]:
class PPO:
    def __init__(self):
        self.net = ConvNet()
        self.dis_net = Discriminator()
        # use gpu
        if USE_GPU:
            self.net = self.net.cuda()
            self.dis_net = self.dis_net.cuda()
            
        # simulator step conter
        self.memory_counter = 0
        
        # create the replay buffer for expert
        with open( "replay.pkl", "rb" ) as f:
            self.expert_replay_buffer = pickle.load(f)
        
        # define optimizer
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=LR)
        self.dis_opt = torch.optim.RMSprop(self.dis_net.parameters(), lr=LR, alpha=0.9)
        
        # ppo clip range
        self.clip_range = CLIP_RANGE
        
    def save_model(self):
        self.net.cpu()
        self.dis_net.cpu()
        self.net.save(NET_PATH)
        self.dis_net.save(DIS_PATH)
        if USE_GPU:
            self.net.cuda()
            self.dis_net.cuda()
            
    def load_model(self):
        self.net.cpu()
        self.dis_net.cpu()
        self.net.load(NET_PATH)
        self.dis_net.load(DIS_PATH)
        if USE_GPU:
            self.net.cuda()
            self.dis_net.cuda()
        
    def choose_action(self, x):
        self.memory_counter += 1
        # Assume that x is a np.array of shape (nenvs, 4, 84, 84)
        x = torch.FloatTensor(x)
        if USE_GPU:
            x = x.cuda()
        # get action log probs and state values
        action_log_probs, state_values = self.net(x) # (nenvs, N_ACTIONS)
        probs = F.softmax(action_log_probs, dim=1).data.cpu().numpy()
        probs = (probs+1e-8)/np.sum((probs+1e-8), axis=1, keepdims=True)
        # sample actions
        actions = np.array([np.random.choice(N_ACTIONS,p=probs[i]) for i in range(len(probs))])
        # convert tensor to np.array
        action_log_probs , state_values = action_log_probs.data.cpu().numpy() , state_values.squeeze(1).data.cpu().numpy()
        # calc selected logprob
        selected_log_probs = np.array([action_log_probs[i][actions[i]] for i in range(len(probs))])
        return actions, state_values, selected_log_probs
    
    def reward_dis(self, s, a):
        s = torch.FloatTensor(s)
        a = torch.LongTensor(a)
        
        if USE_GPU:
            s , a = s.cuda(), a.cuda()
        
        d_reward = -torch.log(torch.sigmoid(self.dis_net(s, a)))
        return d_reward.squeeze(1).data.cpu().numpy()
    
    def learn_dis(self, obs, returns, masks, actions, values, selected_log_probs):
        # optimize discriminator
        b_s = torch.FloatTensor(obs)
        b_s.requires_grad = True
        b_a = torch.LongTensor(actions)
        
        # data sample from expert experience replay
        e_state_memory, e_action_memory, e_reward_memory, \
        e_next_state_memory, e_done = self.expert_replay_buffer.sample(len(b_s))
        
        e_s = torch.FloatTensor(e_state_memory)
        e_s.requires_grad = True
        e_a = torch.LongTensor(e_action_memory)
        
        if USE_GPU:
            b_s, b_a, e_s, e_a = b_s.cuda(device=0), b_a.cuda(device=0), e_s.cuda(device=0), e_a.cuda(device=0)
        
        d_policy = self.dis_net(b_s, b_a).squeeze(1) # (m)
        d_expert = self.dis_net(e_s, e_a).squeeze(1) # (m)
        
        d_loss = -torch.log(torch.sigmoid(d_policy)).mean() -torch.log(1-torch.sigmoid(d_expert)).mean()
        loss = d_loss
        
        # calc gradient penalty
        if ZERO_GP:
            b_grad = autograd.grad(d_policy, b_s, create_graph=True,
                            grad_outputs=torch.ones_like(d_policy),
                            retain_graph=True, only_inputs=True)[0].view(BATCH_SIZE, -1)
            e_grad = autograd.grad(d_expert, e_s, create_graph=True,
                            grad_outputs=torch.ones_like(d_expert),
                            retain_graph=True, only_inputs=True)[0].view(BATCH_SIZE, -1)
            b_grad = b_grad.norm(dim=1)
            e_grad = e_grad.norm(dim=1)
            gp_loss = 1e+5 * ((b_grad)**2 + (e_grad)**2).mean()
            loss += gp_loss
        
        self.dis_opt.zero_grad()
        loss.backward()
        self.dis_opt.step()
        
        return round(float(d_loss.item()), 4), round(float(gp_loss.item()), 4)

    def learn(self, obs, returns, masks, actions, values, selected_log_probs):
        
        # calculate the advantages
        advs = returns - values
        advs = (advs - advs.mean())/(advs.std() + 1e-8)
        
        # np.array -> torch.Tensor
        obs = torch.FloatTensor(obs) # (m, 4, 84, 84)
        returns = torch.FloatTensor(returns) # (m)
        advs = torch.FloatTensor(advs) # (m)
        actions = torch.LongTensor(actions) # (m)
        selected_log_probs = torch.FloatTensor(selected_log_probs) # (m)
        values = torch.FloatTensor(values) # (m)
        if USE_GPU:
            obs = obs.cuda()
            returns = returns.cuda()
            advs = advs.cuda()
            actions = actions.cuda()
            selected_log_probs = selected_log_probs.cuda()
            values = values.cuda()
        
        # get action log probs and state values
        action_log_probs, state_values = self.net(obs)
        # (m, N_ACTIONS), (m, 1)
        
        # calc probs
        probs = F.softmax(action_log_probs, dim=1)
        # (m, N_ACTIONS)
        
        # calc entropy loss
        ent_loss = ENT_COEF *((action_log_probs * probs).sum(dim=1)).mean()
        # (1)
        
        # calc log probs
        cur_log_probs = action_log_probs.gather(1,actions.unsqueeze(1))
        # cur : (m, 1)
        ratio = torch.exp(cur_log_probs.squeeze(1)-selected_log_probs)
        # (m)
        
        # actor loss
        surr1 = ratio * advs # (m)
        surr2 = torch.clamp(ratio, 1.0 - self.clip_range, 1.0 + self.clip_range)*advs # (m)
        actor_loss = -torch.min(surr1, surr2).mean() # (1)
        # critic loss
        critic_loss = F.smooth_l1_loss(state_values.squeeze(1), returns) # (1)

        loss = actor_loss + critic_loss + ent_loss # (1)
        
        actor_loss, critic_loss, ent_loss, total_loss = actor_loss.data.cpu().numpy(), \
        critic_loss.data.cpu().numpy(), ent_loss.data.cpu().numpy(), loss.data.cpu().numpy()

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.net.parameters(), MAX_GRAD_NORM)
        self.optimizer.step()
        
        return round(float(actor_loss), 4), round(float(critic_loss), 4),\
    round(float(ent_loss), 4), round(float(total_loss), 4)

## 학습

In [None]:
ppo = PPO()
runner = Runner(env=env, model=ppo, nsteps=TRAJ_LEN, gamma=GAMMA, lam=LAMBDA)

# model load with check
if LOAD and os.path.isfile(PRED_PATH) and os.path.isfile(TARGET_PATH):
    ppo.load_model()
    pkl_file = open(RESULT_PATH,'rb')
    result = pickle.load(pkl_file)
    pkl_file.close()
    print('Load complete!')
else:
    result = []
    print('Initialize results!')

print('Collecting experience...')

# episode step for accumulate reward 
epinfobuf = deque(maxlen=100)
# in PPO, we iterate over optimization step
nbatch = N_ENVS * TRAJ_LEN
nupdates = N_STEP// nbatch
# check learning time
start_time = time.time()

for update in range(1, nupdates+1):
    # get minibatch
    obs, returns, masks, actions, values, neglogpacs, epinfos = runner.run()
    epinfobuf.extend(epinfos)
    
    inds = np.arange(nbatch)
    for start in range(0, nbatch, BATCH_SIZE):
        end = start + BATCH_SIZE
        mbinds = inds[start:end]
        slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
        d_loss, gp_loss = ppo.learn_dis(*slices)
    
    # calculate loss
    inds = np.arange(nbatch)
    for _ in range(N_OPT_EPOCHS):
        a_losses, c_losses, e_losses, t_losses = list(), list(), list(), list()
        # shuffle indices for i.i.d.
        np.random.shuffle(inds)
        # 0 to batch_size with batch_train_size step
        for start in range(0, nbatch, BATCH_SIZE):
            end = start + BATCH_SIZE
            mbinds = inds[start:end]
            slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs))
            actor_loss, critic_loss, ent_loss, total_loss = ppo.learn(*slices)
            # save opt log
            a_losses.append(actor_loss)
            c_losses.append(critic_loss)
            e_losses.append(ent_loss)
            t_losses.append(total_loss)
        # print opt log
        if LOG_OPT:
            print('Iter ',_,
                 'actor loss : ',round(float(np.mean(a_losses)), 3),
                 'critic loss : ', round(float(np.mean(c_losses)), 3),
                 'ent loss : ', round(float(np.mean(e_losses)), 3),
                 'total loss : ', round(float(np.mean(t_losses)), 3))
            
    if update % LOG_FREQ == 0:
        # print log and save
        # check time interval
        time_interval = round(time.time() - start_time, 2)
        # calc mean return
        mean_100_ep_return = round(np.mean([epinfo['r'] for epinfo in epinfobuf]),2)
        result.append(mean_100_ep_return)
        # print epi log
        print('N update: ',update,
              '| Mean ep 100 return: ', mean_100_ep_return,
              '/Used Time:',time_interval,
              '/Used Step:',ppo.memory_counter*N_ENVS,
             '| Dis loss : ',d_loss,
             '| GP loss : ', gp_loss)
        # save model
        if SAVE:
            ppo.save_model()

Initialize results!
Collecting experience...
N update:  1 | Mean ep 100 return:  1.71 /Used Time: 9.48 /Used Step: 4004 | Dis loss :  1.2009 | GP loss :  0.3298
N update:  2 | Mean ep 100 return:  1.39 /Used Time: 18.81 /Used Step: 8008 | Dis loss :  0.9293 | GP loss :  0.2144
N update:  3 | Mean ep 100 return:  1.39 /Used Time: 28.15 /Used Step: 12012 | Dis loss :  0.9029 | GP loss :  0.2175
N update:  4 | Mean ep 100 return:  1.34 /Used Time: 37.51 /Used Step: 16016 | Dis loss :  0.8206 | GP loss :  0.1924
N update:  5 | Mean ep 100 return:  1.39 /Used Time: 46.9 /Used Step: 20020 | Dis loss :  0.7839 | GP loss :  0.1717
N update:  6 | Mean ep 100 return:  1.57 /Used Time: 56.43 /Used Step: 24024 | Dis loss :  0.669 | GP loss :  0.1945
N update:  7 | Mean ep 100 return:  1.77 /Used Time: 65.75 /Used Step: 28028 | Dis loss :  0.753 | GP loss :  0.2014
N update:  8 | Mean ep 100 return:  1.88 /Used Time: 75.05 /Used Step: 32032 | Dis loss :  0.7546 | GP loss :  0.1574
N update:  9 | Me

N update:  69 | Mean ep 100 return:  8.99 /Used Time: 626.08 /Used Step: 276276 | Dis loss :  0.7333 | GP loss :  0.1513
N update:  70 | Mean ep 100 return:  8.73 /Used Time: 635.09 /Used Step: 280280 | Dis loss :  0.6195 | GP loss :  0.1391
N update:  71 | Mean ep 100 return:  8.66 /Used Time: 644.13 /Used Step: 284284 | Dis loss :  0.8874 | GP loss :  0.144
N update:  72 | Mean ep 100 return:  8.94 /Used Time: 652.92 /Used Step: 288288 | Dis loss :  0.7037 | GP loss :  0.136
N update:  73 | Mean ep 100 return:  8.87 /Used Time: 661.83 /Used Step: 292292 | Dis loss :  0.6532 | GP loss :  0.1307
N update:  74 | Mean ep 100 return:  9.01 /Used Time: 670.69 /Used Step: 296296 | Dis loss :  0.7187 | GP loss :  0.132
N update:  75 | Mean ep 100 return:  8.92 /Used Time: 679.63 /Used Step: 300300 | Dis loss :  0.5949 | GP loss :  0.1399
N update:  76 | Mean ep 100 return:  8.97 /Used Time: 688.65 /Used Step: 304304 | Dis loss :  0.7229 | GP loss :  0.1277
N update:  77 | Mean ep 100 return:

N update:  137 | Mean ep 100 return:  10.39 /Used Time: 1232.54 /Used Step: 548548 | Dis loss :  0.5809 | GP loss :  0.1093
N update:  138 | Mean ep 100 return:  10.48 /Used Time: 1241.6 /Used Step: 552552 | Dis loss :  0.714 | GP loss :  0.1713
N update:  139 | Mean ep 100 return:  10.55 /Used Time: 1250.62 /Used Step: 556556 | Dis loss :  0.7355 | GP loss :  0.1255
N update:  140 | Mean ep 100 return:  10.65 /Used Time: 1259.61 /Used Step: 560560 | Dis loss :  0.4813 | GP loss :  0.1381
N update:  141 | Mean ep 100 return:  10.85 /Used Time: 1268.5 /Used Step: 564564 | Dis loss :  0.6004 | GP loss :  0.1495
N update:  142 | Mean ep 100 return:  10.92 /Used Time: 1277.57 /Used Step: 568568 | Dis loss :  0.7088 | GP loss :  0.1203
N update:  143 | Mean ep 100 return:  11.09 /Used Time: 1286.51 /Used Step: 572572 | Dis loss :  0.7655 | GP loss :  0.1579
N update:  144 | Mean ep 100 return:  11.27 /Used Time: 1295.51 /Used Step: 576576 | Dis loss :  0.5779 | GP loss :  0.119
N update:  1

N update:  204 | Mean ep 100 return:  16.1 /Used Time: 1826.34 /Used Step: 816816 | Dis loss :  0.5717 | GP loss :  0.1495
N update:  205 | Mean ep 100 return:  16.45 /Used Time: 1835.3 /Used Step: 820820 | Dis loss :  0.539 | GP loss :  0.1128
N update:  206 | Mean ep 100 return:  16.32 /Used Time: 1844.17 /Used Step: 824824 | Dis loss :  0.7377 | GP loss :  0.0787
N update:  207 | Mean ep 100 return:  16.52 /Used Time: 1852.94 /Used Step: 828828 | Dis loss :  0.6032 | GP loss :  0.1595
N update:  208 | Mean ep 100 return:  16.71 /Used Time: 1861.72 /Used Step: 832832 | Dis loss :  0.6077 | GP loss :  0.135
N update:  209 | Mean ep 100 return:  16.8 /Used Time: 1870.54 /Used Step: 836836 | Dis loss :  0.7026 | GP loss :  0.1714
N update:  210 | Mean ep 100 return:  16.65 /Used Time: 1879.39 /Used Step: 840840 | Dis loss :  0.6018 | GP loss :  0.1102
N update:  211 | Mean ep 100 return:  16.65 /Used Time: 1888.09 /Used Step: 844844 | Dis loss :  0.8318 | GP loss :  0.1517
N update:  21

N update:  271 | Mean ep 100 return:  17.6 /Used Time: 2421.3 /Used Step: 1085084 | Dis loss :  0.7504 | GP loss :  0.1237
N update:  272 | Mean ep 100 return:  17.53 /Used Time: 2430.12 /Used Step: 1089088 | Dis loss :  0.7118 | GP loss :  0.1274
N update:  273 | Mean ep 100 return:  17.77 /Used Time: 2439.05 /Used Step: 1093092 | Dis loss :  0.5559 | GP loss :  0.1076
N update:  274 | Mean ep 100 return:  17.91 /Used Time: 2447.89 /Used Step: 1097096 | Dis loss :  0.5095 | GP loss :  0.1167
N update:  275 | Mean ep 100 return:  17.68 /Used Time: 2456.78 /Used Step: 1101100 | Dis loss :  0.7489 | GP loss :  0.1252
N update:  276 | Mean ep 100 return:  17.54 /Used Time: 2465.64 /Used Step: 1105104 | Dis loss :  0.7931 | GP loss :  0.0972
N update:  277 | Mean ep 100 return:  17.2 /Used Time: 2474.46 /Used Step: 1109108 | Dis loss :  0.7548 | GP loss :  0.1758
N update:  278 | Mean ep 100 return:  17.38 /Used Time: 2483.26 /Used Step: 1113112 | Dis loss :  0.4371 | GP loss :  0.1378
N u

N update:  337 | Mean ep 100 return:  15.87 /Used Time: 3003.34 /Used Step: 1349348 | Dis loss :  0.7806 | GP loss :  0.1331
N update:  338 | Mean ep 100 return:  15.64 /Used Time: 3012.19 /Used Step: 1353352 | Dis loss :  0.7657 | GP loss :  0.1248
N update:  339 | Mean ep 100 return:  15.44 /Used Time: 3020.96 /Used Step: 1357356 | Dis loss :  0.5631 | GP loss :  0.1017
N update:  340 | Mean ep 100 return:  15.46 /Used Time: 3029.77 /Used Step: 1361360 | Dis loss :  0.5624 | GP loss :  0.115
N update:  341 | Mean ep 100 return:  15.29 /Used Time: 3038.6 /Used Step: 1365364 | Dis loss :  0.5421 | GP loss :  0.0843
N update:  342 | Mean ep 100 return:  14.87 /Used Time: 3047.47 /Used Step: 1369368 | Dis loss :  0.5997 | GP loss :  0.1482
N update:  343 | Mean ep 100 return:  14.84 /Used Time: 3056.27 /Used Step: 1373372 | Dis loss :  0.6611 | GP loss :  0.0812
N update:  344 | Mean ep 100 return:  15.14 /Used Time: 3065.02 /Used Step: 1377376 | Dis loss :  0.5966 | GP loss :  0.1002
N 

N update:  403 | Mean ep 100 return:  19.92 /Used Time: 3584.94 /Used Step: 1613612 | Dis loss :  0.7632 | GP loss :  0.1007
N update:  404 | Mean ep 100 return:  19.67 /Used Time: 3593.81 /Used Step: 1617616 | Dis loss :  0.5677 | GP loss :  0.1674
N update:  405 | Mean ep 100 return:  19.34 /Used Time: 3602.77 /Used Step: 1621620 | Dis loss :  0.8188 | GP loss :  0.1336
N update:  406 | Mean ep 100 return:  19.28 /Used Time: 3611.61 /Used Step: 1625624 | Dis loss :  0.7066 | GP loss :  0.1166
N update:  407 | Mean ep 100 return:  19.46 /Used Time: 3620.51 /Used Step: 1629628 | Dis loss :  0.56 | GP loss :  0.0999
N update:  408 | Mean ep 100 return:  19.52 /Used Time: 3629.28 /Used Step: 1633632 | Dis loss :  0.6473 | GP loss :  0.1293
N update:  409 | Mean ep 100 return:  19.15 /Used Time: 3638.04 /Used Step: 1637636 | Dis loss :  0.8686 | GP loss :  0.1439
N update:  410 | Mean ep 100 return:  18.86 /Used Time: 3646.92 /Used Step: 1641640 | Dis loss :  0.6284 | GP loss :  0.1184
N 

N update:  469 | Mean ep 100 return:  21.48 /Used Time: 4166.96 /Used Step: 1877876 | Dis loss :  0.7636 | GP loss :  0.0913
N update:  470 | Mean ep 100 return:  21.39 /Used Time: 4175.73 /Used Step: 1881880 | Dis loss :  0.7257 | GP loss :  0.1125
N update:  471 | Mean ep 100 return:  21.26 /Used Time: 4184.47 /Used Step: 1885884 | Dis loss :  0.7075 | GP loss :  0.1285
N update:  472 | Mean ep 100 return:  21.35 /Used Time: 4193.19 /Used Step: 1889888 | Dis loss :  0.6653 | GP loss :  0.169
N update:  473 | Mean ep 100 return:  21.15 /Used Time: 4202.1 /Used Step: 1893892 | Dis loss :  0.534 | GP loss :  0.0684
N update:  474 | Mean ep 100 return:  21.04 /Used Time: 4211.0 /Used Step: 1897896 | Dis loss :  0.6731 | GP loss :  0.126
N update:  475 | Mean ep 100 return:  20.84 /Used Time: 4219.82 /Used Step: 1901900 | Dis loss :  0.8311 | GP loss :  0.0652
N update:  476 | Mean ep 100 return:  20.58 /Used Time: 4228.7 /Used Step: 1905904 | Dis loss :  0.4399 | GP loss :  0.1137
N upda

N update:  535 | Mean ep 100 return:  20.96 /Used Time: 4747.97 /Used Step: 2142140 | Dis loss :  0.6249 | GP loss :  0.1594
N update:  536 | Mean ep 100 return:  21.0 /Used Time: 4756.83 /Used Step: 2146144 | Dis loss :  0.6095 | GP loss :  0.1554
N update:  537 | Mean ep 100 return:  21.12 /Used Time: 4765.5 /Used Step: 2150148 | Dis loss :  0.5659 | GP loss :  0.1299
N update:  538 | Mean ep 100 return:  20.8 /Used Time: 4774.29 /Used Step: 2154152 | Dis loss :  0.7446 | GP loss :  0.134
N update:  539 | Mean ep 100 return:  21.14 /Used Time: 4783.03 /Used Step: 2158156 | Dis loss :  0.7803 | GP loss :  0.1328
N update:  540 | Mean ep 100 return:  21.03 /Used Time: 4791.73 /Used Step: 2162160 | Dis loss :  0.4879 | GP loss :  0.123
N update:  541 | Mean ep 100 return:  20.94 /Used Time: 4800.63 /Used Step: 2166164 | Dis loss :  0.8283 | GP loss :  0.139
N update:  542 | Mean ep 100 return:  21.11 /Used Time: 4809.57 /Used Step: 2170168 | Dis loss :  0.7844 | GP loss :  0.1029
N upda

N update:  601 | Mean ep 100 return:  18.86 /Used Time: 5327.89 /Used Step: 2406404 | Dis loss :  0.5849 | GP loss :  0.1009
N update:  602 | Mean ep 100 return:  18.95 /Used Time: 5336.63 /Used Step: 2410408 | Dis loss :  0.8253 | GP loss :  0.1014
N update:  603 | Mean ep 100 return:  19.21 /Used Time: 5345.53 /Used Step: 2414412 | Dis loss :  0.7026 | GP loss :  0.1289
N update:  604 | Mean ep 100 return:  19.17 /Used Time: 5354.32 /Used Step: 2418416 | Dis loss :  0.5054 | GP loss :  0.1271
N update:  605 | Mean ep 100 return:  19.21 /Used Time: 5363.11 /Used Step: 2422420 | Dis loss :  0.5967 | GP loss :  0.1319
N update:  606 | Mean ep 100 return:  19.22 /Used Time: 5371.73 /Used Step: 2426424 | Dis loss :  0.6074 | GP loss :  0.1115
N update:  607 | Mean ep 100 return:  18.69 /Used Time: 5380.51 /Used Step: 2430428 | Dis loss :  0.4739 | GP loss :  0.0841
N update:  608 | Mean ep 100 return:  18.78 /Used Time: 5389.25 /Used Step: 2434432 | Dis loss :  0.4963 | GP loss :  0.1015


N update:  667 | Mean ep 100 return:  20.15 /Used Time: 5908.1 /Used Step: 2670668 | Dis loss :  0.4774 | GP loss :  0.1441
N update:  668 | Mean ep 100 return:  20.02 /Used Time: 5916.84 /Used Step: 2674672 | Dis loss :  0.838 | GP loss :  0.0918
N update:  669 | Mean ep 100 return:  19.72 /Used Time: 5925.62 /Used Step: 2678676 | Dis loss :  1.1784 | GP loss :  0.099
N update:  670 | Mean ep 100 return:  19.77 /Used Time: 5934.37 /Used Step: 2682680 | Dis loss :  0.8122 | GP loss :  0.1123
N update:  671 | Mean ep 100 return:  20.18 /Used Time: 5943.18 /Used Step: 2686684 | Dis loss :  0.55 | GP loss :  0.123
N update:  672 | Mean ep 100 return:  20.39 /Used Time: 5952.1 /Used Step: 2690688 | Dis loss :  0.5762 | GP loss :  0.0923
N update:  673 | Mean ep 100 return:  20.29 /Used Time: 5960.95 /Used Step: 2694692 | Dis loss :  0.718 | GP loss :  0.1165
N update:  674 | Mean ep 100 return:  20.29 /Used Time: 5969.76 /Used Step: 2698696 | Dis loss :  0.6101 | GP loss :  0.0814
N update

N update:  733 | Mean ep 100 return:  22.12 /Used Time: 6489.32 /Used Step: 2934932 | Dis loss :  0.4709 | GP loss :  0.1268
N update:  734 | Mean ep 100 return:  22.5 /Used Time: 6498.15 /Used Step: 2938936 | Dis loss :  0.5049 | GP loss :  0.1684
N update:  735 | Mean ep 100 return:  22.4 /Used Time: 6506.94 /Used Step: 2942940 | Dis loss :  0.8159 | GP loss :  0.133
N update:  736 | Mean ep 100 return:  22.75 /Used Time: 6515.66 /Used Step: 2946944 | Dis loss :  0.4574 | GP loss :  0.1462
N update:  737 | Mean ep 100 return:  23.17 /Used Time: 6524.66 /Used Step: 2950948 | Dis loss :  0.5003 | GP loss :  0.139
N update:  738 | Mean ep 100 return:  23.31 /Used Time: 6533.55 /Used Step: 2954952 | Dis loss :  0.7243 | GP loss :  0.1208
N update:  739 | Mean ep 100 return:  23.28 /Used Time: 6542.5 /Used Step: 2958956 | Dis loss :  1.3114 | GP loss :  0.1144
N update:  740 | Mean ep 100 return:  23.64 /Used Time: 6551.45 /Used Step: 2962960 | Dis loss :  0.7076 | GP loss :  0.0971
N upd

## 결과 시각화

In [None]:
plt.plot(range(len(result)), result)
plt.tight_layout()
plt.show()

In [None]:
from matplotlib import animation

def display_frames_as_gif(frames):
    patch = plt.imshow(frames[0])
    plt.axis('off')
    def animate(i):
        patch.set_data(frames[i])
        
    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=5)
    anim.save('./gail_breakout_result.gif', writer='imagemagick', fps=30)

In [None]:
env = wrap(gym.make('BreakoutNoFrameskip-v4'))
s = np.array(env.reset())
total_reward = 0
frames = []

for t in range(10000):
    # Render into buffer. 
    frames.append(env.render(mode = 'rgb_array'))
    a, v, l = ppo.choose_action(np.expand_dims(s,axis=0))
    # take action and get next state
    s_, r, done, info = env.step(a)
    s_ = np.array(s_)
    total_reward += r
    if done:
        break
    s = s_
env.close()
print('Total Reward : %.2f'%total_reward)
display_frames_as_gif(frames)

![alt text](./ppo_pong_result.gif "segment")