# Tries

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
from functools import partial
from pathlib import Path

import gymnasium as gym
import numpy as np
import torch
from beartype import beartype as typechecker
from jaxtyping import Float, jaxtyped
from rich.pretty import pprint as pp
from torch import Tensor, nn
from torch.nn import functional as F

from libs.bert_sac.gymnasium_envs import AntLegsEnv

ppe = partial(pp, expand_all=True)

torch.set_default_device("cpu")
torch.set_default_dtype(torch.float64)

## Attention seeking

### Explorations

In [6]:
X = torch.rand((4, 1))
X.shape

torch.Size([4, 1])

In [7]:
W_q = torch.rand((1, 10))
W_q.shape

torch.Size([1, 10])

In [8]:
Q = X @ W_q
Q.shape

torch.Size([4, 10])

In [9]:
W_k = torch.rand((1, 10))
W_k.shape

torch.Size([1, 10])

In [10]:
K = X @ W_k
K.shape

torch.Size([4, 10])

In [11]:
score = Q @ K.T
score.shape

torch.Size([4, 4])

In [12]:
score.shape

torch.Size([4, 4])

In [13]:
soft_score = F.softmax(score, dim=0)
soft_score

tensor([[0.2471, 0.2480, 0.2530, 0.2384],
        [0.2324, 0.2339, 0.2452, 0.2195],
        [0.1301, 0.1346, 0.1824, 0.1004],
        [0.3903, 0.3834, 0.3194, 0.4417]])

In [14]:
W_v = torch.rand((1, 10))
V = X @ W_v
V.shape

torch.Size([4, 10])

In [15]:
soft_score_value = soft_score @ V
soft_score_value.shape

torch.Size([4, 10])

In [16]:
torch.sum(soft_score_value, dim=1)

tensor([2.4464, 2.2954, 1.2750, 3.9657])

### Determination

In [17]:
batch_size = 500
num_obs = 27
hidden_dim = 768

In [27]:
X = torch.rand((batch_size, num_obs, 1))
X.shape

torch.Size([500, 27, 1])

In [28]:
W_q = torch.rand((1, 1, hidden_dim))
W_k = torch.rand((1, 1, hidden_dim))
W_v = torch.rand((1, 1, hidden_dim))
W_q.shape

torch.Size([1, 1, 768])

In [29]:
Q = X @ W_q
K = X @ W_k
Q.shape, K.shape

(torch.Size([500, 27, 768]), torch.Size([500, 27, 768]))

In [33]:
torch.rand((1, 4, 1)) * torch.rand((1, 4, 1))

tensor([[[0.2093],
         [0.2668],
         [0.2135],
         [0.1477]]])

In [34]:
@jaxtyped(typechecker=typechecker)
def self_attention(
    X: Float[Tensor, "batch num_obs 1"],
    W_q: Float[Tensor, "1 1 hidden_dim"],
    W_k: Float[Tensor, "1 1 hidden_dim"],
    W_v: Float[Tensor, "1 1 hidden_dim"],
) -> Float[Tensor, "batch num_obs"]:
    Q = X @ W_q
    K = X @ W_k
    V = X @ W_v

    scores = Q @ K.mT
    soft_scores = F.softmax(scores, dim=1)
    soft_scores_value = soft_scores @ V
    Z = torch.sum(soft_scores_value, dim=2)

    return Z


In [35]:
self_attention(X, W_q, W_k, W_v).shape

torch.Size([500, 27])

### Attention Layer

In [36]:
from libs.bert_sac.models import AttentionLayer

In [63]:
INPUT = torch.tensor([[2], [3], [4], [5]], dtype=torch.float64).unsqueeze(0)
MASK = torch.tensor(
    [[1, 1, 1, 1], [1, 1, 0, 0], [1, 0, 1, 1], [1, 0, 1, 1]], dtype=torch.float64
).unsqueeze(0)
INPUT.shape, MASK.shape

(torch.Size([1, 4, 1]), torch.Size([1, 4, 4]))

In [64]:
att = AttentionLayer(mask=MASK, hidden_dim=10)

In [65]:
att(INPUT).shape

torch.Size([1, 4])

In [None]:
MASK = torch.sum(MASK, dim=0)

In [None]:
MASK

tensor([4., 2., 3., 3.], device='cuda:0')

In [66]:
at1 = AttentionLayer(mask=MASK, hidden_dim=3)
at2 = AttentionLayer(mask=MASK, hidden_dim=2)
at3 = AttentionLayer(mask=MASK, hidden_dim=1)

In [76]:
out1 = at1(INPUT)
out1.unsqueeze(-1).shape

torch.Size([1, 4, 1])

In [68]:
INPUT.shape

torch.Size([1, 4, 1])

In [79]:
out2 = at2(out1.unsqueeze(-1))
out2.shape

torch.Size([1, 4])

In [81]:
class Unsqueeze(nn.Module):
    def __init__(self, dim: int) -> None:
        super().__init__()
        self.dim = dim

    def forward(self, input):
        return torch.unsqueeze(input, dim=self.dim)


In [97]:
class Logger(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, input):
        pp(input.shape)
        return input

In [100]:
attention = nn.Sequential(
    AttentionLayer(mask=MASK, hidden_dim=3),
    Unsqueeze(dim=-1),
    # nn.LogSigmoid(),
    AttentionLayer(mask=MASK, hidden_dim=6),
    Unsqueeze(dim=-1),
    # nn.LogSigmoid(),
    AttentionLayer(mask=MASK, hidden_dim=10),
    nn.Softmax(dim=1),
)
attention

Sequential(
  (0): AttentionLayer(
    (softmax): Softmax(dim=1)
  )
  (1): Logger()
  (2): Unsqueeze()
  (3): Logger()
  (4): AttentionLayer(
    (softmax): Softmax(dim=1)
  )
  (5): Unsqueeze()
  (6): AttentionLayer(
    (softmax): Softmax(dim=1)
  )
  (7): Softmax(dim=1)
)

In [101]:
attention(INPUT)

tensor([[1.0000e+00, 1.6234e-99, 1.6304e-99, 1.6234e-99]],
       grad_fn=<SoftmaxBackward0>)

In [None]:
xml_model_dir = Path("assets/mujoco/models/").resolve()

# env = gym.make("mujoco/AntLegs", num_obs_shape=23)
env = AntLegsEnv(xml_file=str(xml_model_dir / "ant-4.xml"), num_obs_shape=23)
env.reset()

(array([ 0.65527404,  0.99311952, -0.02701614, -0.05445851, -0.10009006,
        -0.00443353, -0.08069813,  0.06857853, -0.06214799,  0.0304708 ,
        -0.04546053, -0.01395165,  0.01609419,  0.13267775,  0.0665999 ,
         0.07760071,  0.14569279, -0.05596373, -0.14798958,  0.08728911,
         0.07033081, -0.01377987,  0.34725854,  0.10455345, -0.07564038,
        -0.0821203 , -0.04508681]),
 {})

In [None]:
att_mask = [
    [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
    [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
    [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
    [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
    [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
    [1] * 7 + [0] * 6 + [1] * 8 + [0] * 6,
    [1] * 5 + [0] * 2 + [1] * 2 + [0] * 4 + [1] * 6 + [0] * 2 + [1] * 2 + [0] * 4,
    [1] * 5 + [0] * 4 + [1] * 2 + [0] * 2 + [1] * 6 + [0] * 4 + [1] * 2 + [0] * 2,
    [1] * 5 + [0] * 6 + [1] * 2 + [0] * 0 + [1] * 6 + [0] * 6 + [1] * 2 + [0] * 0,
    [0] * 5 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 6,
    [0] * 7 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 4,
    [0] * 9 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 2,
    [0] * 11 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 0,
    [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
    [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
    [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
    [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
    [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
    [1] * 6 + [0, 1, 0, 1, 0, 1, 0] + [1] * 7 + [0, 1, 0, 1, 0, 1, 0],
    [1] * 7 + [0] * 6 + [1] * 8 + [0] * 6,
    [1] * 5 + [0] * 2 + [1] * 2 + [0] * 4 + [1] * 6 + [0] * 2 + [1] * 2 + [0] * 4,
    [1] * 5 + [0] * 4 + [1] * 2 + [0] * 2 + [1] * 6 + [0] * 4 + [1] * 2 + [0] * 2,
    [1] * 5 + [0] * 6 + [1] * 2 + [0] * 0 + [1] * 6 + [0] * 6 + [1] * 2 + [0] * 0,
    [0] * 5 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 6,
    [0] * 7 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 4,
    [0] * 9 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 2,
    [0] * 11 + [1] * 2 + [0] * 12 + [1] * 2 + [0] * 0,
]
np.array(att_mask)

array([[1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
        0, 1, 0, 1, 0],
       [1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
        0, 1, 0, 1, 0],
       [1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
        0, 1, 0, 1, 0],
       [1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
        0, 1, 0, 1, 0],
       [1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1,
        0, 1, 0, 1, 0],
       [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
        0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1,
        1, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0,
        0, 1, 1, 0, 0],
       [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0,
        0, 0, 0, 0, 0],
       [0, 0, 0, 0, 