In [14]:
import torch
import random

nets = random.sample([1, 2, 3, 4], 4)

torch.stack(tuple(torch.tensor([net]) for net in nets), dim=1)

tensor([[2, 4, 3, 1]])

In [29]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class LayerNorm(nn.Module):

    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

class Q_network(nn.Module):
    
    def __init__(self, args, l1=400, l2=300, l3=300):
        super(Q_network, self).__init__()
        self.args = args
        self.l1 = l1
        # Construct input interface (Hidden Layer 1)
        self.w_l1 = nn.Linear(args.state_dim+args.action_dim, l1)
        # Hidden Layer 2
        self.w_l2 = nn.Linear(l1, l2)
        if self.args.use_ln:
            self.lnorm1 = LayerNorm(l1)
            self.lnorm2 = LayerNorm(l2)
        # Out
        self.w_out = nn.Linear(l3, 1)
        self.w_out.weight.data.mul_(0.1)
        self.w_out.bias.data.mul_(0.1)
        self.to(self.args.device)

    def forward(self, input_):
        # Hidden Layer 1 (Input Interface)
        out = self.w_l1(input_)
        if self.args.use_ln:out = self.lnorm1(out)
        out = F.leaky_relu(out)
        # Hidden Layer 2
        out = self.w_l2(out)
        if self.args.use_ln: out = self.lnorm2(out)
        out = F.leaky_relu(out)
        # Output interface
        out = self.w_out(out)
        return out

class Critic(nn.Module):

    def __init__(self, args, n_nets = 2):
        super(Critic, self).__init__()
        self.args = args

        l1 = 400;
        l2 = 300;
        l3 = l2

        self.Q1_nets = []
        self.Q2_nets = []

        for i in range(n_nets):
            self.Q1_network = Q_network(args, l1, l2, l3)
            self.Q2_network = Q_network(args, l1, l2, l3)
        
            self.Q1_nets.append(self.Q1_network)
            self.Q2_nets.append(self.Q2_network)

    def forward(self, state, action):
        sa = torch.cat((state, action), dim=-1)
        quantiles_Q1 = torch.stack(tuple(net(sa) for net in self.Q1_nets), dim=1)
        quantiles_Q2 = torch.stack(tuple(net(sa) for net in self.Q2_nets), dim=1)
        return quantiles_Q1, quantiles_Q2

In [30]:
class Args:
    def __init__(self):
        self.state_dim = 3
        self.action_dim = 1
        self.device = 'cpu'
        self.use_ln = False

args = Args()

critic = Critic(args)

In [31]:
test_input = torch.rand(1, 3)
test_action = torch.rand(1, 1)

quantiles_Q1, quantiles_Q2 = critic.forward(state=test_input, action=test_action)

In [36]:
quantiles_Q2.var(dim=1).mean()

tensor(1.1701e-05, grad_fn=<MeanBackward0>)