In [61]:
import sys; sys.path.insert(0, '..')

import gym
import sys
import torch
import tester
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output, Javascript

import pytorch_drl.models.ppo_models as models

from pytorch_drl.algs.trpo import TRPO

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [62]:
device =torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda:0


## 1. Define Utils:

### 1.1 Plotting:

In [63]:
def mean_filter(arr, filter_len):
    arr = np.concatenate([[0]*(filter_len-1), arr])
    filter_arr = np.ones(filter_len)/filter_len
    arr = arr
    output = list()
    for i in range(filter_len-1, len(arr)):
        conv = np.sum(filter_arr * arr[i-filter_len+1:i+1])
        output.append(conv)
    return output

def plot(scores, n=None):
    if n is not None:
        scores = mean_filter(scores, n)
    clear_output(True)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    plt.plot(np.arange(len(scores)), scores)
    plt.ylabel('Score')
    plt.xlabel('Episode')
    plt.show()

## 2. Create environment

In [64]:
env_name = "LunarLander-v2"
env_name = "CartPole-v0"
env = gym.make(env_name)
env.seed(0)

state_size =  env.observation_space.shape[0]
action_size = env.action_space.n

print("State size:", state_size, "\nAction size:", action_size)

State size: 4 
Action size: 2


## 3. Define networks for different algorithms

In [121]:
# Models:
actor_constructor = models.ActorNetwork
actor_args = (state_size, action_size)
critic_constructor = models.CriticNetwork
critic_args = (state_size, action_size)


## 4. TRPO Test

In [None]:
tmax = 1000
n_traj = 2000
n_env = 1

actor_constructor = models.ActorNetwork
actor_args = (state_size, action_size)
critic_constructor = models.CriticNetwork
critic_args = (state_size, action_size)


agent = TRPO(actor_constructor,
             actor_args,
             critic_constructor,
             critic_args,
             critic_use_bfgs=False,
             critic_lr_sgd=1e-3,
             critic_lr_bfgs=0.1,
             critic_reg_bfgs=1e-3,
             max_kl=1e-2,
             backtrack_alpha=0.5,
             backtrack_steps=10,
             damping_coeff=0.1,
             env_name=env_name,
             gamma=0.99, 
             tau=0.97,
             n_env=8,
             device=device,
             normalize_rewards=True,
             max_grad_norm=None,
            )
# train the agent
scores, losses = agent.train(tmax, n_traj,  env)

# plot the training:
x = np.arange(len(scores))
scores = mean_filter(scores, 50)
plt.plot(x, scores, label = "scores")
plt.show()

"""
tensor([ 7.6780e-05,  3.0375e-04, -5.4383e-05,  ..., -1.2598e-03,
         7.4972e-02, -7.4972e-02]
         
tensor([-6.1298e-05,  7.2905e-05,  9.5473e-05,  ...,  6.8445e-04,
         1.0478e-01, -1.0478e-01])
"""

Ep: 0; Score: 20.0, Loss: 245.16111755371094
Ep: 4; Score: 14.5, Loss: 259.6360168457031
Ep: 8; Score: 12.666666666666666, Loss: 226.7269744873047
Ep: 12; Score: 12.0, Loss: 191.65011596679688
Ep: 16; Score: 11.6, Loss: 163.02273559570312
Ep: 20; Score: 11.0, Loss: 140.6450958251953
Ep: 24; Score: 10.714285714285714, Loss: 107.5106201171875
Ep: 28; Score: 10.625, Loss: 84.97537994384766
Ep: 32; Score: 10.444444444444445, Loss: 70.53718566894531
Ep: 36; Score: 10.4, Loss: 61.9267463684082
Ep: 40; Score: 10.363636363636363, Loss: 52.350830078125
Ep: 44; Score: 10.25, Loss: 36.73377227783203
Ep: 48; Score: 10.076923076923077, Loss: 25.562862396240234
Ep: 52; Score: 10.071428571428571, Loss: 20.513202667236328
Ep: 56; Score: 10.066666666666666, Loss: 17.730745315551758
Ep: 60; Score: 10.0625, Loss: 15.272039413452148
Ep: 64; Score: 10.0, Loss: 13.624656677246094
Ep: 68; Score: 10.0, Loss: 12.221139907836914
Ep: 72; Score: 9.894736842105264, Loss: 10.931342124938965
Ep: 76; Score: 9.85, Los

### 4.1 Trained Agent Demonstration

In [9]:
tester.test_agent(agent, env, max_t=200, render=True, num_of_episodes=5, log=True)

200.0
200.0
200.0
200.0
200.0
