In [None]:
import sys; sys.path.insert(0, '..')

import gym
import sys
import torch
import tester
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output, Javascript

import pytorch_drl.models.actor_critic_models as models
import pytorch_drl.models.gail_models as gail_models

from pytorch_drl.algs.gail import GAIL

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
device =torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

## 1. Define Utils:

### 1.1 Plotting:

In [None]:
def mean_filter(arr, filter_len):
    arr = np.concatenate([[arr[0]]*(filter_len-1), arr])
    filter_arr = np.ones(filter_len)/filter_len
    arr = arr
    output = list()
    for i in range(filter_len-1, len(arr)):
        conv = np.sum(filter_arr * arr[i-filter_len+1:i+1])
        output.append(conv)
    return output

def plot(scores, n=None):
    if n is not None:
        scores = mean_filter(scores, n)
    clear_output(True)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    plt.plot(np.arange(len(scores)), scores)
    plt.ylabel('Score')
    plt.xlabel('Episode')
    plt.show()

## 2. Create environment

In [None]:
env_name = "CartPole-v0"
env = gym.make(env_name)
env.seed(0)

state_size =  env.observation_space.shape[0]
action_size = env.action_space.n

print("State size:", state_size, "\nAction size:", action_size)

## 4. GAIL Test

In [None]:
expert_trajectories = np.load("experts/ppo_{}".format(env_name))

In [None]:
print(expert_trajectories.shape)

In [None]:
actor_critic = models.ActorCriticMLP(state_size, action_size, env.action_space)

discriminator = gail_models.GAILDiscriminator(state_size, action_size)

tmax = 5
n_traj = 2000
n_env = 8
ppo_epochs = 4
batch_size = 32

# init agent:
agent = GAIL(actor_critic,
             discriminator,
             expert_trajectories,
             env_name,
             action_size,
             gamma=0.99, 
             gail_epochs=1,
             ppo_epochs=ppo_epochs,
             lr_ppo=2e-3, 
             lr_discriminator=3e-3,
             tau=0.95,
             n_env=n_env,
             device=device,
             max_grad_norm=0.5,
             critic_coef=0.5,
             entropy_coef=0.01,
             mini_batch_size=batch_size,
             )
   
    
# train the agent
max_score = 195.
alg_name = "gail_".format(env_name)
scores, losses = agent.train(tmax, n_traj,  env, max_score, alg_name)

# plot the training:
x = np.arange(len(scores))
scores = mean_filter(scores, 50)
plt.plot(x, scores, label = "scores")
plt.show()

### 4.1 Trained Agent Demonstration

In [None]:
agent.test(env, render=True, n_times=5)