In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

# import torch.distributions as D
from torch.distributions import Categorical
import numpy as np
import sys
from debug_side_channel import DebugSideChannel
from torch.distributions import Normal

In [3]:
ACTION_SIZE = 5

In [21]:
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, action_size, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2):
        super().__init__()
        self.init_w = init_w
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max

        self.fc1 = nn.Linear(state_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.mu = nn.Linear(hidden_size, action_dim)
        self.mu.weight.data.uniform_(-init_w, init_w)
        self.mu.bias.data.uniform_(-init_w, init_w)

        self.log_std = nn.Linear(hidden_size, action_dim)
        self.log_std.weight.data.uniform_(-init_w, init_w)
        self.log_std.bias.data.uniform_(-init_w, init_w)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        mu = self.mu(x)
        log_std = self.log_std(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        return mu, log_std
    
    # mean, log_std = self.forward(state)
    # std = log_std.exp()
    # normal = Normal(0, 1)
    # z      = normal.sample()
    # action = torch.tanh(mean+ std*z.to(device))
    # log_prob = Normal(mean, std).log_prob(mean+ std*z.to(device)) - torch.log(1 - action.pow(2) + epsilon)
    def evaluate(self, state, epsilon=1e-6):
        state = torch.FloatTensor(state).unsqueeze(0)
        mu, log_std = self.forward(state)
        std = log_std.exp()        

        normal = Normal(0, 1)
        z = normal.sample()
        action = torch.tanh(mu + std*z)
        action = (action + 1) * (ACTION_SIZE/2)
        action = torch.clamp(action, 0, ACTION_SIZE-1)
        log_prob = Normal(mu, std).log_prob(mu + std*z) - torch.log(1 - action.pow(2) + epsilon)
        return action, log_prob

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        mu, log_std = self.forward(state)
        std = log_std.exp()
        
        normal = Normal(0, 1)
        z = normal.sample()
        action = torch.tanh(mu + std*z)

        action = (action + 1) * (ACTION_SIZE/2)
        action = torch.round(action)
        action = torch.clamp(action, 0, ACTION_SIZE-1)
        return action.detach().cpu().numpy()[0]

In [22]:
state = torch.tensor([[  5.5000,   0.5000,   0.5000,  -0.0000, 180.0000,  -0.0000]])
a = Actor(6, 1, 5, 128)

In [23]:
a.forward(state)

(tensor([[0.1432]], grad_fn=<AddmmBackward0>),
 tensor([[0.0181]], grad_fn=<ClampBackward1>))

In [52]:
a.select_action(state)

array([[4.]], dtype=float32)

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size, init_w=3e-3, log_std_min=-20, log_std_max=2):
        super(PolicyNetwork, self).__init__()
        
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        
        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        
        self.mean_linear = nn.Linear(hidden_size, num_actions)
        self.mean_linear.weight.data.uniform_(-init_w, init_w)
        self.mean_linear.bias.data.uniform_(-init_w, init_w)
        
        self.log_std_linear = nn.Linear(hidden_size, num_actions)
        self.log_std_linear.weight.data.uniform_(-init_w, init_w)
        self.log_std_linear.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        
        mean    = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        
        return mean, log_std
    
    def evaluate(self, state, epsilon=1e-6):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        
        normal = Normal(0, 1)
        z      = normal.sample()
        action = torch.tanh(mean+ std*z.to(device))
        log_prob = Normal(mean, std).log_prob(mean+ std*z.to(device)) - torch.log(1 - action.pow(2) + epsilon)
        return action, log_prob, z, mean, log_std
        
    
    def get_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        mean, log_std = self.forward(state)
        std = log_std.exp()
        
        normal = Normal(0, 1)
        z      = normal.sample().to(device)
        action = torch.tanh(mean + std*z)
        
        action  = action.cpu()
        return action[0]