### Import packages

In [None]:
!pip install -q numpy
!pip install -q matplotlib
!pip install -q mujoco
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy

%env MUJOCO_GL=egl

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.3/4.3 MB[0m [31m56.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.8/207.8 kB[0m [31m18.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m28.4 MB/s[0m eta [36m0:00:00[0m
[?25henv: MUJOCO_GL=egl


In [None]:
import os
import numpy as np
import mediapy as media
import matplotlib.pyplot as plt
import mujoco
from copy import deepcopy
import torch.nn.functional as F

import scipy.signal
import time

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.distributions.normal import Normal

### Define MuJoCo Enviroment

In [None]:
xml_string="""<!-- Cheetah Model

    The state space is populated with joints in the order that they are
    defined in this file. The actuators also operate on joints.

    State-Space (name/joint/parameter):
        - rootx     slider      position (m)
        - rootz     slider      position (m)
        - rooty     hinge       angle (rad)
        - bthigh    hinge       angle (rad)
        - bshin     hinge       angle (rad)
        - bfoot     hinge       angle (rad)
        - fthigh    hinge       angle (rad)
        - fshin     hinge       angle (rad)
        - ffoot     hinge       angle (rad)
        - rootx     slider      velocity (m/s)
        - rootz     slider      velocity (m/s)
        - rooty     hinge       angular velocity (rad/s)
        - bthigh    hinge       angular velocity (rad/s)
        - bshin     hinge       angular velocity (rad/s)
        - bfoot     hinge       angular velocity (rad/s)
        - fthigh    hinge       angular velocity (rad/s)
        - fshin     hinge       angular velocity (rad/s)
        - ffoot     hinge       angular velocity (rad/s)

    Actuators (name/actuator/parameter):
        - bthigh    hinge       torque (N m)
        - bshin     hinge       torque (N m)
        - bfoot     hinge       torque (N m)
        - fthigh    hinge       torque (N m)
        - fshin     hinge       torque (N m)
        - ffoot     hinge       torque (N m)

-->
<mujoco model="cheetah">
  <compiler angle="radian" coordinate="local" inertiafromgeom="true" settotalmass="14"/>
  <default>
    <joint armature=".1" damping=".01" limited="true" solimplimit="0 .8 .03" solreflimit=".02 1" stiffness="8"/>
    <geom conaffinity="0" condim="3" contype="1" friction=".4 .1 .1" rgba="0.8 0.6 .4 1" solimp="0.0 0.8 0.01" solref="0.02 1"/>
    <motor ctrllimited="true" ctrlrange="-1 1"/>
  </default>
  <size nstack="300000" nuser_geom="1"/>
  <option gravity="0 0 -9.81" timestep="0.01"/>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos="0 0 .7">
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <joint armature="0" axis="1 0 0" damping="0" limited="false" name="rootx" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 0 1" damping="0" limited="false" name="rootz" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 1 0" damping="0" limited="false" name="rooty" pos="0 0 0" stiffness="0" type="hinge"/>
      <geom fromto="-.5 0 0 .5 0 0" name="torso" size="0.046" type="capsule"/>
      <geom axisangle="0 1 0 .87" name="head" pos=".6 0 .1" size="0.046 .15" type="capsule"/>
      <!-- <site name='tip'  pos='.15 0 .11'/>-->
      <body name="bthigh" pos="-.5 0 0">
        <joint axis="0 1 0" damping="6" name="bthigh" pos="0 0 0" range="-.52 1.05" stiffness="240" type="hinge"/>
        <geom axisangle="0 1 0 -3.8" name="bthigh" pos=".1 0 -.13" size="0.046 .145" type="capsule"/>
        <body name="bshin" pos=".16 0 -.25">
          <joint axis="0 1 0" damping="4.5" name="bshin" pos="0 0 0" range="-.785 .785" stiffness="180" type="hinge"/>
          <geom axisangle="0 1 0 -2.03" name="bshin" pos="-.14 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .15" type="capsule"/>
          <body name="bfoot" pos="-.28 0 -.14">
            <joint axis="0 1 0" damping="3" name="bfoot" pos="0 0 0" range="-.4 .785" stiffness="120" type="hinge"/>
            <geom axisangle="0 1 0 -.27" name="bfoot" pos=".03 0 -.097" rgba="0.9 0.6 0.6 1" size="0.046 .094" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="fthigh" pos=".5 0 0">
        <joint axis="0 1 0" damping="4.5" name="fthigh" pos="0 0 0" range="-1 .7" stiffness="180" type="hinge"/>
        <geom axisangle="0 1 0 .52" name="fthigh" pos="-.07 0 -.12" size="0.046 .133" type="capsule"/>
        <body name="fshin" pos="-.14 0 -.24">
          <joint axis="0 1 0" damping="3" name="fshin" pos="0 0 0" range="-1.2 .87" stiffness="120" type="hinge"/>
          <geom axisangle="0 1 0 -.6" name="fshin" pos=".065 0 -.09" rgba="0.9 0.6 0.6 1" size="0.046 .106" type="capsule"/>
          <body name="ffoot" pos=".13 0 -.18">
            <joint axis="0 1 0" damping="1.5" name="ffoot" pos="0 0 0" range="-.5 .5" stiffness="60" type="hinge"/>
            <geom axisangle="0 1 0 -.6" name="ffoot" pos=".045 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .07" type="capsule"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor gear="120" joint="bthigh" name="bthigh"/>
    <motor gear="90" joint="bshin" name="bshin"/>
    <motor gear="60" joint="bfoot" name="bfoot"/>
    <motor gear="120" joint="fthigh" name="fthigh"/>
    <motor gear="60" joint="fshin" name="fshin"/>
    <motor gear="30" joint="ffoot" name="ffoot"/>
  </actuator>
</mujoco>"""

In [None]:
class HalfCheetahEnv():
  def __init__(
      self,
      frame_skip=5,
      forward_reward_weight=1.0,
      ctrl_cost_weight=0.1,
      reset_noise_scale=0.1
      ):

    self.frame_skip = frame_skip
    self.forward_reward_weight = forward_reward_weight
    self.ctrl_cost_weight = ctrl_cost_weight
    self.reset_noise_scale = reset_noise_scale

    self.initialize_simulation()
    self.init_qpos = self.data.qpos.ravel().copy()
    self.init_qvel = self.data.qvel.ravel().copy()
    self.dt = self.model.opt.timestep * self.frame_skip

    self.observation_dim = 17
    self.action_dim = 6
    self.action_limit = 1.

  def initialize_simulation(self):
    self.model = mujoco.MjModel.from_xml_string(xml_string)
    self.data = mujoco.MjData(self.model)
    mujoco.mj_resetData(self.model, self.data)
    self.renderer = mujoco.Renderer(self.model)

  def reset_simulation(self):
    mujoco.mj_resetData(self.model, self.data)

  def step_mujoco_simulation(self, ctrl, n_frames):
    self.data.ctrl[:] = ctrl
    mujoco.mj_step(self.model, self.data, nstep=n_frames)
    self.renderer.update_scene(self.data,0)

  def set_state(self, qpos, qvel):
    self.data.qpos[:] = np.copy(qpos)
    self.data.qvel[:] = np.copy(qvel)
    if self.model.na == 0:
      self.data.act[:] = None
    mujoco.mj_forward(self.model, self.data)

  def sample_action(self):
    return (2.*np.random.uniform(size=(self.action_dim,)) - 1)*self.action_limit

  def step(self, action):
    x_position_before = self.data.qpos[0]
    self.step_mujoco_simulation(action, self.frame_skip)
    x_position_after = self.data.qpos[0]
    x_velocity = (x_position_after - x_position_before) / self.dt

    # Rewards
    ctrl_cost = self.ctrl_cost_weight * np.sum(np.square(action))
    forward_reward = self.forward_reward_weight * x_velocity
    observation = self.get_obs()
    reward = forward_reward - ctrl_cost
    terminated = False
    info = {
        "x_position": x_position_after,
        "x_velocity": x_velocity,
        "reward_run": forward_reward,
        "reward_ctrl": -ctrl_cost,
    }
    return observation, reward, terminated, info

  def get_obs(self):
    position = self.data.qpos.flat.copy()
    velocity = self.data.qvel.flat.copy()
    position = position[1:]

    observation = np.concatenate((position, velocity)).ravel()
    return observation

  def render(self):
    return self.renderer.render()

  def reset(self):
    self.reset_simulation()
    noise_low = -self.reset_noise_scale
    noise_high = self.reset_noise_scale
    qpos = self.init_qpos + np.random.uniform(
        low=noise_low, high=noise_high, size=self.model.nq
    )
    qvel = (
        self.init_qvel
        + self.reset_noise_scale * np.random.standard_normal(self.model.nv)
    )
    self.set_state(qpos, qvel)
    observation = self.get_obs()
    return observation

In [None]:
device = torch.device('cpu')

if(torch.cuda.is_available()):
    device = torch.device('cuda:0')
    torch.cuda.empty_cache()
    print("Device set to : " + str(torch.cuda.get_device_name(device)))
else:
    print("Device set to : cpu")

Device set to : cpu


# Soft Actor Critic

### Buffer 클래스 정의

- 저장 해야 할 것들
  - state, $S_{t}$
  - next state, $S_{t+1}$
  - action, $A_{t}$
  - rewards, $R_{t+1}$
  - done (terminal information), $d_{t+1}$

- Off-policy 알고리즘이므로 정해진 buffer size 만큼 모든 데이터를 수집 가능.

In [None]:
# Util 함수들
def combined_shape(length, shape=None):
    if shape is None:
        return (length,)
    return (length, shape) if np.isscalar(shape) else (length, *shape)

def count_vars(module):
    return sum([np.prod(p.shape) for p in module.parameters()])

def discount_cumsum(x, discount):
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]

In [None]:
class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for SAC agents.
    """

    def __init__(self, obs_dim, act_dim, size):
        self.obs_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.obs2_buf = np.zeros(combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(combined_shape(size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.idx, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done):
        self.obs_buf[self.idx] = obs
        self.obs2_buf[self.idx] = next_obs
        self.act_buf[self.idx] = act
        self.rew_buf[self.idx] = rew
        self.done_buf[self.idx] = done
        self.idx = (self.idx+1) % self.max_size
        self.size = min(self.size+1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        batch = dict(obs=self.obs_buf[idxs],
                     obs2=self.obs2_buf[idxs],
                     act=self.act_buf[idxs],
                     rew=self.rew_buf[idxs],
                     done=self.done_buf[idxs])
        return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in batch.items()}

### Policy Network and Q Network

- Policy Network: Observation이 입력, 평균 $\mu_{t}$, 표준편차 $\Sigma_{t}$를 출력하는 네트워크 설계. 정책 함수의 분포로 Squashed Gaussian 분포를 제안하여 사용함
  - Squashed Gaussian Distribution
  - $u_{t}\sim\mathcal{N}(\mu_{t},\Sigma_{t})$
  - $a_{t}=$tanh$(u_{t})$
  - $\pi(a_{t}|s_{t})=\mathcal{N}(u_{t}|\mu_{t},\Sigma_{t})\big|$det$\left(\frac{da}{du}\right)\big|^{-1}$
  - $\big|$det$\left(\frac{da}{du}\right)\big|^{-1}=\prod_{i=1}^{d}(1-$tanh$(u_{i})^{2})^{-1}$
- Q Network: Observation과 Action이 입력, 1차원의 Q(s,a) 값을 출력하는 네트워크 설계

In [None]:
def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

LOG_STD_MAX = 2
LOG_STD_MIN = -20

class SquashedGaussianMLPActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit):
        super().__init__()
        self.net = mlp([obs_dim] + list(hidden_sizes), activation, activation)
        self.mu_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.log_std_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.act_limit = act_limit

    def forward(self, obs, deterministic=False, with_logprob=True):
        net_out = self.net(obs)
        mu = self.mu_layer(net_out)
        log_std = self.log_std_layer(net_out)
        log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
        std = torch.exp(log_std)

        # Pre-squash distribution and sample
        pi_distribution = Normal(mu, std)
        if deterministic:
            # Only used for evaluating policy at test time.
            pi_action = mu
        else:
            pi_action = pi_distribution.rsample()

        if with_logprob:
            logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
            logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
        else:
            logp_pi = None

        pi_action = torch.tanh(pi_action)
        pi_action = self.act_limit * pi_action

        return pi_action, logp_pi


class MLPQFunction(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation):
        super().__init__()
        self.q = mlp([obs_dim + act_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs, act):
        q = self.q(torch.cat([obs, act], dim=-1))
        return torch.squeeze(q, -1) # Critical to ensure q has right shape.

class MLPActorCritic(nn.Module):

    def __init__(self, observation_dim, action_dim, action_limit, hidden_sizes=(256,256),
                 activation=nn.ReLU):
        super().__init__()

        obs_dim = observation_dim
        act_dim = action_dim
        act_limit = action_limit

        # build policy and value functions
        self.pi = SquashedGaussianMLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit)
        self.q1 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)
        self.q2 = MLPQFunction(obs_dim, act_dim, hidden_sizes, activation)

    def act(self, obs, deterministic=False):
        with torch.no_grad():
            a, _ = self.pi(obs, deterministic, False)
            return a.numpy()

Environment 설정

In [None]:
env = HalfCheetahEnv()
obs_dim = env.observation_dim
act_dim = env.action_dim
act_limit = env.action_limit

Policy 네트워크 Q1, Q2 네트워크 정의 및 각 네트워크별 optimizer 정의

In [None]:
lr=1e-3

# Create actor-critic module and target networks
ac = MLPActorCritic(obs_dim, act_dim, act_limit)
ac_targ = deepcopy(ac)

# List of parameters for both Q-networks (save this for convenience)
q_params = list(ac.q1.parameters()) + list(ac.q2.parameters())

# Count variables (protip: try to get a feel for how different size networks behave!)
var_counts = tuple(count_vars(module) for module in [ac.pi, ac.q1, ac.q2])
print('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n'%var_counts)

# Set up optimizers for policy and q-function
pi_optimizer = Adam(ac.pi.parameters(), lr=lr)
q_optimizer = Adam(q_params, lr=lr)

# Freeze target networks with respect to optimizers (only update via polyak averaging)
for p in ac_targ.parameters():
    p.requires_grad = False


Number of parameters: 	 pi: 73484, 	 q1: 72193, 	 q2: 72193



Replay buffer 생성

In [None]:
replay_size=200000

# Experience buffer
replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)

에피소드 수집 및 buffer에 저장

In [None]:
total_steps = 1000
max_ep_len = 1000

o, ep_ret, ep_len = env.reset(), 0, 0

# Main loop: collect experience in env and update/log each epoch
for t in range(total_steps):

    a = env.sample_action()

    # Step the env
    o2, r, d, _ = env.step(a)
    ep_ret += r
    ep_len += 1

    # Ignore the "done" signal if it comes from hitting the time
    # horizon (that is, when it's an artificial terminal signal
    # that isn't based on the agent's state)
    d = False if ep_len==max_ep_len else d

    # Store experience to replay buffer
    replay_buffer.store(o, a, r, o2, d)

    # Super critical, easy to overlook step: make sure to update
    # most recent observation!
    o = o2

    # End of trajectory handling
    if d or (ep_len == max_ep_len):
        o, ep_ret, ep_len = env.reset(), 0, 0

Buffer에서 batch sampling

In [None]:
batch_size=256
batch = replay_buffer.sample_batch(batch_size)
o, a, r, o2, d = batch['obs'], batch['act'], batch['rew'], batch['obs2'], batch['done']

Q function loss 계산 및 gradient step 진행

In [None]:
gamma = 0.99
alpha = 0.2

q1 = ac.q1(o,a)
q2 = ac.q2(o,a)

# Bellman backup for Q functions
with torch.no_grad():
    # Target actions come from *current* policy
    a2, logp_a2 = ac.pi(o2)

    # Target Q-values
    q1_pi_targ = ac_targ.q1(o2, a2)
    q2_pi_targ = ac_targ.q2(o2, a2)
    q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
    backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)

# MSE loss against Bellman backup
loss_q1 = ((q1 - backup)**2).mean()
loss_q2 = ((q2 - backup)**2).mean()
loss_q = loss_q1 + loss_q2

# First run one gradient descent step for Q1 and Q2
q_optimizer.zero_grad()
loss_q.backward()
q_optimizer.step()

Policy loss 계산 및 진행

In [None]:
a, logp_pi = ac.pi(o)
q1_pi = ac.q1(o, a)
q2_pi = ac.q2(o, a)
q_pi = torch.min(q1_pi, q2_pi)

# Entropy-regularized policy loss
loss_pi = (alpha * logp_pi - q_pi).mean()

# Next run one gradient descent step for pi.
pi_optimizer.zero_grad()
loss_pi.backward()
pi_optimizer.step()

Target network에 soft update 적용

In [None]:
polyak = 0.995
# Finally, update target networks by polyak averaging.
with torch.no_grad():
    for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
        p_targ.data.mul_(polyak)
        p_targ.data.add_((1 - polyak) * p.data)

### Soft Actor Critic
- 앞서 만든 클래스들과 함수들을 모두 합쳐 SAC 알고리즘 구현
- Q Loss:
  - Target 값 계산: $y_t=r_{t+1}+\gamma (1-d_{t+1}) \left(\min_{j\in(1,2)}Q_{\phi_{j}}(s_{t+1},\tilde{a}) - \alpha \log \pi_{\theta}(\tilde{a}|s_{t+1})\right)$
    - $\tilde{a}\sim\pi_{\theta}(\cdot|s_{t+1})$: Next state에 대한 action은 policy network에서 sampling 하여 사용하는 것이 핵심.
  - MSE Loss: $\sum_{i=1}^{B}\left(Q_{j}(s_{t_i},a_{t_i}) - y_{t_i}\right)^{2}$
  - 실제 논문의 구현과 차이가 있음.
- Policy Loss:
  - Maximize value: $- \sum_{i=1}^{B}Q_{\pi}(s_{t_i},\tilde{a}) - \alpha \log \pi_{\theta}(\tilde{a}|s_{t_i})$
      - $\tilde{a}\sim\pi_{\theta}(\cdot|s_{t+1})$: Next state에 대한 action은 policy network에서 sampling 하여 사용하는 것이 핵심.
      - $Q_{\pi}(s_{t_i},\tilde{a}) = \min_{j\in(1,2)}Q_{\phi_{j}}(s_{t_i},\tilde{a})$: 보수적인 Q Value를 사용하는 것이 Overestimation을 방지할 수 있음.


In [None]:
def sac(env_fn, actor_critic=MLPActorCritic, ac_kwargs=dict(), seed=0,
        steps_per_epoch=5000, epochs=5, replay_size=int(1e6), gamma=0.99,
        polyak=0.995, lr=1e-3, alpha=0.2, batch_size=100, start_steps=10000,
        update_after=1000, update_every=50, num_test_episodes=40, max_ep_len=1000):

    torch.manual_seed(seed)
    np.random.seed(seed)

    EpRet = []
    TestEpRet = []
    EpLen = []
    TestEpLen = []
    TotalEnvInteracts = []
    Q1Vals = []
    Q2Vals = []
    LogPi = []
    LossPi = []
    LossQ = []
    Time = []

    env, test_env = env_fn(), env_fn()
    obs_dim = env.observation_dim
    act_dim = env.action_dim
    act_limit = env.action_limit

    # Create actor-critic module and target networks
    ac = actor_critic(obs_dim, act_dim, act_limit, **ac_kwargs)
    ac_targ = deepcopy(ac)

    # List of parameters for both Q-networks (save this for convenience)
    q_params = list(ac.q1.parameters()) + list(ac.q2.parameters())

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(count_vars(module) for module in [ac.pi, ac.q1, ac.q2])
    print('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n'%var_counts)

    # Set up optimizers for policy and q-function
    pi_optimizer = Adam(ac.pi.parameters(), lr=lr)
    q_optimizer = Adam(q_params, lr=lr)

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)

    # Set up function for computing SAC Q-losses
    def compute_loss_q(data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']

        q1 = ac.q1(o,a)
        q2 = ac.q2(o,a)

        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a2, logp_a2 = ac.pi(o2)

            # Target Q-values
            q1_pi_targ = ac_targ.q1(o2, a2)
            q2_pi_targ = ac_targ.q2(o2, a2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma * (1 - d) * (q_pi_targ - alpha * logp_a2)

        # MSE loss against Bellman backup
        loss_q1 = ((q1 - backup)**2).mean()
        loss_q2 = ((q2 - backup)**2).mean()
        loss_q = loss_q1 + loss_q2

        # Useful info for logging
        q_info = dict(Q1Vals=q1.detach().numpy(),
                      Q2Vals=q2.detach().numpy())

        return loss_q, q_info

    # Set up function for computing SAC pi loss
    def compute_loss_pi(data):
        o = data['obs']
        a, logp_pi = ac.pi(o)
        q1_pi = ac.q1(o, a)
        q2_pi = ac.q2(o, a)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = (alpha * logp_pi - q_pi).mean()

        # Useful info for logging
        pi_info = dict(LogPi=logp_pi.detach().numpy())

        return loss_pi, pi_info

    def update(data):
        # First run one gradient descent step for Q1 and Q2
        q_optimizer.zero_grad()
        loss_q, q_info = compute_loss_q(data)
        loss_q.backward()
        q_optimizer.step()

        # Record things
        LossQ.append(loss_q.item())
        Q1Vals.append(q_info['Q1Vals'])
        Q2Vals.append(q_info['Q2Vals'])

        # Freeze Q-networks so you don't waste computational effort
        # computing gradients for them during the policy learning step.
        for p in q_params:
            p.requires_grad = False

        # Next run one gradient descent step for pi.
        pi_optimizer.zero_grad()
        loss_pi, pi_info = compute_loss_pi(data)
        loss_pi.backward()
        pi_optimizer.step()

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        for p in q_params:
            p.requires_grad = True

        # Record things
        LossPi.append(loss_pi.item())
        LogPi.append(pi_info['LogPi'])

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                # NB: We use an in-place operations "mul_", "add_" to update target
                # params, as opposed to "mul" and "add", which would make new tensors.
                p_targ.data.mul_(polyak)
                p_targ.data.add_((1 - polyak) * p.data)

    def get_action(o, deterministic=False):
        return ac.act(torch.as_tensor(o, dtype=torch.float32),
                      deterministic)

    def test_agent():
        for j in range(num_test_episodes):
            o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
            while not(d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time
                o, r, d, _ = test_env.step(get_action(o, True))
                ep_ret += r
                ep_len += 1
            TestEpRet.append(ep_ret)
            TestEpLen.append(ep_len)

    # Prepare for interaction with environment
    total_steps = steps_per_epoch * epochs
    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0

    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):

        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy.
        if t > start_steps:
            a = get_action(o)
        else:
            a = env.sample_action()

        # Step the env
        o2, r, d, _ = env.step(a)
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len==max_ep_len else d

        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len):
            EpRet.append(ep_ret)
            EpLen.append(ep_len)
            o, ep_ret, ep_len = env.reset(), 0, 0

        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = replay_buffer.sample_batch(batch_size)
                update(data=batch)

        # End of epoch handling
        if (t+1) % steps_per_epoch == 0:
            epoch = (t+1) // steps_per_epoch

            # Test the performance of the deterministic version of the agent.
            test_agent()

            TotalEnvInteracts.append(t)
            Time.append(time.time()-start_time)
            print(f'[Epoch:{epoch}] TestEpRet:{np.min(TestEpRet[-10:]):8.2f} < {np.mean(TestEpRet[-10:]):8.2f} < {np.max(TestEpRet[-10:]):8.2f}, TestEpLen:{np.mean(TestEpLen[-10:]):8.2f}, EpRet:{np.min(EpRet[-10:]):8.2f} < {np.mean(EpRet[-10:]):8.2f} < {np.max(EpRet[-10:]):8.2f}, EpLen:{np.mean(EpLen[-10:]):8.2f}, Q1Vals:{np.mean(Q1Vals[-10:]):8.2f}, Q2Vals:{np.mean(Q2Vals[-10:]):8.2f}, TotalEnvInteracts:{TotalEnvInteracts[-1]:8d}, LossPi:{np.mean(LossPi[-10:]):8.2f}, LossQ:{np.mean(LossQ[-10:]):8.2f}, Time:{Time[-1]:8.2f}')
    return ac, EpRet, EpLen, Q1Vals, Q2Vals, TotalEnvInteracts, LossPi, LossQ, Time

In [None]:
sac = sac(lambda : HalfCheetahEnv(), actor_critic=MLPActorCritic,
          ac_kwargs=dict(hidden_sizes=[256, 256]),
          gamma=0.99, seed=0, epochs=10)


Number of parameters: 	 pi: 73484, 	 q1: 72193, 	 q2: 72193

[Epoch:1] TestEpRet:  140.53 <   302.88 <   349.67, TestEpLen: 1000.00, EpRet: -408.55 <  -302.33 <  -229.21, EpLen: 1000.00, Q1Vals:    9.69, Q2Vals:    9.67, TotalEnvInteracts:    4999, LossPi:  -10.71, LossQ:    0.98, Time:   93.69
[Epoch:2] TestEpRet: -160.29 <     8.75 <    61.12, TestEpLen: 1000.00, EpRet: -454.36 <  -323.10 <  -229.21, EpLen: 1000.00, Q1Vals:   34.96, Q2Vals:   34.95, TotalEnvInteracts:    9999, LossPi:  -36.67, LossQ:    3.49, Time:  198.58
[Epoch:3] TestEpRet:  661.19 <   907.47 <  1004.76, TestEpLen: 1000.00, EpRet: -454.36 <  -198.35 <    93.83, EpLen: 1000.00, Q1Vals:   37.75, Q2Vals:   37.71, TotalEnvInteracts:   14999, LossPi:  -38.42, LossQ:    2.74, Time:  300.90
[Epoch:4] TestEpRet: 1470.78 <  1523.01 <  1572.72, TestEpLen: 1000.00, EpRet: -303.67 <   210.28 <   837.50, EpLen: 1000.00, Q1Vals:   43.00, Q2Vals:   43.01, TotalEnvInteracts:   19999, LossPi:  -43.58, LossQ:    3.19, Time:  403.3

### Test

In [None]:
env = HalfCheetahEnv()
imgs = []

obs = env.reset()
for t in range(1000):
  with torch.no_grad():
    obs = torch.as_tensor(obs, dtype=torch.float32)
    action = sac[0].act(obs, True)
  obs, reward, terminated, info = env.step(action)
  img = env.render()
  imgs.append(img)

media.show_video(imgs, fps=1/env.dt)

In [None]:
num_trajs = 100
max_ep_len = 1000

actions = []
observations = []
next_observations = []
rewards = []
terminals = []
ravg = []

o, ep_ret, ep_len = env.reset(), 0, 0
for t in range(num_trajs*max_ep_len):
  with torch.no_grad():
    a = sac[0].act(torch.as_tensor(o, dtype=torch.float32), True)

  o2, r, d, _ = env.step(a) # 시뮬레이션 진행
  ep_ret += r
  ep_len += 1
  d = False if ep_len==max_ep_len else d

  # 데이터 저장
  observations.append(o)
  actions.append(a)
  next_observations.append(o2)
  rewards.append(r)
  terminals.append(d)

  # Oservation 업데이트
  o = o2
  if d or (ep_len == max_ep_len):
      ravg.append(ep_ret)
      o, ep_ret, ep_len = env.reset(), 0, 0
print(f'오프라인 데이터 셋의 평균 리턴 : {np.mean(ravg)}')

# Type 정리
observations = np.array(observations).astype(np.float32)
actions = np.array(actions).astype(np.float32)
next_observations = np.array(next_observations).astype(np.float32)
rewards = np.array(rewards).astype(np.float32)
terminals = np.array(terminals).astype(np.bool_)

# Dictionary로 만들어서 Replay Buffer 생성
dataset = {"observations":observations,"actions":actions,"next_observations":next_observations,"rewards":rewards,"terminals":terminals,"ravg":ravg}

import pickle

with open('expert_dataset.pickle','wb') as f:
    pickle.dump(dataset, f)