In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AxisAttentionBase(nn.Module):
    def __init__(
        self,
        size_Q: int,
        size_KV: int,
        size: int) -> None:
        super().__init__()

        self.size = size

        self.W_q = nn.Linear(size_Q, size)
        self.W_k = nn.Linear(size_KV, size)
        self.W_v = nn.Linear(size_KV, size)

        self.fc_o = nn.Linear(size, size)

    def forward(self, query: torch.Tensor, key_value: torch.Tensor) -> torch.Tensor:
        q: torch.Tensor = self.W_q(query)
        k: torch.Tensor = self.W_k(key_value)
        v: torch.Tensor = self.W_v(key_value)

        A = torch.einsum('bqd,bkd->bqk', q, k) / math.sqrt(self.size)
        A = F.softmax(A, dim=-1)

        output = torch.einsum('bqk,bkd->bqd', A, v)

        output = output + F.relu(self.fc_o(output))

        return output

In [3]:
aa = AxisAttentionBase(512, 512, 512).to(device)
X = torch.randn(128, 10, 512).to(device)
output = aa(X, X)
print(output.shape)

torch.Size([128, 10, 512])


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import string

class AxisAttention(nn.Module):
    def __init__(
        self,
        dim_Q: int,
        dim_KV: int,
        dim_attn: int,
        dropout: float = 0.0,
        axis: int = 1,
    ) -> None:
        """
        Initialize an AxisAttention module.

        Args:
            dim_Q: Input dimension of query tensor
            dim_KV: Input dimension of key-value tensor
            dim_attn: Internal attention dimension
            dropout: Dropout probability after softmax
            axis: Axis to perform attention over (default: 1)
            use_residual: Whether to use residual connection
        """
        super().__init__()

        self.dim_attn = dim_attn
        self.axis = axis

        # Projection layers
        self.proj_q = nn.Linear(dim_Q, dim_attn)
        self.proj_k = nn.Linear(dim_KV, dim_attn)
        self.proj_v = nn.Linear(dim_KV, dim_attn)

        # Output projection
        self.proj_out = nn.Linear(dim_attn, dim_Q)

        # Dropout for attention weights
        self.dropout = nn.Dropout(dropout)

        self.cached_expressions: dict = {}

        self.attention_normalizer = math.sqrt(dim_attn)

    def forward(self, query: torch.Tensor, key_value: torch.Tensor, verbose: bool = False) -> torch.Tensor:
        """
        Perform attention along the specified axis.

        Args:
            query: Query tensor of shape (..., dim_Q)
            key_value: Key-value tensor of shape (..., dim_KV)
            verbose: Whether to print intermediate shapes

        Returns:
            Output tensor with the same shape as query
        """
        # Project inputs to attention space
        q = self.proj_q(query)      # (..., dim_attn)
        k = self.proj_k(key_value)  # (..., dim_attn)
        v = self.proj_v(key_value)  # (..., dim_attn)

        if len(self.cached_expressions) != 5:
            # Ensure axis is positive
            axis = self.axis if self.axis >= 0 else query.ndim + self.axis
            if axis == query.ndim or axis == key_value.ndim:
                raise ValueError(f"Axis cannot be the last dimension of the tensors. Got {axis=} alghough the tensors have {query.ndim=} and {key_value.ndim=} dimensions.")

            # Generate einsum notation
            ndim = query.ndim

            # Create dimension labels using lowercase letters for all dimensions
            available_letters = string.ascii_lowercase[3:]
            dims = available_letters[:ndim-1]

            # Build einsum expressions
            batch_dims = ''.join([dims[i] for i in range(ndim-1) if i != axis])
            self.cached_expressions["q_expr"] = batch_dims[:axis] + ('a' if axis < ndim-1 else '') + batch_dims[axis:] + 'c'
            self.cached_expressions["k_expr"] = batch_dims[:axis] + ('b' if axis < ndim-1 else '') + batch_dims[axis:] + 'c'
            self.cached_expressions["v_expr"] = batch_dims[:axis] + ('b' if axis < ndim-1 else '') + batch_dims[axis:] + 'c'

            # Output should match query's shape
            self.cached_expressions["out_expr"] = batch_dims[:axis] + ('a' if axis < ndim-1 else '') + batch_dims[axis:] + 'c'

            # Attention map expression
            self.cached_expressions["attn_expr"] = batch_dims[:axis] + ('ab' if axis < ndim-1 else '') + batch_dims[axis:]

        # Compute attention scores
        if verbose:
            print(f'{self.cached_expressions["q_expr"]},{self.cached_expressions["k_expr"]}->{self.cached_expressions["attn_expr"]}')
        attn = torch.einsum(f'{self.cached_expressions["q_expr"]},{self.cached_expressions["k_expr"]}->{self.cached_expressions["attn_expr"]}', q, k) * self.attention_normalizer

        # Apply softmax along the key dimension (last dimension of attention tensor)
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)

        # Apply attention to values
        if verbose:
            print(f'{self.cached_expressions["attn_expr"]},{self.cached_expressions["v_expr"]}->{self.cached_expressions["out_expr"]}')
        output = torch.einsum(f'{self.cached_expressions["attn_expr"]},{self.cached_expressions["v_expr"]}->{self.cached_expressions["out_expr"]}', attn, v)

        # Add residual connection if requested
        output = query + self.proj_out(output)

        return output

In [5]:
# Order 1 tensor

In [6]:
aa = AxisAttention(512, 512, 512).to(device)
X = torch.randn(128, 10, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

dac,dbc->dab
dab,dbc->dac
torch.Size([128, 10, 512])


In [7]:
aa = AxisAttention(512, 512, 512).to(device)
X = torch.randn(128, 10, 512).to(device)
Y = torch.randn(128, 3, 512).to(device)
output = aa(X, Y, verbose=True)
print(output.shape)

dac,dbc->dab
dab,dbc->dac
torch.Size([128, 10, 512])


In [8]:
# Order 2 tensor

In [9]:
aa = AxisAttention(512, 512, 512).to(device)
X = torch.randn(128, 10, 3, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

dafc,dbfc->dabf
dabf,dbfc->dafc
torch.Size([128, 10, 3, 512])


In [10]:
aa = AxisAttention(512, 512, 512, axis=2).to(device)
X = torch.randn(128, 10, 3, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

deac,debc->deab
deab,debc->deac
torch.Size([128, 10, 3, 512])


In [None]:
aa = AxisAttention(512, 512, 512, axis=1).to(device)
X = torch.randn(128, 10, 3, 512).to(device)
Y = torch.randn(128, 3, 3, 512).to(device)
output = aa(X, Y, verbose=True)
print(output.shape)

dafc,dbfc->dabf
dabf,dbfc->dafc
torch.Size([128, 10, 3, 512])


In [12]:
# Order 3 tensor

In [13]:
aa = AxisAttention(512, 512, 512, axis=1).to(device)
X = torch.randn(128, 10, 3, 6, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

dafgc,dbfgc->dabfg
dabfg,dbfgc->dafgc
torch.Size([128, 10, 3, 6, 512])


In [14]:
aa = AxisAttention(512, 512, 512, axis=2).to(device)
X = torch.randn(128, 10, 3, 6, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

deagc,debgc->deabg
deabg,debgc->deagc
torch.Size([128, 10, 3, 6, 512])


In [15]:
aa = AxisAttention(512, 512, 512, axis=3).to(device)
X = torch.randn(128, 10, 3, 6, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

defac,defbc->defab
defab,defbc->defac
torch.Size([128, 10, 3, 6, 512])


In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import string

class MultiheadAxisAttention(nn.Module):
    def __init__(
        self,
        dim_Q: int,
        dim_KV: int,
        dim_out: int | None = None,
        num_heads: int = 8,
        head_dim: int | None = None,
        dropout: float = 0.0,
        axis: int = 1,
    ) -> None:
        """
        Initialize a MultiheadAxisAttention module.

        Args:
            dim_Q: Input dimension of query tensor
            dim_KV: Input dimension of key-value tensor
            dim_out: Output dimension (defaults to dim_Q if None)
            num_heads: Number of attention heads
            head_dim: Dimension of each attention head (if None, calculated as dim_Q // num_heads)
            dropout: Dropout probability after softmax
            axis: Axis to perform attention over (default: 1)
        """
        super().__init__()

        self.num_heads = num_heads
        self.axis = axis

        # Set output dimension
        self.dim_out = dim_out if dim_out is not None else dim_Q

        # Set head dimension
        self.head_dim = head_dim if head_dim is not None else dim_Q // num_heads
        if self.head_dim * num_heads != dim_Q and head_dim is None:
            raise ValueError(f"dim_Q ({dim_Q}) must be divisible by num_heads ({num_heads})")

        # Total dimension for all heads combined
        self.total_head_dim = self.head_dim * num_heads

        # Projection layers
        self.proj_q = nn.Linear(dim_Q, self.total_head_dim)
        self.proj_k = nn.Linear(dim_KV, self.total_head_dim)
        self.proj_v = nn.Linear(dim_KV, self.total_head_dim)

        # Output projection
        self.proj_out = nn.Linear(self.total_head_dim, self.dim_out)

        # Dropout for attention weights
        self.dropout = nn.Dropout(dropout)

        # Cache for einsum expressions
        self.cached_expressions: dict[str, str] = {}

        # Scaling factor for attention scores
        self.attention_normalizer = 1.0 / math.sqrt(self.head_dim)

    def _generate_einsum_expressions(self, query_shape: tuple[int, ...]) -> None:
        """Generate and cache einsum expressions based on input tensor shapes"""
        # Ensure axis is positive
        axis = self.axis if self.axis >= 0 else len(query_shape) + self.axis
        if axis >= len(query_shape) - 1:
            raise ValueError(f"Axis cannot be the last dimension. Got {axis=} but tensor has {len(query_shape)} dimensions.")

        # Generate einsum notation
        ndim = len(query_shape)

        # Reserve 'd' for head dimension, 'a' for query attention, 'b' for key attention, 'c' for embedding
        available_letters = string.ascii_lowercase[4:]  # Skip a, b, c, d
        if ndim - 2 > len(available_letters):  # -2 for attention axis and embedding dim
            raise ValueError(f"Tensor has too many dimensions: {ndim}")

        # Assign dimension labels
        batch_dims = []
        letter_idx = 0

        for i in range(ndim - 1):  # Exclude the last dimension (embedding)
            if i == axis:
                # For the attention axis, use 'a'
                batch_dims.append('a')
            else:
                # For other dimensions, use remaining letters
                batch_dims.append(available_letters[letter_idx])
                letter_idx += 1

        batch_expr = ''.join(batch_dims)

        # Create expressions with head dimension 'h'
        # Query: batch_dims + head + embedding
        self.cached_expressions["q_expr"] = batch_expr + 'dc'  # FIXME

        # Key and Value: replace 'a' with 'b' for the attention dimension
        k_batch = batch_expr.replace('a', 'b')
        self.cached_expressions["k_expr"] = k_batch + 'dc'
        self.cached_expressions["v_expr"] = k_batch + 'dc'

        # Attention map: batch_dims (with both 'a' and 'b') + head
        attn_batch = batch_expr.replace('a', 'ab')
        self.cached_expressions["attn_expr"] = attn_batch + 'd'

        # Output: same as query but with embedding dimension
        self.cached_expressions["out_expr"] = batch_expr + 'dc'

    def forward(self, query: torch.Tensor, key_value: torch.Tensor, verbose: bool = False) -> torch.Tensor:
        """
        Perform multi-head attention along the specified axis.

        Args:
            query: Query tensor of shape (..., dim_Q)
            key_value: Key-value tensor of shape (..., dim_KV)
            verbose: Whether to print intermediate shapes and expressions

        Returns:
            Output tensor with shape (..., dim_out)
        """
        batch_shape = query.shape[:-1]

        # Project inputs to attention space
        q = self.proj_q(query)      # (..., total_head_dim)
        k = self.proj_k(key_value)  # (..., total_head_dim)
        v = self.proj_v(key_value)  # (..., total_head_dim)

        # Reshape to separate head dimension
        q = q.view(*batch_shape, self.num_heads, self.head_dim)
        k = k.view(*key_value.shape[:-1], self.num_heads, self.head_dim)
        v = v.view(*key_value.shape[:-1], self.num_heads, self.head_dim)

        # Generate einsum expressions if not cached
        if not self.cached_expressions:
            self._generate_einsum_expressions(query.shape)

        # Compute attention scores
        if verbose:
            print(f"Q shape: {q.shape}")
            print(f"K shape: {k.shape}")
            print(f"Einsum: {self.cached_expressions['q_expr']},{self.cached_expressions['k_expr']}->{self.cached_expressions['attn_expr']}")

        attn = torch.einsum(
            f"{self.cached_expressions['q_expr']},{self.cached_expressions['k_expr']}->{self.cached_expressions['attn_expr']}", 
            q, k
        ) * self.attention_normalizer

        # Apply softmax along the key dimension (dimension 'b')
        attn = F.softmax(attn, dim=-2)  # -2 corresponds to 'b' dimension
        attn = self.dropout(attn)

        if verbose:
            print(f"Attention shape: {attn.shape}")
            print(f"V shape: {v.shape}")
            print(f"Einsum: {self.cached_expressions['attn_expr']},{self.cached_expressions['v_expr']}->{self.cached_expressions['out_expr']}")

        # Apply attention to values
        output = torch.einsum(
            f"{self.cached_expressions['attn_expr']},{self.cached_expressions['v_expr']}->{self.cached_expressions['out_expr']}", 
            attn, v
        )

        if verbose:
            print(f"Output before reshape: {output.shape}")

        # Reshape back to combine heads
        output = output.reshape(*batch_shape, self.total_head_dim)

        # Apply output projection
        output = self.proj_out(output)

        # Add residual connection
        output = query + output

        return output


In [17]:
# Order 1 tensor

In [18]:
aa = MultiheadAxisAttention(512, 512, 512).to(device)
X = torch.randn(128, 10, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

Q shape: torch.Size([128, 10, 8, 64])
K shape: torch.Size([128, 10, 8, 64])
Einsum: eadc,ebdc->eabd
Attention shape: torch.Size([128, 10, 10, 8])
V shape: torch.Size([128, 10, 8, 64])
Einsum: eabd,ebdc->eadc
Output before reshape: torch.Size([128, 10, 8, 64])
torch.Size([128, 10, 512])


In [19]:
# Order 2 tensor

In [20]:
aa = MultiheadAxisAttention(512, 512, 512).to(device)
X = torch.randn(128, 10, 3, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

Q shape: torch.Size([128, 10, 3, 8, 64])
K shape: torch.Size([128, 10, 3, 8, 64])
Einsum: eafdc,ebfdc->eabfd
Attention shape: torch.Size([128, 10, 10, 3, 8])
V shape: torch.Size([128, 10, 3, 8, 64])
Einsum: eabfd,ebfdc->eafdc
Output before reshape: torch.Size([128, 10, 3, 8, 64])
torch.Size([128, 10, 3, 512])


In [21]:
aa = MultiheadAxisAttention(512, 512, 512, axis=2).to(device)
X = torch.randn(128, 10, 3, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

Q shape: torch.Size([128, 10, 3, 8, 64])
K shape: torch.Size([128, 10, 3, 8, 64])
Einsum: efadc,efbdc->efabd
Attention shape: torch.Size([128, 10, 3, 3, 8])
V shape: torch.Size([128, 10, 3, 8, 64])
Einsum: efabd,efbdc->efadc
Output before reshape: torch.Size([128, 10, 3, 8, 64])
torch.Size([128, 10, 3, 512])


In [22]:
aa = MultiheadAxisAttention(512, 512, 512, axis=1).to(device)
X = torch.randn(128, 10, 3, 512).to(device)
Y = torch.randn(128, 3, 3, 512).to(device)
output = aa(X, Y, verbose=True)
print(output.shape)

Q shape: torch.Size([128, 10, 3, 8, 64])
K shape: torch.Size([128, 3, 3, 8, 64])
Einsum: eafdc,ebfdc->eabfd
Attention shape: torch.Size([128, 10, 3, 3, 8])
V shape: torch.Size([128, 3, 3, 8, 64])
Einsum: eabfd,ebfdc->eafdc
Output before reshape: torch.Size([128, 10, 3, 8, 64])
torch.Size([128, 10, 3, 512])


In [23]:
# Order 3 tensor

In [24]:
aa = MultiheadAxisAttention(512, 512, 512, axis=1).to(device)
X = torch.randn(128, 10, 3, 6, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

Q shape: torch.Size([128, 10, 3, 6, 8, 64])
K shape: torch.Size([128, 10, 3, 6, 8, 64])
Einsum: eafgdc,ebfgdc->eabfgd
Attention shape: torch.Size([128, 10, 10, 3, 6, 8])
V shape: torch.Size([128, 10, 3, 6, 8, 64])
Einsum: eabfgd,ebfgdc->eafgdc
Output before reshape: torch.Size([128, 10, 3, 6, 8, 64])
torch.Size([128, 10, 3, 6, 512])


In [25]:
aa = MultiheadAxisAttention(512, 512, 512, axis=2).to(device)
X = torch.randn(128, 10, 3, 6, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

Q shape: torch.Size([128, 10, 3, 6, 8, 64])
K shape: torch.Size([128, 10, 3, 6, 8, 64])
Einsum: efagdc,efbgdc->efabgd
Attention shape: torch.Size([128, 10, 3, 3, 6, 8])
V shape: torch.Size([128, 10, 3, 6, 8, 64])
Einsum: efabgd,efbgdc->efagdc
Output before reshape: torch.Size([128, 10, 3, 6, 8, 64])
torch.Size([128, 10, 3, 6, 512])


In [26]:
aa = MultiheadAxisAttention(512, 512, 512, axis=3).to(device)
X = torch.randn(128, 10, 3, 6, 512).to(device)
output = aa(X, X, verbose=True)
print(output.shape)

Q shape: torch.Size([128, 10, 3, 6, 8, 64])
K shape: torch.Size([128, 10, 3, 6, 8, 64])
Einsum: efgadc,efgbdc->efgabd
Attention shape: torch.Size([128, 10, 3, 6, 6, 8])
V shape: torch.Size([128, 10, 3, 6, 8, 64])
Einsum: efgabd,efgbdc->efgadc
Output before reshape: torch.Size([128, 10, 3, 6, 8, 64])
torch.Size([128, 10, 3, 6, 512])


---

In [27]:
class AxisSAB(nn.Module):
    # https://github.com/juho-lee/set_transformer
    def __init__(self, input_size: int, output_size: int, n_heads: int, axis: int = 1) -> None:
        super().__init__()
        self.mab = AxisAttention(input_size, input_size, output_size, axis=axis)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        return self.mab(X, X)

In [28]:
asab = AxisSAB(512, 512, 8).to(device)
X = torch.randn(128, 10, 512).to(device)
output = asab(X)
print(output.shape)

torch.Size([128, 10, 512])


In [29]:
asab = AxisSAB(512, 512, 8, axis=2).to(device)
X = torch.randn(128, 10, 3, 512).to(device)
output = asab(X)
print(output.shape)

torch.Size([128, 10, 3, 512])


In [30]:
class AxisISAB(nn.Module):
    # https://github.com/juho-lee/set_transformer
    def __init__(self, input_size: int, output_size: int, n_heads: int, n_induce: int, axis: int = 1) -> None:
        super().__init__()

        self.inducing_points = nn.Parameter(torch.Tensor(n_induce, output_size))
        nn.init.xavier_uniform_(self.inducing_points)

        self.mab0 = AxisAttention(output_size, input_size, output_size, axis=axis)
        self.mab1 = AxisAttention(input_size, output_size, output_size, axis=axis)

        self.axis = axis
        self.n_induce = n_induce

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # Format inducing points to align with the attention axis
        formatted_induction_points = self._format_inducing_points(X)

        # Apply attention mechanisms
        H = self.mab0(formatted_induction_points, X)
        return self.mab1(X, H)

    def _format_inducing_points(self, X: torch.Tensor) -> torch.Tensor:
        """
        Format inducing points to align with the attention axis of X.

        Args:
            X: Input tensor of shape (..., input_size)

        Returns:
            Formatted inducing points tensor with shape matching X except:
            - The attention axis dimension is replaced with n_induce
            - The last dimension is output_size
        """
        # Ensure axis is positive
        axis = self.axis if self.axis >= 0 else X.ndim + self.axis

        # Start with the basic inducing points: [n_induce, output_size]
        points = self.inducing_points

        # Create the target shape:
        # - Same as X for all batch dimensions before the attention axis
        # - n_induce at the attention axis
        # - Same as X for all batch dimensions after the attention axis
        # - output_size as the last dimension
        target_shape = list(X.shape)
        target_shape[axis] = self.n_induce

        # Create a list of dimensions to expand
        expand_shape = list(target_shape)

        # Build reshape pattern by adding singleton dimensions
        view_shape = [1] * (X.ndim - 1)  # -1 because points already has the last dimension
        view_shape[axis] = self.n_induce

        # Reshape inducing points to have singleton dimensions in all batch dims
        points = points.view(*view_shape, -1)

        # Expand to match the target shape (efficiently reuses memory)
        points = points.expand(*expand_shape)

        return points


In [31]:
aisab = AxisISAB(512, 512, 8, 3).to(device)
X = torch.randn(128, 10, 512).to(device)
output = aisab(X)
print(output.shape)

torch.Size([128, 10, 512])


In [32]:
aisab = AxisISAB(512, 512, 8, 3, axis=1).to(device)
X = torch.randn(128, 10, 3, 512).to(device)
output = aisab(X)
print(output.shape)

torch.Size([128, 10, 3, 512])


In [33]:
class AxisPMA(nn.Module):
    """
    Pooling by Multihead Attention with axis support.
    Adapted from https://github.com/juho-lee/set_transformer
    """
    def __init__(self, size: int, n_heads: int, n_seeds: int, axis: int = 1) -> None:
        """
        Initialize a Pooling by Multihead Attention module with axis support.

        Args:
            size: Dimension of input and output features
            n_heads: Number of attention heads
            n_seeds: Number of seed vectors (output set size)
            axis: Axis to perform attention over (default: 1)
        """
        super().__init__()

        self.S = nn.Parameter(torch.Tensor(n_seeds, size))
        nn.init.xavier_uniform_(self.S)

        self.mab = AxisAttention(size, size, size,  axis=axis)
        self.axis = axis
        self.n_seeds = n_seeds

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        """
        Apply pooling attention to input tensor.

        Args:
            X: Input tensor of shape (..., size)

        Returns:
            Output tensor with n_seeds elements along the attention axis
        """
        # Format seed vectors to align with the attention axis
        formatted_seeds = self._format_seeds(X)

        # Apply attention mechanism
        return self.mab(formatted_seeds, X)

    def _format_seeds(self, X: torch.Tensor) -> torch.Tensor:
        """
        Format seed vectors to align with the attention axis of X.

        Args:
            X: Input tensor of shape (..., size)

        Returns:
            Formatted seed vectors with shape matching X's batch dimensions,
            n_seeds at the attention axis, and size as the feature dimension
        """
        # Ensure axis is positive
        axis = self.axis if self.axis >= 0 else X.ndim + self.axis

        # Get the batch shape (all dimensions except the last one)
        batch_shape = list(X.shape[:-1])

        # Create the target shape for the seeds:
        # - Same as X for all batch dimensions except the attention axis
        # - n_seeds at the attention axis
        # - size as the last dimension
        target_shape = batch_shape.copy()
        target_shape[axis] = self.n_seeds

        # Create a view shape with singleton dimensions for all batch dims
        view_shape = [1] * len(batch_shape)
        view_shape[axis] = self.n_seeds

        # Reshape seed vectors to have singleton dimensions in all batch dims
        seeds = self.S.view(*view_shape, -1)

        # Expand to match the target shape (efficiently reuses memory)
        seeds = seeds.expand(*target_shape, X.shape[-1])

        return seeds

In [34]:
apma = AxisPMA(512, 8, 3).to(device)
X = torch.randn(128, 10, 512).to(device)
output = apma(X)
print(output.shape)

torch.Size([128, 3, 512])


In [35]:
apma = AxisPMA(512, 8, 3, axis=2).to(device)
X = torch.randn(128, 10, 16, 512).to(device)
output = apma(X)
print(output.shape)

torch.Size([128, 10, 3, 512])


---

In [36]:
from flash_ansr.models.encoders.set_encoder import SetEncoder
from flash_ansr.models.transformer_utils import PositionalEncoding

class AlternatingSetTransformer(SetEncoder):
    # https://github.com/juho-lee/set_transformer
    def __init__(
            self,
            input_embedding_size: int,
            input_dimension_size: int,
            output_embedding_size: int,
            n_seeds: int,
            hidden_size: int = 512,
            n_enc_isab: int = 2,
            n_dec_sab: int = 2,
            n_induce: int | list[int] = 64,
            n_heads: int = 4,
            add_positional_encoding: bool = True) -> None:
        super().__init__()
        if n_enc_isab < 1:
            raise ValueError(f"Number of ISABs in encoder `n_enc_isab` ({n_enc_isab}) must be greater than 0")

        if n_dec_sab < 0:
            raise ValueError(f"Number of SABs in decoder `n_dec_sab` ({n_dec_sab}) cannot be negative")

        if isinstance(n_induce, int):
            n_induce = [n_induce] * n_enc_isab
        elif len(n_induce) != n_enc_isab:
            raise ValueError(
                f"Number of inducing points `n_induce` ({n_induce}) must be an integer or a list of length {n_enc_isab}")

        self.linear_in = nn.Linear(input_embedding_size, hidden_size)
        self.enc = nn.Sequential(*[AxisISAB(hidden_size, hidden_size, n_heads, n_induce[i], axis = 1 + (i % 2)) for i in range(n_enc_isab)])
        self.pma = AxisPMA(hidden_size, n_heads, n_seeds, axis=1)
        self.dec = nn.Sequential(*[AxisSAB(hidden_size, hidden_size, n_heads, axis = 1 + (i % 2)) for i in range(n_dec_sab)])
        self.linear_out = nn.Linear(hidden_size, output_embedding_size)
        self.positional_encoding_out = PositionalEncoding()

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        out = self.linear_out(self.dec(self.pma(self.enc(self.linear_in(X)))))

        B, M, D, E = out.shape

        out = out + self.positional_encoding_out(shape=(D, E), device=out.device)

        return out

In [37]:
ast = AlternatingSetTransformer(16, 3, 64, hidden_size=64, n_seeds=64, n_enc_isab=5, n_dec_sab=2, n_induce=64).to(device)
ast.n_params

246144

In [43]:
X = torch.randn(128, 256, 3, 16).to(device)
output = ast(X)
print(output.shape)
# Holy mother of memory usage

torch.Size([128, 64, 3, 64])
