In [1]:
import torch, gymnasium
from torch.distributions.categorical import Categorical
import numpy as np
from torch import nn

def ortho_init(layer, scale=np.sqrt(2)):
    nn.init.orthogonal_(layer.weight, gain=scale)
    nn.init.constant_(layer.bias, 0)
    return layer

class MlpSeparate(nn.Module): 
    def __init__(self):
        super(MlpSeparate, self).__init__()
        self.action_nvec = torch.tensor([4,4,4,4])
        hidden_size = 256
        input_size = 16
        self.critic = nn.Sequential(
            ortho_init(nn.Linear(input_size, hidden_size)),
            nn.Tanh(),
            ortho_init(nn.Linear(hidden_size, hidden_size)),
            nn.Tanh(),
            ortho_init(nn.Linear(hidden_size, 1), scale=1)
        )
        self.actor = nn.Sequential(
            ortho_init(nn.Linear(input_size, hidden_size)),
            nn.Tanh(),
            ortho_init(nn.Linear(hidden_size, hidden_size)),
            nn.Tanh(),
            ortho_init(nn.Linear(hidden_size, self.action_nvec.sum()), scale=0.01)
        )

    def get_value(self, obs):
        return self.critic(obs)
    
    # mask is a boolean tensor with the same shape as your action space, where True indicates an invalid action.
    def get_action(self, obs, action=None, mask=None):
        print(' obs', obs)
        logits = self.actor(obs)
        if mask is not None:
            assert mask.any(), mask
            masks = mask.reshape(logits.shape).bool()
            logits = torch.where(masks, logits, -1e8)
        print(' logits', logits) 
        split_logits = torch.split(logits, self.action_nvec.tolist(), dim=1) # logits are slightly different 
        multi_dists = [Categorical(logits=logits) for logits in split_logits]
        if action is None: 
            action = torch.tensor([dist.sample() for dist in multi_dists], device=obs.device)
        logprob = torch.stack([dist.log_prob(a) for a, dist in zip(action.T, multi_dists)])
        entropy = torch.stack([dist.entropy() for dist in multi_dists])
        return action, logprob.sum(dim=0), entropy.sum(dim=0)
    
    def get_det_action(self, obs, action=None):
        logits = self.actor(obs)
        split_logits = torch.reshape(logits, (self.action_nvec.size, self.action_nvec[0]))
        return torch.argmax(split_logits, dim=1)

In [41]:
torch.set_default_dtype(torch.float64) 

In [42]:
model = MlpSeparate()
# load
model.load_state_dict(torch.load('test.pt'))

<All keys matched successfully>

In [43]:
torch.set_printoptions(precision=15, sci_mode=False)

In [44]:
obsb1 = torch.tensor([[1.000000000000000, 2.000000000000000, 2.000000000000000,   
         2.000000000000000, 0.670000016689301, 0.340000003576279,
         0.140000000596046, 0.109999999403954, 0.560000002384186,
         0.959999978542328, 0.230000004172325, 0.949999988079071,
         0.000000000000000, 0.670000016689301, 0.000000000000000,
         0.560000002384186]])
model.actor(obsb1)

tensor([[-0.012930031149011, -0.001350823133152,  0.005173190206362,
          0.004375997139879,  0.001556734199947,  0.012896205650492,
         -0.013287034370875,  0.004398340103247,  0.001213372412979,
         -0.015685295074480,  0.031027841928699, -0.002897765298380,
         -0.008842607965760,  0.000788311019257,  0.008520488607345,
         -0.002693830312777]], grad_fn=<AddmmBackward0>)

In [45]:
obsb2 =  torch.tensor([[2.000000000000000, 2.000000000000000, 2.000000000000000,
         0.000000000000000, 0.670000016689301, 0.340000003576279,
         0.140000000596046, 0.109999999403954, 0.560000002384186,
         0.959999978542328, 0.230000004172325, 0.949999988079071,
         0.109999999403954, 0.000000000000000, 0.949999988079071,
         0.000000000000000]])
model.actor(obsb2)

tensor([[    -0.011655277289170,     -0.000043480378212,      0.005238374913708,
              0.002273797044490,      0.001762097236874,      0.010387559115177,
             -0.007799827173139,      0.003275219226455,      0.009683629636783,
             -0.014490072681988,      0.024021823365105,     -0.004761678157189,
             -0.005153438420307,      0.005546502476704,      0.008477908510767,
             -0.002293286252387]], grad_fn=<AddmmBackward0>)

In [47]:
obs = torch.tensor([[1.000000000000000, 2.000000000000000, 2.000000000000000,   
         2.000000000000000, 0.670000016689301, 0.340000003576279,
         0.140000000596046, 0.109999999403954, 0.560000002384186,
         0.959999978542328, 0.230000004172325, 0.949999988079071,
         0.000000000000000, 0.670000016689301, 0.000000000000000,
         0.560000002384186],[2.000000000000000, 2.000000000000000, 2.000000000000000,
         0.000000000000000, 0.670000016689301, 0.340000003576279,
         0.140000000596046, 0.109999999403954, 0.560000002384186,
         0.959999978542328, 0.230000004172325, 0.949999988079071,
         0.109999999403954, 0.000000000000000, 0.949999988079071,
         0.000000000000000]])
model.actor(obs)

tensor([[    -0.012930031149011,     -0.001350823133152,      0.005173190206362,
              0.004375997139879,      0.001556734199947,      0.012896205650492,
             -0.013287034370875,      0.004398340103247,      0.001213372412979,
             -0.015685295074480,      0.031027841928699,     -0.002897765298380,
             -0.008842607965760,      0.000788311019257,      0.008520488607345,
             -0.002693830312777],
        [    -0.011655277289170,     -0.000043480378212,      0.005238374913708,
              0.002273797044490,      0.001762097236874,      0.010387559115177,
             -0.007799827173139,      0.003275219226455,      0.009683629636783,
             -0.014490072681988,      0.024021823365105,     -0.004761678157189,
             -0.005153438420307,      0.005546502476704,      0.008477908510767,
             -0.002293286252387]], grad_fn=<AddmmBackward0>)

In [46]:
obs = torch.tensor([[1.000000000000000, 2.000000000000000, 2.000000000000000,
         2.000000000000000, 0.670000016689301, 0.340000003576279,
         0.140000000596046, 0.109999999403954, 0.560000002384186,
         0.959999978542328, 0.230000004172325, 0.949999988079071,
         0.000000000000000, 0.670000016689301, 0.000000000000000,
         0.560000002384186],
        [2.000000000000000, 2.000000000000000, 2.000000000000000,
         0.000000000000000, 0.670000016689301, 0.340000003576279,
         0.140000000596046, 0.109999999403954, 0.560000002384186,
         0.959999978542328, 0.230000004172325, 0.949999988079071,
         0.109999999403954, 0.000000000000000, 0.949999988079071,
         0.000000000000000]])
model.actor(obs)

tensor([[    -0.012930031149011,     -0.001350823133152,      0.005173190206362,
              0.004375997139879,      0.001556734199947,      0.012896205650492,
             -0.013287034370875,      0.004398340103247,      0.001213372412979,
             -0.015685295074480,      0.031027841928699,     -0.002897765298380,
             -0.008842607965760,      0.000788311019257,      0.008520488607345,
             -0.002693830312777],
        [    -0.011655277289170,     -0.000043480378212,      0.005238374913708,
              0.002273797044490,      0.001762097236874,      0.010387559115177,
             -0.007799827173139,      0.003275219226455,      0.009683629636783,
             -0.014490072681988,      0.024021823365105,     -0.004761678157189,
             -0.005153438420307,      0.005546502476704,      0.008477908510767,
             -0.002293286252387]], grad_fn=<AddmmBackward0>)