In [None]:
import gym
from gym import spaces
import robo_gym
from robo_gym.wrappers.exception_handling import ExceptionHandling
import numpy as np
import pfrl
import torch
from torch import distributions, nn
import matplotlib.pyplot as plt

In [None]:
class WrapPyTorch(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(WrapPyTorch, self).__init__(env)
        self.observation_space = spaces.Box(low=-1, high=100, shape=(1,env.map_size,env.map_size,), dtype=np.float32)
        
    def observation(self, observation):
        return np.expand_dims(observation, axis=0)

    def reset(self, **kwargs):
        return self.observation(self.env.reset(**kwargs))

In [None]:
target_machine_ip = 'robot-server' # or other machine 'xxx.xxx.xxx.xxx'

# initialize environment
env = gym.make('CubeSearchInCubeRoomObsMapOnly-v0', ip=target_machine_ip, gui=True)
env = ExceptionHandling(env)
env = WrapPyTorch(env)

state = env.reset(
    new_room=True, 
    new_agent_pose=True, 
    obstacle_count=40,
    room_length_max=15.0, 
    room_mass_min=150.0, 
    room_mass_max=160.0, 
)
print(state)

In [None]:
timestep_limit = env.spec.max_episode_steps
obs_space = env.observation_space
action_space = env.action_space
obs_size = obs_space.low.size
action_size = action_space.low.size

print(f'timelimit: \t{timestep_limit}')
print(f'obs_space: \t{obs_space} \naction_space: \t{action_space}')
print(f'obs_size: \t{obs_size}')
print(f'action_size: \t{action_size}')

In [None]:
def conv2d_size_out(size, kernel_size=5, stride=2):
    return (size - (kernel_size - 1) - 1) // stride + 1
        
def make_conv2d_layer(width, height):
    convW = conv2d_size_out(width, 4, 4) # 128 -> 32
    convW = conv2d_size_out(convW, 4, 4) # 32 -> 8
    convW = conv2d_size_out(convW, 3, 1) # 8 -> 6

    convH = conv2d_size_out(height, 4, 4)
    convH = conv2d_size_out(convH, 4, 4)
    convH = conv2d_size_out(convH, 3, 1)

    linear_input_size = convW * convH * 64
    print('size:', linear_input_size)

    # RGB Image tensor as input
    return nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=4,stride=4),
        nn.ELU(),
        nn.Conv2d(32, 64, kernel_size=4, stride=4),
        nn.ELU(),
        nn.Conv2d(64, 64, kernel_size=3,stride=1),
        nn.ELU(),
        nn.Flatten(),
    ), linear_input_size

def make_linear_layer(linear_input_size, out_size):
    return nn.Sequential(
        nn.Linear(linear_input_size, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, out_size),
    )

In [None]:
def squashed_diagonal_gaussian_head(x):
    assert x.shape[-1] == action_size * 2
    mean, log_scale = torch.chunk(x, 2, dim=1)
    log_scale = torch.clamp(log_scale, -20.0, 2.0)
    var = torch.exp(log_scale * 2)
    base_distribution = distributions.Independent(
        distributions.Normal(loc=mean, scale=torch.sqrt(var)), 1
    )
    # cache_size=1 is required for numerical stability
    return distributions.transformed_distribution.TransformedDistribution(
        base_distribution, [distributions.transforms.TanhTransform(cache_size=1)]
    )

In [None]:
class PolicyFunction(nn.Module):
    def __init__(self, width, height, action_size):
        super().__init__()

        # RGB Image tensor as input
        self.selectTrackFeatures, self.linear_input_size = make_conv2d_layer(width, height)
        self.fc1 = make_linear_layer(self.linear_input_size, action_size*2)
    
    def forward(self, state):
        x = self.selectTrackFeatures(state)
        x = self.fc1(x)
        return squashed_diagonal_gaussian_head(x)

obs_map_shape = obs_space.low.shape
print(obs_map_shape)
policy = PolicyFunction(obs_map_shape[0], obs_map_shape[1], action_size)
policy_optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)

In [None]:
class QFunction(nn.Module):
    def __init__(self, width, height, action_size):
        super().__init__()

        # RGB Image tensor as input
        self.selectTrackFeatures, self.linear_input_size = make_conv2d_layer(width, height)
        self.fc1 = make_linear_layer(self.linear_input_size + action_size, 1)
    
    def forward(self, state_and_action):
        state = state_and_action[0]
        occupancy_vector = self.selectTrackFeatures(state)
        x = torch.cat((occupancy_vector, state_and_action[1]), axis=-1)
        return self.fc1(x)

q_func1 = QFunction(obs_map_shape[0], obs_map_shape[1], action_size)
q_func2 = QFunction(obs_map_shape[0], obs_map_shape[1], action_size)
q_func1_optimizer = torch.optim.Adam(q_func1.parameters(), lr=3e-4)
q_func2_optimizer = torch.optim.Adam(q_func2.parameters(), lr=3e-4)

In [None]:
rbuf = pfrl.replay_buffers.ReplayBuffer(10 ** 6)

In [None]:
def burnin_action_func():
    """Select random actions until model is updated one or more times."""
    return np.random.uniform(action_space.low, action_space.high).astype(np.float32)

In [None]:
gamma = 0.99
replay_start_size = 10000
gpu = 0
batch_size = 256
entropy_target = -action_size
temperature_optimizer_lr = 3e-4

agent = pfrl.agents.SoftActorCritic(
    policy,
    q_func1,
    q_func2,
    policy_optimizer,
    q_func1_optimizer,
    q_func2_optimizer,
    rbuf,
    gamma=gamma,
    replay_start_size=replay_start_size,
    gpu=gpu,
    minibatch_size=batch_size,
    burnin_action_func=burnin_action_func,
    entropy_target=entropy_target,
    temperature_optimizer_lr=temperature_optimizer_lr,
)

In [None]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.info("Start training")

n_episodes = 1000
max_episode_len = 50
eval_step = 10
total_R = 0

for i in range(1, n_episodes + 1):
    obs = env.reset(new_room=False, new_agent_pose=True)
    R = 0  # return (sum of rewards)
    t = 0  # time step
    while True:
        # Uncomment to watch the behavior in a GUI window
        # env.render()
        action = agent.act(obs)
        obs, reward, done, _ = env.step(action)
        R += reward
        t += 1
        reset = t == max_episode_len
        agent.observe(obs, reward, done, reset)
        # print(f"action: {action}, reward: {reward}")
        if done or reset:
            break
            
    total_R += R
    if i % eval_step == 0:
        logging.info(f'episode: {i}, R_mean: {total_R/eval_step} \nstatistics: {agent.get_statistics()}')
        total_R = 0

logging.info('Finished')