In [4]:
from functools import partial

import numpy as np
import torch
from jaxtyping import Float
from rich.pretty import pprint as pp
from torch import nn

ppe = partial(pp, expand_all=True)

In [None]:
class AttentionLayer(nn.Module):
    def __init__(self, d_k: int, d: int, *, bias: bool = False):
        """
        d - num of rows of X (X.size[1]) (len of observation space)
        d_k - len of one embedding
        """
        super().__init__()
        self.W_q = nn.Parameter(torch.randn(d, d_k))
        self.W_k = nn.Parameter(torch.randn(d, d_k))
        self.W_v = nn.Parameter(torch.randn(d, d_k))

        if bias:
            self.b_q = nn.Parameter(torch.randn(d_k))
            self.b_k = nn.Parameter(torch.randn(d_k))
            self.b_v = nn.Parameter(torch.randn(d_k))
        else:
            self.register_parameter("b_q", None)
            self.register_parameter("b_k", None)
            self.register_parameter("b_v", None)

    def forward(self, G: Float[torch.Tensor, "..."], X: Float[torch.Tensor, "..."]):
        Q = X @ self.W_q + self.b_q
        K = X @ self.W_k + self.b_k
        V = X @ self.W_v + self.b_v

        scores = Q @ K.T
        masked_scores = scores * G
        A = torch.softmax(masked_scores, dim=1)
        Z = A @ V  # New Node Embeddings
        return Z
