In [None]:
# Import libraries
import math
import random
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time

from torch.distributions import Normal
from mlagents_envs.environment import UnityEnvironment
from IPython.display import clear_output

In [None]:
# Check for CUDA
if (torch.cuda.is_available()):
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

# Neural Network Architecture

In [None]:
# Function to initialize weights of NN from normal distribution
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, mean=0.0, std=0.1)
        nn.init.constant_(m.bias, 0.5)

In [None]:
# Actor-Critic Neural Network
class ActorCritic(nn.Module):
    def __init__(self, inputs, outputs, hidden_size, std=0):
        super(ActorCritic, self).__init__()
        
        self.critic = nn.Sequential(
            nn.Linear(inputs, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, int(hidden_size/2)),
            nn.Tanh(),
            nn.Linear(int(hidden_size/2), int(hidden_size/2)),
            nn.LayerNorm(int(hidden_size/2)),
            nn.Tanh(),
            nn.Linear(int(hidden_size/2), 1)
        )
        
        self.actor = nn.Sequential(
            nn.Linear(inputs, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, int(hidden_size/2)),
            nn.Tanh(),
            nn.Linear(int(hidden_size/2), int(hidden_size/2)),
            nn.LayerNorm(int(hidden_size/2)),
            nn.Tanh(),
            nn.Linear(int(hidden_size/2), outputs),
            nn.Tanh()
        )
        self.log_std = nn.Parameter(torch.ones(outputs) * std)
        self.apply(init_weights)
        
    def forward(self, x):
        value = self.critic(x)
        mu = self.actor(x)
        std   = self.log_std.exp().expand_as(mu)
        dist = Normal(mu, std)
        return dist, value

# Plot

In [None]:
# Function to plot rewards
def plot(frame_idx, rewards):
    clear_output(True)
    plt.figure(figsize=(40, 8))
    plt.subplot(131)
    plt.title('Frame %s. reward: %s' % (frame_idx, rewards[-1]))

    plt.plot(rewards)
    plt.show()

# Run NN

In [None]:
# Test of environment
def test_env(max_frames=1000):
    env.reset()
    step_result = env.get_steps(behaviorName)
    DecisionSteps = step_result[0]
    TerminalSteps = step_result[1]
    state = []

    if (len(DecisionSteps) > 0):
        state = DecisionSteps.obs[0]
        reward = DecisionSteps.reward
    
    total_reward = 0
    frame_count = 0
    while(True):
        frame_count += 1
        state = torch.FloatTensor(state).to(device)
        dist, _ = model(state)
        action = dist.sample()
        env.set_actions(behaviorName, np.array(action.cpu()))
        env.step()
        step_result = env.get_steps(behaviorName)
        DecisionSteps = step_result[0]
        TerminalSteps = step_result[1]
        next_state = []
        reward = []
        if (frame_count < max_frames):
            if (len(DecisionSteps) > 0):
                next_state = DecisionSteps.obs[0]
                reward = DecisionSteps.reward
                state = next_state
                total_reward += reward
            if (len(TerminalSteps) > 0):
                reward = TerminalSteps.reward
                total_reward += reward
                break
        else:
            break
    return total_reward

# Generalized Advantage Estimator

In [None]:
# Conmpute GAE
def compute_gae(next_value, rewards, masks, vals,
               gamma=0.99, tau=0.95):
    vals = vals + [next_value]
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * vals[step + 1] * masks[step] - vals[step]
        gae = delta + gamma * tau * gae * masks[step]
        returns.insert(0, gae + vals[step])
    return returns

# Proximal Policy Optimization

In [None]:
# Batch sampling
def ppo_iter(mini_batch_size, states, actions, log_probs,
            returns, advantage):
    batch_size = states.size(0)
    for _ in range(batch_size // mini_batch_size):
        rand_ids = np.random.randint(0, batch_size, mini_batch_size)
        yield states[rand_ids, :], actions[rand_ids, :], log_probs[rand_ids, :], returns[rand_ids, :], advantage[rand_ids, :]

In [None]:
# Update PPO weights
def ppo_update(ppo_epochs, mini_batch_size, states, actions, 
               log_probs, returns, advantages, clip_param=0.2):
    mean_loss = 0
    for _ in range(ppo_epochs):
        for state, action, old_log_probs, return_, advantage in ppo_iter(mini_batch_size, states, actions, log_probs, returns, advantages):
            dist, value = model(state)
            entropy = dist.entropy().mean()
            new_log_probs = dist.log_prob(action)
            
            ratio = (new_log_probs - old_log_probs).exp()
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage
            
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = (return_ - value).pow(2).mean()
            
            loss = (0.5 * critic_loss + actor_loss - 0.001*entropy)
            mean_loss += loss.mean()
            
            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()
            
    return mean_loss

# Start learning

In [None]:
# Connect to Unity
n = True
counter = 0
while(n):
    try:
        print('Connecting...')
        env = UnityEnvironment(file_name=None, base_port=5004)
        env.reset()
        behaviorNames = list(env.behavior_specs.keys())
        behaviorName = behaviorNames[0]
        behavior_spec = env.behavior_specs[behaviorName]
        print('Connected...')
        n = False
    except:
        counter += 1
        time.sleep(1)
        if (counter > 11):
            print('Connection failed...')
            n = False
        pass

In [None]:
# Inputs and outputs of NN
num_inputs = behavior_spec.observation_shapes[0][0]
num_outputs = behavior_spec.action_shape
print("Inputs: {}, Outputs: {}".format(num_inputs, num_outputs))

In [None]:
# Hyper parameters of NN
hidden_size = 256
lr = 1e-4
num_steps = 20
mini_batch_size = 5
ppo_epochs = 4
threshold_reward = 185
max_frames = 150000
max_steps = 800
frame_idx = 0

In [None]:
# Create NN
model = ActorCritic(num_inputs, num_outputs, hidden_size).to(device)
#model.load_state_dict(torch.load('PPO_Good.dat'))
optimizer = optim.Adam(model.parameters(), lr = lr)

In [None]:
# Get states and actions
env.reset()
step_result = env.get_steps(behaviorName)
DecisionSteps = step_result[0]
state = DecisionSteps.obs[0]
reward = DecisionSteps.reward
dist, value = model(torch.FloatTensor(state).to(device))
action = dist.sample()
print(action)
print("Reward: {}".format(reward))
for s in state[0]:
    print(s, end=', ')

In [None]:
# Start learning process
env.reset()
early_stop = False
rr = []
test_rewards = []
mean_loss = []
frame_idx = 0
while frame_idx < max_frames and not early_stop:
    log_probs = []
    values = []
    states = []
    actions = []
    rewards = []
    masks = []
    entropy = 0
    
    for _ in range(num_steps):
        step_result = env.get_steps(behaviorName)
        DecisionSteps = step_result[0]
        TerminalSteps = step_result[1]
        
        state = []
        if (len(DecisionSteps) > 0):
            state = DecisionSteps.obs[0]
        elif (len(TerminalSteps) > 0):
            state = TerminalSteps.obs[0]
        
        state = torch.FloatTensor(state).to(device)
        dist, value = model(state)
        action = dist.sample()
        if(int(torch.isnan(torch.min(action))) == 1): #we have Nan when sampling distribution
            print("Error: distribution=", dist, "%.2f, %.2f" % (float(action[0][0]), float(action[0][1])))
            
        env.set_actions(behaviorName, np.array(action.cpu()))
        env.step()
        
        step_result = env.get_steps(behaviorName)
        DecisionSteps = step_result[0]
        TerminalSteps = step_result[1]
        
        next_state = []
        reward = []
        mask = 0
        
        if (len(DecisionSteps) > 0):
            next_state = DecisionSteps.obs[0]
            reward = DecisionSteps.reward
            mask = [1.0]
        if (len(TerminalSteps) > 0):
            next_state = TerminalSteps.obs[0]
            reward = TerminalSteps.reward
            mask = [0.0]
        
        log_prob = dist.log_prob(action)
        entropy += dist.entropy().mean()
        
        log_probs.append(log_prob)
        values.append(value)
        rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))
        masks.append(torch.FloatTensor(mask).unsqueeze(1).to(device))
        
        states.append(state)
        actions.append(action)
        
        state = next_state
        frame_idx += 1
            
        if frame_idx % 800 == 0:
            test_reward = np.mean([test_env(max_steps) for _ in range(10)])
            test_rewards.append(test_reward)
            plot(frame_idx, test_rewards)
            rr.append(test_reward)
            if test_reward > threshold_reward: early_stop = True
            
    next_state = torch.FloatTensor(next_state).to(device)
    _, next_value = model(next_state)
    returns = compute_gae(next_value, rewards, masks, values)
    
    returns = torch.cat(returns).detach()
    rewards = torch.cat(rewards).detach()
    log_probs = torch.cat(log_probs).detach()
    values = torch.cat(values).detach()
    states = torch.cat(states)
    actions = torch.cat(actions)
    advantage = returns - values
    m_loss = ppo_update(ppo_epochs, mini_batch_size, states, actions, log_probs, returns, advantage)
    mean_loss.append(m_loss.item())

In [None]:
plt.figure(figsize=(12, 6))
plt.plot(mean_loss)
plt.title('Loss')

In [None]:
env.reset()

In [None]:
env.close()

# Test

In [None]:
# Connect to Unity
n = True
counter = 0
while(n):
    try:
        print('Connecting...')
        env = UnityEnvironment(file_name=None, base_port=5004)
        env.reset()
        behaviorNames = list(env.behavior_specs.keys())
        behaviorName = behaviorNames[0]
        behavior_spec = env.behavior_specs[behaviorName]
        print('Connected...')
        n = False
    except:
        counter += 1
        time.sleep(1)
        if (counter > 11):
            print('Connection failed...')
            n = False
        pass

In [None]:
# Load states dictionary to model
model = ActorCritic(num_inputs, num_outputs, hidden_size).to(device)
model.load_state_dict(torch.load('PPO_Good.dat'))

In [None]:
# Test environment
env.reset()
time.sleep(1)
test_env(5000)

In [None]:
env.close()