### Import packages

In [None]:
!pip install -q numpy
!pip install -q matplotlib
!pip install -q mujoco
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy

%env MUJOCO_GL=egl

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.3/4.3 MB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.8/207.8 kB[0m [31m20.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m52.2 MB/s[0m eta [36m0:00:00[0m
[?25henv: MUJOCO_GL=egl


In [None]:
import os
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt
import mujoco

import scipy.signal
import time

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions.normal import Normal


### Define MuJoCo Enviroment

In [None]:
xml_string="""<!-- Cheetah Model

    The state space is populated with joints in the order that they are
    defined in this file. The actuators also operate on joints.

    State-Space (name/joint/parameter):
        - rootx     slider      position (m)
        - rootz     slider      position (m)
        - rooty     hinge       angle (rad)
        - bthigh    hinge       angle (rad)
        - bshin     hinge       angle (rad)
        - bfoot     hinge       angle (rad)
        - fthigh    hinge       angle (rad)
        - fshin     hinge       angle (rad)
        - ffoot     hinge       angle (rad)
        - rootx     slider      velocity (m/s)
        - rootz     slider      velocity (m/s)
        - rooty     hinge       angular velocity (rad/s)
        - bthigh    hinge       angular velocity (rad/s)
        - bshin     hinge       angular velocity (rad/s)
        - bfoot     hinge       angular velocity (rad/s)
        - fthigh    hinge       angular velocity (rad/s)
        - fshin     hinge       angular velocity (rad/s)
        - ffoot     hinge       angular velocity (rad/s)

    Actuators (name/actuator/parameter):
        - bthigh    hinge       torque (N m)
        - bshin     hinge       torque (N m)
        - bfoot     hinge       torque (N m)
        - fthigh    hinge       torque (N m)
        - fshin     hinge       torque (N m)
        - ffoot     hinge       torque (N m)

-->
<mujoco model="cheetah">
  <compiler angle="radian" coordinate="local" inertiafromgeom="true" settotalmass="14"/>
  <default>
    <joint armature=".1" damping=".01" limited="true" solimplimit="0 .8 .03" solreflimit=".02 1" stiffness="8"/>
    <geom conaffinity="0" condim="3" contype="1" friction=".4 .1 .1" rgba="0.8 0.6 .4 1" solimp="0.0 0.8 0.01" solref="0.02 1"/>
    <motor ctrllimited="true" ctrlrange="-1 1"/>
  </default>
  <size nstack="300000" nuser_geom="1"/>
  <option gravity="0 0 -9.81" timestep="0.01"/>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos="0 0 .7">
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <joint armature="0" axis="1 0 0" damping="0" limited="false" name="rootx" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 0 1" damping="0" limited="false" name="rootz" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 1 0" damping="0" limited="false" name="rooty" pos="0 0 0" stiffness="0" type="hinge"/>
      <geom fromto="-.5 0 0 .5 0 0" name="torso" size="0.046" type="capsule"/>
      <geom axisangle="0 1 0 .87" name="head" pos=".6 0 .1" size="0.046 .15" type="capsule"/>
      <!-- <site name='tip'  pos='.15 0 .11'/>-->
      <body name="bthigh" pos="-.5 0 0">
        <joint axis="0 1 0" damping="6" name="bthigh" pos="0 0 0" range="-.52 1.05" stiffness="240" type="hinge"/>
        <geom axisangle="0 1 0 -3.8" name="bthigh" pos=".1 0 -.13" size="0.046 .145" type="capsule"/>
        <body name="bshin" pos=".16 0 -.25">
          <joint axis="0 1 0" damping="4.5" name="bshin" pos="0 0 0" range="-.785 .785" stiffness="180" type="hinge"/>
          <geom axisangle="0 1 0 -2.03" name="bshin" pos="-.14 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .15" type="capsule"/>
          <body name="bfoot" pos="-.28 0 -.14">
            <joint axis="0 1 0" damping="3" name="bfoot" pos="0 0 0" range="-.4 .785" stiffness="120" type="hinge"/>
            <geom axisangle="0 1 0 -.27" name="bfoot" pos=".03 0 -.097" rgba="0.9 0.6 0.6 1" size="0.046 .094" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="fthigh" pos=".5 0 0">
        <joint axis="0 1 0" damping="4.5" name="fthigh" pos="0 0 0" range="-1 .7" stiffness="180" type="hinge"/>
        <geom axisangle="0 1 0 .52" name="fthigh" pos="-.07 0 -.12" size="0.046 .133" type="capsule"/>
        <body name="fshin" pos="-.14 0 -.24">
          <joint axis="0 1 0" damping="3" name="fshin" pos="0 0 0" range="-1.2 .87" stiffness="120" type="hinge"/>
          <geom axisangle="0 1 0 -.6" name="fshin" pos=".065 0 -.09" rgba="0.9 0.6 0.6 1" size="0.046 .106" type="capsule"/>
          <body name="ffoot" pos=".13 0 -.18">
            <joint axis="0 1 0" damping="1.5" name="ffoot" pos="0 0 0" range="-.5 .5" stiffness="60" type="hinge"/>
            <geom axisangle="0 1 0 -.6" name="ffoot" pos=".045 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .07" type="capsule"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor gear="120" joint="bthigh" name="bthigh"/>
    <motor gear="90" joint="bshin" name="bshin"/>
    <motor gear="60" joint="bfoot" name="bfoot"/>
    <motor gear="120" joint="fthigh" name="fthigh"/>
    <motor gear="60" joint="fshin" name="fshin"/>
    <motor gear="30" joint="ffoot" name="ffoot"/>
  </actuator>
</mujoco>"""

In [None]:
class HalfCheetahEnv():
  def __init__(
      self,
      frame_skip=5,
      forward_reward_weight=1.0,
      ctrl_cost_weight=0.1,
      reset_noise_scale=0.1
      ):

    self.frame_skip = frame_skip
    self.forward_reward_weight = forward_reward_weight
    self.ctrl_cost_weight = ctrl_cost_weight
    self.reset_noise_scale = reset_noise_scale

    self.initialize_simulation()
    self.init_qpos = self.data.qpos.ravel().copy()
    self.init_qvel = self.data.qvel.ravel().copy()
    self.dt = self.model.opt.timestep * self.frame_skip

    self.observation_dim = 17
    self.action_dim = 6
    self.action_limit = 1.

  def initialize_simulation(self):
    self.model = mujoco.MjModel.from_xml_string(xml_string)
    self.data = mujoco.MjData(self.model)
    mujoco.mj_resetData(self.model, self.data)
    self.renderer = mujoco.Renderer(self.model)

  def reset_simulation(self):
    mujoco.mj_resetData(self.model, self.data)

  def step_mujoco_simulation(self, ctrl, n_frames):
    self.data.ctrl[:] = ctrl
    mujoco.mj_step(self.model, self.data, nstep=n_frames)
    self.renderer.update_scene(self.data,0)

  def set_state(self, qpos, qvel):
    self.data.qpos[:] = np.copy(qpos)
    self.data.qvel[:] = np.copy(qvel)
    if self.model.na == 0:
      self.data.act[:] = None
    mujoco.mj_forward(self.model, self.data)

  def sample_action(self):
    return (2.*np.random.uniform(size=(self.action_dim,)) - 1)*self.action_limit

  def step(self, action):
    x_position_before = self.data.qpos[0]
    self.step_mujoco_simulation(action, self.frame_skip)
    x_position_after = self.data.qpos[0]
    x_velocity = (x_position_after - x_position_before) / self.dt

    # Rewards
    ctrl_cost = self.ctrl_cost_weight * np.sum(np.square(action))
    forward_reward = self.forward_reward_weight * x_velocity
    observation = self.get_obs()
    reward = forward_reward - ctrl_cost
    terminated = False
    info = {
        "x_position": x_position_after,
        "x_velocity": x_velocity,
        "reward_run": forward_reward,
        "reward_ctrl": -ctrl_cost,
    }
    return observation, reward, terminated, info

  def get_obs(self):
    position = self.data.qpos.flat.copy()
    velocity = self.data.qvel.flat.copy()
    position = position[1:]

    observation = np.concatenate((position, velocity)).ravel()
    return observation

  def render(self):
    return self.renderer.render()

  def reset(self):
    self.reset_simulation()
    noise_low = -self.reset_noise_scale
    noise_high = self.reset_noise_scale
    qpos = self.init_qpos + np.random.uniform(
        low=noise_low, high=noise_high, size=self.model.nq
    )
    qvel = (
        self.init_qvel
        + self.reset_noise_scale * np.random.standard_normal(self.model.nv)
    )
    self.set_state(qpos, qvel)
    observation = self.get_obs()
    return observation

In [None]:
device = torch.device('cpu')

if(torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

Device set to : cpu


# Proximal Policy Optimization

### Buffer 클래스: Buffer 클래스는 에피소드를 저장 하고 가공하는 역할

- 저장 해야 할 것들
  - state, $S_{t}$
  - action, $A_{t}$
  - reward, $R_{t+1}$
  - value: $V_{\phi}(S_{t})$
  - log policy distribution (old): $\log(\pi_{\theta}(A_{t}|S_{t}))$

- Buffer 클래스에서 계산해줘야 하는 것들
  - advantage:

    $\delta_{t}=R_{t+1}+\gamma V_{\phi}(S_{t+1}) - V_{\phi}(S_{t})$

    $A_{t}^{(\lambda)}=\delta_{t}+\gamma \lambda A_{t+1}^{(\lambda)}$
  - return:

    $G_{t} = R_{t+1} + \gamma G_{t+1}$

- PPO는 On-policy 알고리즘으로 Policy의 업데이트 이후 Buffer 초기화 작업 필요
  - Buffer의 크기 = Transition 수집 횟수
  - 정해진 횟수 만큼 Transition을 모두 수집하면 이전에 수집된 데이터가 모두 업데이트 됨. 별도의 Refresh 작업이 필요하지 않음

In [None]:
class PPOBuffer:

    def __init__(self, obs_dim, act_dim, size, gamma=0.99, lam=0.95):
        self.obs_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(combined_shape(size, act_dim), dtype=np.float32)
        self.adv_buf = np.zeros(size, dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.ret_buf = np.zeros(size, dtype=np.float32)
        self.val_buf = np.zeros(size, dtype=np.float32)
        self.logp_buf = np.zeros(size, dtype=np.float32)
        self.gamma, self.lam = gamma, lam
        self.idx, self.path_start_idx, self.max_size = 0, 0, size

    def store(self, obs, act, rew, val, logp):

        self.obs_buf[self.idx] = obs
        self.act_buf[self.idx] = act
        self.rew_buf[self.idx] = rew
        self.val_buf[self.idx] = val
        self.logp_buf[self.idx] = logp
        self.idx += 1

    def finish_path(self, last_val=0):

        path_slice = slice(self.path_start_idx, self.idx)
        rews = np.append(self.rew_buf[path_slice], last_val)
        vals = np.append(self.val_buf[path_slice], last_val)

        # the next two lines implement GAE-Lambda advantage calculation
        deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]
        self.adv_buf[path_slice] = discount_cumsum(deltas, self.gamma * self.lam)

        # the next line computes rewards-to-go, to be targets for the value function
        self.ret_buf[path_slice] = discount_cumsum(rews, self.gamma)[:-1]

        self.path_start_idx = self.idx

    def get(self):
        self.idx, self.path_start_idx = 0, 0
        # the next two lines implement the advantage normalization trick
        adv_mean, adv_std = np.mean(self.adv_buf), np.std(self.adv_buf)
        self.adv_buf = (self.adv_buf - adv_mean) / adv_std
        data = dict(obs=self.obs_buf, act=self.act_buf, ret=self.ret_buf,
                    adv=self.adv_buf, logp=self.logp_buf)
        return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in data.items()}

### Policy Network and Value Network
- Policy network: Observation이 입력, $\mu_{t}$를 출력으로 하는 네트워크 설계
  - $\mu_{t}=\pi_{\theta}(s)$
  - $\sigma_{t} = \exp($log_std$_{t})$
- Value network: Observation이 입력, 1차원 Value를 출력하는 네트워크 설계

In [None]:
# Util 함수들
def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])

def discount_cumsum(x, discount):
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]

PyTorch 네트워크 클래스 정의

In [None]:
def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

class MLPGaussianActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        log_std = -0.5 * np.ones(act_dim, dtype=np.float32)
        self.log_std = torch.nn.Parameter(torch.as_tensor(log_std))
        self.mu_net = mlp([obs_dim] + list(hidden_sizes) + [act_dim], activation)

    def _distribution(self, obs):
        mu = self.mu_net(obs)
        std = torch.exp(self.log_std)
        return Normal(mu, std)

    def _log_prob_from_distribution(self, pi, act):
        return pi.log_prob(act).sum(axis=-1)    # Last axis sum needed for Torch Normal distribution

    def _get_mode(self,obs):
        return self.mu_net(obs)

    def forward(self, obs, act=None):
        pi = self._distribution(obs)
        logp_a = None
        if act is not None:
            logp_a = self._log_prob_from_distribution(pi, act)
        return pi, logp_a


class MLPCritic(nn.Module):

    def __init__(self, obs_dim, hidden_sizes, activation):
        super().__init__()
        self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs):
        return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.

class MLPActorCritic(nn.Module):

    def __init__(self, obs_dim, act_dim,
                 hidden_sizes=(64,64), activation=nn.Tanh):
        super().__init__()

        # policy builder depends on action space
        self.pi = MLPGaussianActor(obs_dim, act_dim, hidden_sizes, activation)

        # build value function
        self.v  = MLPCritic(obs_dim, hidden_sizes, activation)

    def step(self, obs):
        with torch.no_grad():
            pi = self.pi._distribution(obs)
            a = pi.sample()
            logp_a = self.pi._log_prob_from_distribution(pi, a)
            v = self.v(obs)
        return a.numpy(), v.numpy(), logp_a.numpy()

    def act(self, obs):
        return self.pi._get_mode(obs).numpy()

Environment 생성, Buffer 생성.

In [None]:
steps_per_epoch = 4000
gamma = 0.99
lam = 0.97

env = HalfCheetahEnv()
obs_dim = env.observation_dim
act_dim = env.action_dim

# Set up experience buffer
buf = PPOBuffer(obs_dim, act_dim, steps_per_epoch, gamma, lam)

Actor Critic 네트워크 생성 및 Optimizer 초기화.

In [None]:
hidden_sizes=[64,64]
pi_lr = 3e-4
vf_lr = 1e-3

# Create actor-critic module
ac = MLPActorCritic(env.observation_dim, env.action_dim, hidden_sizes)

# Count variables
var_counts = tuple(count_vars(module) for module in [ac.pi, ac.v])
print('\nNumber of parameters: \t pi: %d, \t v: %d\n'%var_counts)

pi_optimizer = Adam(ac.pi.parameters(), lr=pi_lr)
vf_optimizer = Adam(ac.v.parameters(), lr=vf_lr)


Number of parameters: 	 pi: 5708, 	 v: 5377



에피소드 수집.
  - Environment 초기화. 초기 상태 $s_0$
  - Policy 로부터 다음 정보 획득: $a_{t}, v(s_{t}), \log(\pi(a_{t}|s_{t}))$
    - $a_{t} = \mu_{t} + \sigma_{t}\epsilon_{t}$
    - $v(s_{t})$: Advantage를 계산하기 위해 필요
    - $\log(\pi(a_{t}|s_{t}))$: Ratio를 계산하기 위해 필요
  - $s_{t+1}, r_{t+1}, d_{t}$ 획득
  - Buffer에 저장
  - 에피소드 종료 체크, 환경 초기화, $s_0$ 획득

In [None]:
max_ep_len = 1000

# Prepare for interaction with environment
start_time = time.time()
o, ep_ret, ep_len = env.reset(), 0, 0

# Main loop: collect experience in env and update/log each epoch
for t in range(steps_per_epoch):
    a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))

    next_o, r, d, _ = env.step(a)
    ep_ret += r
    ep_len += 1

    # save and log
    buf.store(o, a, r, v, logp)

    # Update obs (critical!)
    o = next_o

    timeout = ep_len == max_ep_len
    terminal = d or timeout
    epoch_ended = t==steps_per_epoch-1

    if terminal or epoch_ended:
      if timeout or epoch_ended:
          _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
      else:
          v = 0
      buf.finish_path(v)
      o, ep_ret, ep_len = env.reset(), 0, 0

Buffer에서 data를 불러오기

In [None]:
data = buf.get()
obs, act, adv, logp_old, ret = data['obs'], data['act'], data['adv'], data['logp'], data['ret']

Policy 네트워크 업데이트. 이때 업데이트 된 kl divergence가 특정 threshold 이상이면 업데이트 중지.

In [None]:
target_kl=0.02
clip_ratio = 0.2
train_pi_iters = 80

# Train policy with multiple steps of gradient descent
for i in range(train_pi_iters):

    # Policy loss
    pi, logp = ac.pi(obs, act)
    ratio = torch.exp(logp - logp_old)
    clip_adv = torch.clamp(ratio, 1-clip_ratio, 1+clip_ratio) * adv
    loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()

    pi_optimizer.zero_grad()
    approx_kl = (logp_old - logp).mean().item()
    kl = np.mean(approx_kl)
    if kl > 1.5 * target_kl:
        print('Early stopping at step %d due to reaching max kl.'%i)
        break
    loss_pi.backward()
    pi_optimizer.step()

Value 네트워크 업데이트

In [None]:
train_v_iters=80

# Value function learning
for i in range(train_v_iters):
    vf_optimizer.zero_grad()
    loss_v = ((ac.v(obs) - ret)**2).mean()
    loss_v.backward()
    vf_optimizer.step()

### PPO

- 앞서 만든 클래스들과 함수들을 모두 합쳐 PPO 알고리즘 구현
- Policy Loss: Clipped Surrogated Loss
  - $L_{CLIP}(\theta):=\sum_{t=1}^{B} \min( r_{t}(\theta)A_{t},$clip$(r_{t}(\theta),1-\epsilon, 1+\epsilon)A_{t})$
  - $r_{t}(\theta)=\frac{\pi_{\theta}(a_{t}|s_{t})}{\pi_{\theta_{k}}(a_{t}|s_{t})}$: $\pi_{\theta_{k}}(a_{t}|s_{t})$를 미리 Buffer에 저장해두기
- Value Loss
  - $L_{V}(\phi):=\sum_{t=1}^{B} (G_{t} - V_{\phi}(s_{t}))^{2}$

In [None]:
def ppo(env_fn, actor_critic=MLPActorCritic, ac_kwargs=dict(), seed=0,
        steps_per_epoch=4000, epochs=50, gamma=0.99, clip_ratio=0.2, pi_lr=3e-4,
        vf_lr=1e-3, train_pi_iters=80, train_v_iters=80, lam=0.97, max_ep_len=1000,
        target_kl=0.01, save_freq=10):

    # Random seed
    torch.manual_seed(seed)
    np.random.seed(seed)

    EpRet = []
    EpLen = []
    VVals = []
    TotalEnvInteracts = []
    LossPi = []
    LossV = []
    DeltaLossPi = []
    DeltaLossV = []
    Entropy = []
    KL = []
    ClipFrac = []
    StopIter = []
    Time = []

    # Instantiate environment
    env = env_fn()
    obs_dim = env.observation_dim
    act_dim = env.action_dim

    # Create actor-critic module
    ac = actor_critic(env.observation_dim, env.action_dim, **ac_kwargs)

    # Count variables
    var_counts = tuple(count_vars(module) for module in [ac.pi, ac.v])
    print('\nNumber of parameters: \t pi: %d, \t v: %d\n'%var_counts)

    # Set up experience buffer
    buf = PPOBuffer(obs_dim, act_dim, steps_per_epoch, gamma, lam)

    # Set up function for computing PPO policy loss
    def compute_loss_pi(data):
        obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data['logp']

        # Policy loss
        pi, logp = ac.pi(obs, act)
        ratio = torch.exp(logp - logp_old)
        clip_adv = torch.clamp(ratio, 1-clip_ratio, 1+clip_ratio) * adv
        loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()

        # Useful extra info
        approx_kl = (logp_old - logp).mean().item()
        ent = pi.entropy().mean().item()
        clipped = ratio.gt(1+clip_ratio) | ratio.lt(1-clip_ratio)
        clipfrac = torch.as_tensor(clipped, dtype=torch.float32).mean().item()
        pi_info = dict(kl=approx_kl, ent=ent, cf=clipfrac)

        return loss_pi, pi_info

    # Set up function for computing value loss
    def compute_loss_v(data):
        obs, ret = data['obs'], data['ret']
        return ((ac.v(obs) - ret)**2).mean()

    # Set up optimizers for policy and value function
    pi_optimizer = Adam(ac.pi.parameters(), lr=pi_lr)
    vf_optimizer = Adam(ac.v.parameters(), lr=vf_lr)

    def update():
        data = buf.get()

        pi_l_old, pi_info_old = compute_loss_pi(data)
        pi_l_old = pi_l_old.item()
        v_l_old = compute_loss_v(data).item()

        # Train policy with multiple steps of gradient descent
        for i in range(train_pi_iters):
            pi_optimizer.zero_grad()
            loss_pi, pi_info = compute_loss_pi(data)
            kl = np.mean(pi_info['kl'])
            if kl > 1.5 * target_kl:
                print('Early stopping at step %d due to reaching max kl.'%i)
                break
            loss_pi.backward()
            pi_optimizer.step()

        # Value function learning
        for i in range(train_v_iters):
            vf_optimizer.zero_grad()
            loss_v = compute_loss_v(data)
            loss_v.backward()
            vf_optimizer.step()

        # Log changes from update
        LossPi.append(pi_l_old)
        LossV.append(v_l_old)
        KL.append(pi_info['kl'])
        Entropy.append(pi_info_old['ent'])
        ClipFrac.append(pi_info['cf'])
        DeltaLossPi.append(loss_pi.item() - pi_l_old)
        DeltaLossV.append(loss_v.item() - v_l_old)

    # Prepare for interaction with environment
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        for t in range(steps_per_epoch):
            a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))

            next_o, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1

            # save and log
            buf.store(o, a, r, v, logp)
            VVals.append(v)

            # Update obs (critical!)
            o = next_o

            timeout = ep_len == max_ep_len
            terminal = d or timeout
            epoch_ended = t==steps_per_epoch-1

            if terminal or epoch_ended:
                if epoch_ended and not(terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.'%ep_len, flush=True)
                # if trajectory didn't reach terminal state, bootstrap value target
                if timeout or epoch_ended:
                    _, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
                else:
                    v = 0
                buf.finish_path(v)
                if terminal:
                    EpRet.append(ep_ret)
                    EpLen.append(ep_len)
                o, ep_ret, ep_len = env.reset(), 0, 0

        # Perform PPO update!
        update()

        TotalEnvInteracts.append((epoch+1)*steps_per_epoch)
        Time.append(time.time()-start_time)

        print(f'[Epoch:{epoch}] EpRet:{np.min(EpRet[-10:]):8.2f} < {np.mean(EpRet[-10:]):8.2f} < {np.max(EpRet[-10:]):8.2f}, EpLen:{np.mean(EpLen[-10:]):8.2f}, VVals:{np.mean(VVals[-10:]):8.2f}, TotalEnvInteracts:{TotalEnvInteracts[-1]:8d}, LossPi:{np.mean(LossPi[-10:]):8.2f}, LossV:{np.mean(LossV[-10:]):8.2f}, Entropy:{np.mean(Entropy[-10:]):8.2f}, KL:{np.mean(KL[-10:]):8.2f}, Time:{Time[-1]:8.2f}')
    return ac, EpRet, EpLen, VVals, TotalEnvInteracts, LossPi, LossV, Entropy, KL, Time

### Run PPO

In [None]:
ac = ppo(lambda : HalfCheetahEnv(), actor_critic=MLPActorCritic,
         ac_kwargs=dict(hidden_sizes=[64,64]), gamma=0.99,
         seed=0, steps_per_epoch=5000, epochs=50)


Number of parameters: 	 pi: 5708, 	 v: 5377

[Epoch:0] EpRet: -438.42 <  -278.33 <  -158.13, EpLen: 1000.00, VVals:   -0.05, TotalEnvInteracts:    5000, LossPi:    0.00, LossV: 1014.67, Entropy:    0.92, KL:    0.01, Time:    5.13
Early stopping at step 29 due to reaching max kl.
[Epoch:1] EpRet: -438.42 <  -278.91 <  -158.13, EpLen: 1000.00, VVals:   -5.45, TotalEnvInteracts:   10000, LossPi:    0.00, LossV:  759.97, Entropy:    0.91, KL:    0.02, Time:    9.70
Early stopping at step 55 due to reaching max kl.
[Epoch:2] EpRet: -325.14 <  -259.48 <  -138.61, EpLen: 1000.00, VVals:  -16.03, TotalEnvInteracts:   15000, LossPi:    0.00, LossV:  575.38, Entropy:    0.91, KL:    0.02, Time:   15.32
[Epoch:3] EpRet: -284.86 <  -234.16 <  -138.61, EpLen: 1000.00, VVals:  -18.81, TotalEnvInteracts:   20000, LossPi:    0.00, LossV:  479.04, Entropy:    0.91, KL:    0.01, Time:   20.37
Early stopping at step 6 due to reaching max kl.
[Epoch:4] EpRet: -358.64 <  -218.28 <  -106.98, EpLen: 1000.0

### Test

In [None]:
env = HalfCheetahEnv()
imgs = []

obs = env.reset()
for t in range(1000):
  with torch.no_grad():
    obs = torch.as_tensor(obs, dtype=torch.float32)
    action = ac[0].act(obs)
  obs, reward, terminated, info = env.step(action)
  img = env.render()
  imgs.append(img)

media.show_video(imgs, fps=1/env.dt)