In [1]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.distributions.normal import Normal

from octopus.replay_buffer import ReplayBuffer

In [3]:
class ActorNetwork(nn.Module):
    def __init__(self, n_observations, n_actions, max_action, fc1_dim=256, fc2_dim=256):
        super().__init__()
        self.reparam_noise = 1e-6
        self.max_action = max_action
        
        self.layers = nn.Sequential(
            nn.Linear(n_observations, fc1_dim),
            nn.ReLU(),
            nn.Linear(fc1_dim, fc2_dim),
            nn.ReLU(),
        )
        
        self.mu = nn.Linear(fc2_dim, n_actions)
        self.sigma = nn.Linear(fc2_dim, n_actions)
    
    def forward(self, state):
        prob = self.layers(state)
        
        # mean of distribution
        mu = self.mu(state)
        
        # standard deviation of distribution
        sigma = self.sigma(state)
        sigma = torch.clamp(sigma, min=self.reparam_noise, max=1)

        return mu, sigma

    def sample_normal(self, state, reparametrize=True):
        mu, sigma = self.forward(state)
        probs = Normal(mu, sigma)
        
        if reparametrize:
            # do reparametrize trick
            # add some noise to the acton!
            actions = probs.rsample()
        else:
            actions = probs.sample()
        
        # scales the values of actions to the range [-self.max_action, self.max_action]
        actions = torch.tanh(actions) * self.max_action
        
        log_probs = probs.log_prob(actions)
        
        # come from the paper's appendix
        # c. Enforcing Action Bounds
        log_probs -= torch.log(1 - actions.pow(2) + self.reparam_noise)
        log_probs = log_probs.sum(1, keepdim=True)
        
        return actions, log_probs