In [4]:
from functools import partial

import numpy as np
import torch
from jaxtyping import Float
from rich.pretty import pprint as pp
from torch import nn

ppe = partial(pp, expand_all=True)

In [None]:
class AttentionLayer(nn.Module):
    def __init__(self, d_k: int, d: int, *, bias: bool = False):
        """
        d - num of rows of X (X.size[1]) (len of observation space)
        d_k - len of one embedding
        """
        super().__init__()
        self.W_q = nn.Parameter(torch.randn(d, d_k))
        self.W_k = nn.Parameter(torch.randn(d, d_k))
        self.W_v = nn.Parameter(torch.randn(d, d_k))

        if bias:
            self.b_q = nn.Parameter(torch.randn(d_k))
            self.b_k = nn.Parameter(torch.randn(d_k))
            self.b_v = nn.Parameter(torch.randn(d_k))
        else:
            self.register_parameter("b_q", None)
            self.register_parameter("b_k", None)
            self.register_parameter("b_v", None)

    def forward(self, G: Float[torch.Tensor, "..."], X: Float[torch.Tensor, "..."]):
        Q = X @ self.W_q + self.b_q
        K = X @ self.W_k + self.b_k
        V = X @ self.W_v + self.b_v

        scores = Q @ K.T
        masked_scores = scores * G
        A = torch.softmax(masked_scores, dim=1)
        Z = A @ V  # New Node Embeddings
        return Z


In [5]:
from pathlib import Path

import gymnasium as gym
from libs.bert_sac.gymnasium_envs import AntLegsEnv


In [22]:
xml_model_dir = Path("assets/mujoco/models/").resolve()

# env = gym.make("mujoco/AntLegs", num_obs_shape=23)
env = AntLegsEnv(xml_file=str(xml_model_dir / "ant-4.xml"), num_obs_shape=23)

In [23]:
env.reset()

(array([ 0.66193837,  0.99279757,  0.06918375,  0.04694113, -0.08580864,
        -0.01303356,  0.06562912,  0.01811408,  0.08676466, -0.01431188,
        -0.04472134,  0.05985482, -0.02576076, -0.16465954, -0.05354587,
         0.17390433, -0.01276485, -0.03088796, -0.04587678,  0.01184972,
        -0.11630884, -0.03189005, -0.12605349, -0.07652192, -0.02849198,
         0.00916565, -0.0905689 ]),
 {})