## Soft Actor Critic

ref: 

- https://pfrl.readthedocs.io/en/latest/agents.html
- https://github.com/pfnet/pfrl/blob/master/examples/mujoco/reproduction/soft_actor_critic/train_soft_actor_critic.py
- https://github.com/pfnet/pfrl/blob/master/pfrl/agents/soft_actor_critic.py

In [7]:
import pfrl
from pfrl import experiments, replay_buffers, utils
from pfrl.nn.lmbda import Lambda

import torch
from torch import distributions, nn
import gym
import numpy
from skimage.transform import rescale

import robo_gym
from robo_gym.wrappers.exception_handling import ExceptionHandling

In [8]:
target_machine_ip = 'localhost'
num_envs = 3

def make_env(process_idx, test):
    env = gym.make('CubeRoomSearch-v0', ip=target_machine_ip, gui=False)
    env = ExceptionHandling(env)
    
    # Unwrap TimiLimit wrapper
    assert isinstance(env, gym.wrappers.TimeLimit)
    env = env.env

    env = pfrl.wrappers.CastObservationToFloat32(env)

    return env

def make_batch_env(test):
    return pfrl.envs.MultiprocessVectorEnv(
        [
            functools.partial(make_env, idx, test)
            for idx, env in enumerate(range(num_envs))
        ]
    )

In [9]:
sample_env = make_env(process_idx=0, test=False)

timestep_limit = sample_env.spec.max_episode_steps
obs_space = sample_env.observation_space
action_space = sample_env.action_space
print("Observation space:", obs_space)
print("Action space:", action_space)

Starting new Robot Server | Tentative 1
Successfully started Robot Server at localhost:58189


AssertionError: 

In [None]:
rescale = 128
obs_size = obs_space['agent_pose'].low.size + rescale**2
action_size = action_space.low.size

In [None]:
def squashed_diagonal_gaussian_head(x):
    assert x.shape[-1] == action_size * 2
    mean, log_scale = torch.chunk(x, 2, dim=1)
    log_scale = torch.clamp(log_scale, -20.0, 2.0)
    var = torch.exp(log_scale * 2)
    base_distribution = distributions.Independent(
        distributions.Normal(loc=mean, scale=torch.sqrt(var)), 1
    )
    # cache_size=1 is required for numerical stability
    return distributions.transformed_distribution.TransformedDistribution(
        base_distribution, [distributions.transforms.TanhTransform(cache_size=1)]
    )

policy = nn.Sequential(
    nn.Linear(obs_size, 256),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.ReLU(),
    nn.Linear(256, action_size * 2),
    Lambda(squashed_diagonal_gaussian_head),
)
torch.nn.init.xavier_uniform_(policy[0].weight)
torch.nn.init.xavier_uniform_(policy[2].weight)
torch.nn.init.xavier_uniform_(policy[4].weight, gain=1.0)

policy_optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)

In [None]:
def make_q_func_with_optimizer():
    q_func = nn.Sequential(
        pfrl.nn.ConcatObsAndAction(),
        nn.Linear(obs_size + action_size, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 1),
    )
    torch.nn.init.xavier_uniform_(q_func[1].weight)
    torch.nn.init.xavier_uniform_(q_func[3].weight)
    torch.nn.init.xavier_uniform_(q_func[5].weight)
    q_func_optimizer = torch.optim.Adam(q_func.parameters(), lr=3e-4)
    return q_func, q_func_optimizer

q_func1, q_func1_optimizer = make_q_func_with_optimizer()
q_func2, q_func2_optimizer = make_q_func_with_optimizer()

rbuf = replay_buffers.ReplayBuffer(10 ** 6)

In [None]:
gamma = 0.99
replay_start_size = 1000
gpu = -1
batch_size = 64
burnin_action_func = action_space.sample
entropy_target = -action_size
temperature_optimizer_lr=3e-4

In [None]:
agent = pfrl.agents.SoftActorCritic(
    policy,
    q_func1,
    q_func2,
    policy_optimizer,
    q_func1_optimizer,
    q_func2_optimizer,
    rbuf,
    gamma=gamma,
    replay_start_size=replay_start_size,
    gpu=gpu,
    minibatch_size=batch_size,
    burnin_action_func=burnin_action_func,
    entropy_target=entropy_target,
    temperature_optimizer_lr=temperature_optimizer_lr,
)