In [1]:
import os
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

In [2]:
# critic network
class CriticNetwork(nn.Module):
    def __init__(self, beta, input_dims, fc1_dims, fc2_dims, n_actions, chkpt_name, chkpt_dir='tmp/td3'):
        super(CriticNetwork, self).__init__()
        self.beta = beta
        self.input_dims = input_dims
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.n_actions = n_actions
        self.checkpoint_name = chkpt_name
        self.checkpoint_dir = chkpt_dir
        self.checkpoint_file = os.path.join(self.checkpoint_dir, self.checkpoint_name + '_td3')
        
        # the following implementation is not for 2-D state representations
        self.fc1 = nn.Linear(self.input_dims[0] + n_actions, self.fc1_dims)
        self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
        self.q1 = nn.Linear(self.fc2_dims, 1)
        
        self.optimizer = optim.Adam(self.parameters(), lr=self.beta)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        
        self.to(self.device)
        
    def forward(self, state, action):
        q1_action_value = self.fc1(T.cat([state,action], dim=1))
        q1_action_value = F.relu(q1_action_value)
        q1_action_value = self.fc2(q1_action_value)
        q1_action_value = F.relu(q1_action_value)
        
        q1 = self.q1(q1_action_value)
        
        return q1
    
    def save_checkpoint(self):
        print("Saving critic checkpoint...")
        T.save(self.state_dict(), self.checkpoint_file)
        print("Critic checkpoint saved!")
    
    def load_checkpoint(self):
        print("Loading critic checkpoint...")
        self.load_state_dict(T.load(self.checkpoint_file))
        print("Critic checkpoint loaded!")

In [5]:
# actor network
class ActorNetwork(nn.Module):
    def __init__(self, alpha, input_dims, fc1_dims, fc2_dims, n_actions, chkpt_name, chkpt_dir='tmp/td3'):
        super(ActorNetwork, self).__init__()
        self.alpha = alpha
        self.input_dims = input_dims
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.n_actions = n_actions
        self.checkpoint_name = chkpt_name
        self.checkpoint_dir = chkpt_dir
        self.checkpoint_file = os.path.join(self.checkpoint_dir, self.checkpoint_name + '_td3')
        
        # the following implementation is not for 2-D state representations
        self.fc1 = nn.Linear(*self.input_dims, self.fc1_dims)
        self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
        self.mu = nn.Linear(self.fc2_dims, self.n_actions)
        
        self.optimizer = optim.Adam(self.parameters(), lr=self.alpha)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        
        self.to(self.device)
        
    def forward(self, state):
        action_prob = self.fc1(state)
        action_prob = F.relu(action_prob)
        action_prob = self.fc2(action_prob)
        action_prob = F.relu(action_prob)
        
        a = T.tanh(self.mu(action_prob))
        
        return a
    
    def save_checkpoint(self):
        print("Saving actor checkpoint...")
        T.save(self.state_dict(), self.checkpoint_file)
        print("Actor checkpoint saved!")
    
    def load_checkpoint(self):
        print("Loading actor checkpoint...")
        self.load_state_dict(T.load(self.checkpoint_file))
        print("Actor checkpoint loaded!")

In [None]:
# buffer
class ReplayBuffer():
    def __init__(self, max_size, input_shape, action_shape):
        self.mem_size = max_size
        self.mem_cntr = 0

In [None]:
# agent
class Agent():
    def __init__(self, alpha, beta, input_dims, n_actions, gamma=0.99, 
                 layer1_size=400, layer2_size=300, buffer_size=1e6, batch_size=100, 
                 tau, udate_actor_interval=2, warmup=10000, max_size=10e3, 
                 noise=0.1, env):
#         self.alpha = alpha
#         self.beta = beta
        self.n_actions = n_actions
        self.gamma = gamma
#         self.layer1_size = layer1_size
#         self.layer2_size = layer2_size
        self.batch_size = batch_size
        self.tau = tau
        self.udate_actor_interval = udate_actor_interval
        self.warmup = warmup
        self.max_size = max_size
        self.noise = noise
        self.max_action = env.action_space.high
        self.min_action = env.action_space.low
        self.learn_step_cntr = 0
        self.time_step = 0 //countdown to end of warmup period
        self.warmup = warmup
        
        
        self.actor = ActorNetwork(alpha, input_dims, layer1_size, layer2_size, n_actions, "actor")
        self.critic_one = CriticNetwork(beta, input_dims, layer1_size, layer2_size, n_actions, "critic_one")
        self.critic_two = CriticNetwork(beta, input_dims, layer1_size, layer2_size, n_actions, "critic_two")
        
        self.target_actor = ActorNetwork(alpha, input_dims, layer1_size, layer2_size, n_actions, "target_actor")
        self.target_critic_one = CriticNetwork(beta, input_dims, layer1_size, layer2_size, n_actions, "target_critic_one")
        self.target_critic_two = CriticNetwork(beta, input_dims, layer1_size, layer2_size, n_actions, "target_critic_two")
        
        self.memory = ReplayBuffer(buffer_size, input_dims, n_actions)
        
        self.update_network_parameters(tau=1)
        
    def choose_action(self, observation):
        if self.time_step > self.warmup:
            state = T.tensor(observation, dtype=T.float).to(self.actor.device)
            mu = self.actor.forward(state).to(self.actor.device)
        else:
            mu  = T.tensor(np.random.normal(scale=self.noise, size=(self.n_actions,)))
        
        mu_prime = mu + T.tensor(np.random.normal(scale=self.noise),
                                dtype=T.float).to(self.actor.device)
        #clamp actions to acceptable action space
        mu_prime = T.clap(mu_prime, self.min_action[0], self.max_action[0])
        
        self.time_step += 1
        
        return mu_prime.cpu().detach().numpy()
    
    def remember(self, state, action, reward, new_state, done):
        self.memory.store_transition(state, action, reward, new_state, done)