In [5]:
from warnings import catch_warnings

import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F  # noqa: N812
from skrl.agents.torch.sac import SAC, SAC_DEFAULT_CONFIG
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, GaussianMixin, Model
from skrl.trainers.torch import SequentialTrainer
from skrl.utils import set_seed
from torch import nn
from transformers import DistilBertConfig, DistilBertModel
from typing_extensions import override

In [6]:
set_seed()  # seed for reproducibility, e.g. `set_seed(42)` for fixed seed

with catch_warnings(action="ignore"):
    gym_env = gym.make("Pendulum-v1") # load the environment
    env = wrap_env(gym_env, wrapper="gymnasium") # wrap the environment

[38;20m[skrl:INFO] Seed: 1638628088[0m
[38;20m[skrl:INFO] Environment class: gymnasium.core.Wrapper, gymnasium.utils.record_constructor.RecordConstructorArgs[0m
[38;20m[skrl:INFO] Environment wrapper: gymnasium[0m


In [47]:
from collections.abc import Mapping
from typing import Any


class CustomModel(nn.Module):
    def __init__(
        self,
        num_struct_elements: int,
        attention_mask: torch.LongTensor,
        components_mask: torch.LongTensor,
        device,
    ):
        self.bert_config = DistilBertConfig(
            vocab_size=10000,
            hidden_size=1,
            num_hidden_layers=2,
            num_attention_heads=1,
            intermediate_size=100,
            hidden_act="gelu",
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1,
            max_position_embeddings=50,
            type_vocab_size=2,
            initializer_range=0.02,
            layer_norm_eps=1e-12,
            pad_token_id=0,
            position_embedding_type="absolute",
            use_cache=True,
            classifier_dropout=None,
        )
        super().__init__()
        self.num_struct_elements = num_struct_elements
        self.attention_mask = attention_mask.to(device)
        self.components_mask = components_mask.to(device)
        self.ones_vector = torch.ones(self.num_struct_elements, 1).to(device)
        self.distilbert_1 = DistilBertModel(self.bert_config)
        self.distilbert_2 = DistilBertModel(self.bert_config)

    def forward(self, inputs_embeds: torch.Tensor):
        attention_mask = self.attention_mask
        components_mask = self.components_mask

        embeds = inputs_embeds.repeat(self.num_struct_elements, 1)
        embeds.unsqueeze_(-1)

        outputs_1 = self.distilbert_1(
            # input_ids=torch.ones(input_ids.size()),
            inputs_embeds=embeds,
            attention_mask=attention_mask,
        )

        last_hidden_state_1 = outputs_1["last_hidden_state"]

        input_2 = torch.sum(last_hidden_state_1, dim=2)
        input_2.mul_(components_mask)  # summing through columns
        input_2 = torch.sum(input_2, dim=0)

        input_2 = self.ones_vector @ input_2.view(1, input_2.size()[0])

        input_2 = input_2.view(input_2.size()[0], input_2.size()[1], 1)

        outputs_2 = self.distilbert_2(
            # input_ids=torch.ones(input_2.size()),
            inputs_embeds=input_2,
            attention_mask=attention_mask,
        )
        last_hidden_state_2 = outputs_2["last_hidden_state"]

        input_2 = torch.sum(last_hidden_state_2, dim=2)
        input_2.mul_(components_mask)
        # summing through columns
        input_2 = torch.sum(input_2, dim=0)
        return input_2


class SoftQNetwork(Model):
    """Custom BERT enabled Critic for SAC approach."""

    @override
    def __init__(
        self,
        observation_space,
        action_space,
        device,
        num_struct_elements: int,
        att_mask: torch.LongTensor,
        components_mask: torch.LongTensor,
    ):
        super().__init__(observation_space, action_space, device)

        self.preprocess_layer = CustomModel(
            num_struct_elements=num_struct_elements,
            attention_mask=att_mask,
            components_mask=components_mask,
            device=device,
        )
        self.fc1 = nn.Linear(
            np.array(observation_space.shape).prod() + np.prod(observation_space.shape),
            256,
        )
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    @override
    def compute(self, inputs: Mapping[str, torch.Tensor | Any], role: str = ""):
        obs = inputs["states"]
        action = inputs["taken_actions"]
        obs = self.preprocess_layer(obs).unsqueeze(0)
        x = torch.cat([obs, action], 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x, {}


class Actor(Model):
    LOG_STD_MAX = 2
    LOG_STD_MIN = -5

    def __init__(
        self,
        observation_space,
        action_space,
        device,
        num_struct_elements: int,
        att_mask: torch.LongTensor,
        components_mask: torch.LongTensor,
    ):
        super().__init__(observation_space, action_space, device)
        self.preprocess_layer = CustomModel(
            num_struct_elements=num_struct_elements,
            attention_mask=att_mask,
            components_mask=components_mask,
            device=device,
        )
        self.fc1 = nn.Linear(np.array(observation_space.shape).prod(), 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, np.prod(observation_space.shape))
        self.fc_logstd = nn.Linear(256, np.prod(observation_space.shape))
        # # ? action rescaling
        # self.register_buffer(
        #     "action_scale",
        #     torch.tensor((env.action_space.high - env.action_space.low) / 2.0, dtype=torch.float32),
        # )
        # self.register_buffer(
        #     "action_bias",
        #     torch.tensor((env.action_space.high + env.action_space.low) / 2.0, dtype=torch.float32),
        # )

    @override
    def compute(self, inputs: Mapping[str, torch.Tensor | Any], role: str = ""):
        x = inputs["states"]
        x = self.preprocess_layer(x).unsqueeze(0)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        log_std = self.LOG_STD_MIN + 0.5 * (self.LOG_STD_MAX - self.LOG_STD_MIN) * (log_std + 1)
        return mean, log_std

    @override
    def act(self, inputs: Mapping[str, torch.Tensor], role: str = ''):
        mean, log_std = self.compute(inputs)
        std = log_std.exp()
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()  # for reparameterization trick (mean + std * N(0,1))
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        # Enforcing Action Bound
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + 1e-6)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, {"mean": mean}


In [48]:
device = env.device

# instantiate a memory as experience replay
memory = RandomMemory(memory_size=20000, num_envs=env.num_envs, device=device, replacement=False)

# configuration
n_legs = 4
num_struct_elements = 9
att_mask = torch.from_numpy(
    np.array(
        [
            [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
            [1] * 7 + [0] * 6 + [1] * 8 + [0] * 6,
            [1] * 5 + [0] * 2 + [1] * 2 + [0] * 4 + [1] * 6 + [0] * 2 + [1] * 2 + [0] * 4,
            [1] * 5 + [0] * 4 + [1] * 2 + [0] * 2 + [1] * 6 + [0] * 4 + [1] * 2 + [0] * 2,
            [1] * 5 + [0] * 6 + [1] * 2 + [0] * 0 + [1] * 6 + [0] * 6 + [1] * 2 + [0] * 0,
            [0] * 5 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 6,
            [0] * 7 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 4,
            [0] * 9 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 2,
            [0] * 11 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 0,
        ]
    )
).to(torch.int64)
components_mask = (
    torch.from_numpy(
        np.array(
            [
                [1] * 5 + [0] * 8 + [1] * 6 + [0] * 8,
                [0] * 5 + [1] + [0] * 13 + [1] + [0] * 7,
                [0] * 6 + [1] + [0] * 13 + [1] + [0] * 6,
                [0] * 7 + [1] + [0] * 13 + [1] + [0] * 5,
                [0] * 8 + [1] + [0] * 13 + [1] + [0] * 4,
                [0] * 9 + [1] + [0] * 13 + [1] + [0] * 3,
                [0] * 10 + [1] + [0] * 13 + [1] + [0] * 2,
                [0] * 11 + [1] + [0] * 13 + [1] + [0] * 1,
                [0] * 12 + [1] + [0] * 13 + [1] + [0] * 0,
            ]
        )
    ).to(torch.int64),
)[0]

  logger.warn(


In [49]:
# instantiate the agent's models (function approximators).
# SAC requires 5 models, visit its documentation for more detail

models = {}
models["policy"] = Actor(
    env.observation_space,
    env.action_space,
    device,
    num_struct_elements=num_struct_elements,
    att_mask=att_mask,  # type: ignore
    components_mask=components_mask,  # type: ignore
)
models["critic_1"] = SoftQNetwork(
    env.observation_space,
    env.action_space,
    device,
    num_struct_elements=num_struct_elements,
    att_mask=att_mask,  # type: ignore
    components_mask=components_mask,  # type: ignore
)
models["critic_2"] = SoftQNetwork(
    env.observation_space,
    env.action_space,
    device,
    num_struct_elements=num_struct_elements,
    att_mask=att_mask,  # type: ignore
    components_mask=components_mask,  # type: ignore
)
models["target_critic_1"] = SoftQNetwork(
    env.observation_space,
    env.action_space,
    device,
    num_struct_elements=num_struct_elements,
    att_mask=att_mask,  # type: ignore
    components_mask=components_mask,  # type: ignore
)
models["target_critic_2"] = SoftQNetwork(
    env.observation_space,
    env.action_space,
    device,
    num_struct_elements=num_struct_elements,
    att_mask=att_mask,  # type: ignore
    components_mask=components_mask,  # type: ignore
)


In [50]:
# initialize models' parameters (weights and biases)
for model in models.values():
    model.init_parameters(method_name="normal_", mean=0.0, std=0.1)

In [51]:
cfg = SAC_DEFAULT_CONFIG.copy()
cfg["discount_factor"] = 0.98
cfg["batch_size"] = 100
cfg["random_timesteps"] = 0
cfg["learning_starts"] = 1000
cfg["learn_entropy"] = True
# logging to TensorBoard and write checkpoints (in timesteps)
cfg["experiment"]["write_interval"] = 75
cfg["experiment"]["checkpoint_interval"] = 750
cfg["experiment"]["directory"] = "runs/torch/Pendulum"

agent = SAC(
    models=models,
    memory=memory,
    cfg=cfg,
    observation_space=env.observation_space, # type: ignore
    action_space=env.action_space, # type: ignore
    device=device,
)

In [52]:
# configure and instantiate the RL trainer
cfg_trainer = {"timesteps": 15000, "headless": True}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent]) # type: ignore

In [53]:
# start training
trainer.train()

  logger.warn(


  0%|          | 0/15000 [00:00<?, ?it/s]

  0%|          | 0/15000 [00:01<?, ?it/s]


RuntimeError: shape '[9, 1, 1, 3]' is invalid for input of size 243