# Crawler

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

from mlagents_envs.environment import UnityEnvironment
from gym_unity.envs import UnityToGymWrapper
import numpy as np

from agent import Agent
from ddpg_learning import ddpg

In [None]:
# Initialize the Environment
unity_env = UnityEnvironment(file_name="Crawler\\UnityEnvironment.app")
env = UnityToGymWrapper(unity_env)

# Get the action size
action_size = 20

# Get the state size
state_size = 172

# Get number of agents
num_agents = 1


In [None]:
torch.cuda.is_available()

In [None]:
#Initialize the Agent with given hyperparameters

BUFFER_SIZE = int(5e5)  # replay buffer size
BATCH_SIZE = 128        # batch size
GAMMA = 0.99            # discount factor
TAU = 0.01              # for soft update of target parameters
BETA = 0.4
BETA_INCREMENT = 0.000001
LR_ACTOR = 1e-4         # learning rate of the actor
LR_CRITIC = 1e-3        # learning rate of the critic
UPDATE_EVERY = 1        # how often to update the network

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #Torch device to use

agent = Agent(state_size=state_size,
              action_size=action_size,
              num_agents=num_agents,
              buffer_size=BUFFER_SIZE,
              batch_size=BATCH_SIZE,
              gamma=GAMMA,
              tau=TAU,
              beta=BETA,
              beta_increment=BETA_INCREMENT,
              learning_rate_actor=LR_ACTOR,
              learning_rate_critic=LR_CRITIC,
              device=device,
              update_every=UPDATE_EVERY,
              random_seed=42)

#agent.actor_local.load_state_dict(torch.load('checkpoints\\checkpoint_actor_local_1000.pth'))
#agent.critic_local.load_state_dict(torch.load('checkpoints\\checkpoint_critic_local_1000.pth'))
#agent.actor_target.load_state_dict(torch.load('checkpoints\\checkpoint_actor_target_1000.pth'))
#agent.critic_target.load_state_dict(torch.load('checkpoints\\checkpoint_critic_target_1000.pth'))

In [None]:
# Train the agent

AVERAGE_SCORE_SOLVED=2000

scores, num_episodes_solved = ddpg(env=env,
                                   agent=agent,
                                   num_agents=num_agents,
                                   average_score_solved=AVERAGE_SCORE_SOLVED)


In [None]:
#Plot the training session (scores per episode averaged across all agents)

def plot_scores(scores):
    plt.plot(scores, color='royalblue')
    plt.title('Scores per episode')
    plt.ylabel('score')
    plt.xlabel('episode #')
    plt.show()
    
plot_scores(scores)


In [None]:
#See the trained agent in action.

env_info = env.reset(train_mode=False)[brain_name] # reset the environment
states = env_info.vector_observations              # get the current state

#Uncomment these lines to use a saved checkpoint:
#agent = Agent(state_size=state_size,
#              action_size=action_size,
#              num_agents=num_agents,
#              buffer_size=BUFFER_SIZE,
#              batch_size=BATCH_SIZE,
#              gamma=GAMMA,
#              tau=TAU,
#              learning_rate_actor=LR_ACTOR,
#              learning_rate_critic=LR_CRITIC,
#              device=device,
#              update_every=UPDATE_EVERY,
#              random_seed=42)
#agent.actor_local.load_state_dict(torch.load('solved_checkpoint_actor.pth'))

scores = np.zeros(num_agents) 
while True:
    actions = agent.act(states, add_noise=False)   # select an action
    env_info = env.step(actions)[brain_name]       # send the action to the environment
    next_states = env_info.vector_observations     # get the next state
    rewards = env_info.rewards                     # get the reward
    dones = env_info.local_done                    # see if episode has finished
    states = next_states                           # roll over the state to next time step
    scores += rewards                              # update the score
    if np.any(dones):                              # exit loop if episode finished
        break
    
print("Average Score: {}".format(np.mean(scores)))


In [None]:
env.close()