In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from Models import actor
from Models import critic

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [15]:
class Agent():
    def __init__(self, n_states = 33, n_actions = 4, actor_hidden = 100, 
                 critic_hidden = 100, seed = 0, roll_out = 5, replay_buffer_size = 1e6, 
                 replay_batch = 128, lr_actor = 5e-5,  lr_critic = 5e-5, epsilon = 0.3, 
                 tau = 1e-3,  gamma = 1, update_interval = 4, noise_fn = np.random.normal):
        
        self.n_states = n_states
        self.n_actions = n_actions
        self.actor_hidden = actor_hidden # hidden nodes in the 1st layer of actor network
        self.critic_hidden = critic_hidden # hidden nodes in the 1st layer of critic network
        self.seed = seed
        self.roll_out = roll_out # roll out steps for n-step bootstrap; taken to be same as in D4PG paper
        self.replay_buffer = replay_buffer_size
        self.replay_batch = replay_batch # batch of memories to sample during training
        self.lr_actor = lr_actor # this was taken to same as the value in the D4PG paper for hard tasks
        self.lr_critic = lr_critic # taken from the D4PG paper
        self.epsilon = epsilon # to scale the noise before mixing with the actions; same as in D4PG paper
        self.tau = tau # for soft updates of the target networks
        self.gamma = gamma # do not decrease this below 1
        # note that we want the reacher to stay in goal position as long as possible
        # thus keeping gamma = 1 will ecourage the agent to increase its holding time
        self.update_every = update_interval # steps between successive updates
        self.noise = noise_fn # noise function; 
        # Note D4PG paper reported that 
        # using normal distribution instead of OU noise does not affect performance
        # will also experiment with OU noise if the need arises
        
        
        self.local_actor = actor(self.n_states, self.n_actions, self.actor_hidden, self.seed).to(device)
        self.local_critic = critic(self.n_states, self.n_actions, self.critic_hidden, self.seed).to(device)
        
        self.target_actor = actor(self.n_states, self.n_actions, self.actor_hidden, self.seed).to(device)
        self.target_critic = critic(self.n_states, self.n_actions, self.critic_hidden, self.seed).to(device)
        
        # initialize target_actor and target_critic weights to be 
        # the same as the corresponding local networks
        for target_c_params, local_c_params in zip(self.target_critic.parameters(), 
                                                   self.local_critic.parameters()):
            target_c_params.data.copy_(local_c_params.data)
        
        for target_a_params, local_a_params in zip(self.target_actor.parameters(), 
                                                   self.local_actor.parameters()):
            target_a_params.data.copy_(local_a_params.data)
            
        # optimizers for the local actor and local critic
        self.actor_optim = torch.optim.Adam(self.local_actor.parameters(), lr = self.lr_actor)
        self.critic_optim = torch.optim.Adam(self.local_critic.parameters(), lr = self.lr_critic)
        
        # steps counter to keep track of steps passed between updates
        self.t_steps = 0
        
        # replay memory 
        self.memory = ReplayBuffer()
    
    def act(self, states):
        # convert states to a torch tensor and move to the device
        # for the multiagent case we will get a batch of states 
        states = torch.from_numpy(states).to(device).float()
        self.local_actor.eval()
        with torch.no_grad():
            actions = self.local_actor(states).cpu().detach().numpy()
            noise = self.noise(size = actions.shape)
            actions = actions + noise
        self.local_actor.train()
        return actions
            
           

In [16]:
class ReplayBuffer():
    
    def __init__(self):
        pass
    
    def add(self):
        pass
    
    def sample(self):
        pass
    
    def __len__(self):
        pass