In [None]:
# default_exp models.actor_critic

# Actor-critic Model
> RL Actor-critic model framework.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
from typing import Tuple
import torch
from torch import nn

In [None]:
#export
class Actor(nn.Module):
    """
    Actor Network
    """

    def __init__(self, embedded_state_size: int, action_weight_size: int, hidden_sizes: Tuple[int]):
        """
        Initialize Actor
        :param embedded_state_size: embedded state size
        :param action_weight_size: embedded action size
        :param hidden_sizes: hidden sizes
        """
        super(Actor, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(embedded_state_size, hidden_sizes[0]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[0], hidden_sizes[1]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[1], action_weight_size),
        )

    def forward(self, embedded_state):
        """
        Forward
        :param embedded_state: embedded state
        :return: action weight
        """
        return self.net(embedded_state)

In [None]:
#export
class Critic(nn.Module):
    """
    Critic Network
    """

    def __init__(self, embedded_state_size: int, embedded_action_size: int, hidden_sizes: Tuple[int]):
        """
        Initialize Critic
        :param embedded_state_size: embedded state size
        :param embedded_action_size: embedded action size
        :param hidden_sizes: hidden sizes
        """
        super(Critic, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(embedded_state_size + embedded_action_size, hidden_sizes[0]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[0], hidden_sizes[1]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[1], 1)
        )

    def forward(self, embedded_state, embedded_action):
        """
        Forward
        :param embedded_state: embedded state
        :param embedded_action: embedded action
        :return: Q value
        """
        return self.net(torch.cat([embedded_state, embedded_action], dim=-1))

In [None]:
embedded_state_size = 64
embedded_action_size = 5
actor_hidden_sizes = (128, 64)
critic_hidden_sizes = (32, 16)

actor = Actor(embedded_state_size=embedded_state_size,
              action_weight_size=embedded_action_size,
              hidden_sizes=actor_hidden_sizes)
critic = Critic(embedded_state_size=embedded_state_size,
                embedded_action_size=embedded_action_size,
                hidden_sizes=critic_hidden_sizes)

In [None]:
torch.manual_seed(0)
actor.forward(torch.rand(2,embedded_state_size))

tensor([[-0.1129, -0.0019,  0.0999, -0.0621,  0.0551],
        [-0.0445, -0.0425,  0.0677, -0.0397,  0.0003]],
       grad_fn=<AddmmBackward0>)

In [None]:
torch.manual_seed(0)
critic.forward(torch.rand(2,embedded_state_size),torch.rand(2,embedded_action_size))

tensor([[-0.1659],
        [-0.1432]], grad_fn=<AddmmBackward0>)

In [None]:
#hide
!pip install -q watermark
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d

Author: Sparsh A.

Last updated: 2021-12-19 09:57:05

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.104+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

IPython: 5.5.0
torch  : 1.10.0+cu111

