In [1]:
from flash_ansr.models.encoders.set_encoder import SetEncoder
from flash_ansr.models.encoders.set_transformer import ISAB, PMA, SAB
from flash_ansr.models.transformer_utils import PositionalEncoding
import torch
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class SetTransformer(SetEncoder):
    # https://github.com/juho-lee/set_transformer
    def __init__(
            self,
            input_size: int,
            output_size: int,
            n_seeds: int,
            hidden_size: int = 512,
            n_enc_isab: int = 2,
            n_dec_sab: int = 2,
            n_induce: int | list[int] = 64,
            n_heads: int = 4,
            layer_norm: bool = False) -> None:
        super().__init__()
        if n_enc_isab < 1:
            raise ValueError(f"Number of ISABs in encoder `n_enc_isab` ({n_enc_isab}) must be greater than 0")

        if n_dec_sab < 0:
            raise ValueError(f"Number of SABs in decoder `n_dec_sab` ({n_dec_sab}) cannot be negative")

        if isinstance(n_induce, int):
            n_induce = [n_induce] * n_enc_isab
        elif len(n_induce) != n_enc_isab:
            raise ValueError(
                f"Number of inducing points `n_induce` ({n_induce}) must be an integer or a list of length {n_enc_isab}")

        self.enc = nn.Sequential(
            ISAB(input_size, hidden_size, n_heads, n_induce[0], layer_norm),
            *[ISAB(hidden_size, hidden_size, n_heads, n_induce[i + 1], layer_norm) for i in range(n_enc_isab - 1)])

        self.dec = nn.Sequential(
            PMA(hidden_size, n_heads, n_seeds, layer_norm),
            *[SAB(hidden_size, hidden_size, n_heads, layer_norm) for _ in range(n_dec_sab)],
            nn.Linear(hidden_size, output_size))

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        return self.dec(self.enc(X))


In [3]:
s = SetTransformer(13, 13, 64).to(device)

In [4]:
X = torch.rand(128, 70, 13).to(device)
print(s(X).shape)

torch.Size([128, 64, 13])


In [5]:
class FlatSetTransformer(SetEncoder):
    # https://github.com/juho-lee/set_transformer
    def __init__(
            self,
            input_size: int,
            output_size: int,
            n_seeds: int,
            hidden_size: int = 512,
            n_enc_isab: int = 2,
            n_dec_sab: int = 2,
            n_induce: int | list[int] = 64,
            n_heads: int = 4,
            layer_norm: bool = False,
            add_positional_encoding: bool = True) -> None:
        super().__init__()
        if n_enc_isab < 1:
            raise ValueError(f"Number of ISABs in encoder `n_enc_isab` ({n_enc_isab}) must be greater than 0")

        if n_dec_sab < 0:
            raise ValueError(f"Number of SABs in decoder `n_dec_sab` ({n_dec_sab}) cannot be negative")

        if isinstance(n_induce, int):
            n_induce = [n_induce] * n_enc_isab
        elif len(n_induce) != n_enc_isab:
            raise ValueError(
                f"Number of inducing points `n_induce` ({n_induce}) must be an integer or a list of length {n_enc_isab}")

        self.enc = nn.Sequential(
            ISAB(input_size, hidden_size, n_heads, n_induce[0], layer_norm),
            *[ISAB(hidden_size, hidden_size, n_heads, n_induce[i + 1], layer_norm) for i in range(n_enc_isab - 1)])

        self.dec = nn.Sequential(
            PMA(hidden_size, n_heads, n_seeds, layer_norm),
            *[SAB(hidden_size, hidden_size, n_heads, layer_norm) for _ in range(n_dec_sab)],
            nn.Linear(hidden_size, output_size))
        
        self.add_positional_encoding = add_positional_encoding
        self.positional_encoding_out = PositionalEncoding()

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        B, M, D, E = X.shape
        out = self.dec(self.enc(X.reshape(B, M * D, E)))

        if self.add_positional_encoding:
            E = X.size(-1)

            # Positional encoding
            X = X + self.positional_encoding_out(shape=(D, E), device=X.device)
        return out
        

In [6]:
s = FlatSetTransformer(16, 13, 64, n_enc_isab=5, hidden_size=512).to(device)
s.n_params

13099533

In [7]:
X = torch.rand(128, 512, 3, 16)
print(s(X.to(device)).shape)

torch.Size([128, 64, 13])


In [5]:
class AlternatingSetTransformer(SetEncoder):
    # https://github.com/juho-lee/set_transformer
    def __init__(
            self,
            input_embedding_size: int,
            output_embedding_size: int,
            n_seeds: int,
            input_dimension_size: int | None = None,  # unused
            hidden_size: int = 512,
            n_enc_isab: int = 2,
            n_dec_sab: int = 2,
            n_induce: int | list[int] = 64,
            n_heads: int = 4,
            layer_norm: bool = False,
            add_positional_encoding: bool = True) -> None:
        super().__init__()
        if n_enc_isab < 1:
            raise ValueError(f"Number of ISABs in encoder `n_enc_isab` ({n_enc_isab}) must be greater than 0")

        if n_dec_sab < 0:
            raise ValueError(f"Number of SABs in decoder `n_dec_sab` ({n_dec_sab}) cannot be negative")

        if isinstance(n_induce, int):
            n_induce = [n_induce] * n_enc_isab
        elif len(n_induce) != n_enc_isab:
            raise ValueError(
                f"Number of inducing points `n_induce` ({n_induce}) must be an integer or a list of length {n_enc_isab}")

        self.linear_in = nn.Linear(input_embedding_size, hidden_size)
        self.enc = nn.ModuleList([ISAB(hidden_size, hidden_size, n_heads, n_induce[i], layer_norm) for i in range(n_enc_isab)])
        self.pma = PMA(hidden_size, n_heads, n_seeds, layer_norm)
        self.dec = nn.ModuleList([SAB(hidden_size, hidden_size, n_heads, layer_norm) for _ in range(n_dec_sab)])
        self.linear_out = nn.Linear(hidden_size, output_embedding_size)
        self.positional_encoding_out = PositionalEncoding()

        self.add_positional_encoding = add_positional_encoding

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # X: (B, M, D, E)
        # With batch size B, M points, D dimensions, and each value being encoded with E dimensions
        B, M, D, _ = X.size()

        # X (B, M, D, E) -> (B, M, D, E')
        X = self.linear_in(X)
        E_prime = X.size(-1)

        for i, isab in enumerate(self.enc):
            if i % 2 == 0:
                # Alternate to self attention between points
                # X (B, M, D, E') -> (B * D, M, E')
                X = X.transpose(1, 2).reshape(B * D, M, E_prime)
                X = isab(X)

                # X (B * D, M, E') -> (B, M, D, E')
                X = X.view(B, D, M, E_prime).transpose(1, 2)
            else:
                # Self attention between dimensions
                # X (B, M, D, E') -> (B * M, D, E')
                X = X.reshape(B * M, D, E_prime)
                X = isab(X)

                # X (B * M, D, E') -> (B, M, D, E')
                X = X.view(B, M, D, E_prime)

        # Pooling through PMA
        # X (B, M, D, E') -> (B * D, M, E')
        X = X.transpose(1, 2).reshape(B * D, M, E_prime)

        # X (B * D, M, E') -> (B * D, S, E')
        X = self.pma(X)
        S = X.size(1)

        # X (B * D, S, E') -> (B, S, D, E')
        X = X.view(B, D, S, E_prime).transpose(1, 2)

        for i, sab in enumerate(self.dec):
            if i % 2 == 0:
                # Alternate to self attention between points
                # X (B, S, D, E') -> (B * D, S, E')
                X = X.transpose(1, 2).reshape(B * D, S, E_prime)
                X = sab(X)

                # X (B * D, S, E') -> (B, S, D, E')
                X = X.view(B, D, S, E_prime).transpose(1, 2)
            else:
                # Self attention between dimensions
                # X (B, S, D, E') -> (B * S, D, E')
                X = X.reshape(B * S, D, E_prime)
                X = sab(X)

                # X (B * S, D, E') -> (B, S, D, E')
                X = X.view(B, S, D, E_prime)

        # X (B, S, D, E') -> (B, S, D, E)
        X = self.linear_out(X)

        if self.add_positional_encoding:
            E = X.size(-1)

            # Positional encoding
            # X (B, S, D, E) -> (B * S, D, E)
            X = X.view(B * S, D, E)
            X = X + self.positional_encoding_out(X)

            # X (B * S, D, E) -> (B, S, D, E)
            X = X.view(B, S, D, E)

        return X

In [6]:
s = AlternatingSetTransformer(16, 32, 3).to(device)

In [7]:
# X: (B, M, D, E)
X = torch.rand(128, 70, 13, 16).to(device)
with torch.no_grad():
    print(s(X).shape)

torch.Size([128, 3, 13, 32])


In [8]:
assert False

AssertionError: 

In [8]:
s = AlternatingSetTransformer(11, 5, 3, n_enc_isab=2).to(device)

In [None]:
#  Verify equivariance in the D dimension
X = torch.rand(2, 7, 13, 11).to(device)
Y = s(X, add_positional_encoding=False)

random_permutation = torch.randperm(X.size(2))
X_permuted = X[:, :, random_permutation]
Y_permuted = s(X_permuted, add_positional_encoding=False)
assert torch.allclose(Y_permuted, Y[:, :, random_permutation], atol=1e-6)

TypeError: AlternatingSetTransformer.forward() got an unexpected keyword argument 'add_positional_encoding'

In [10]:
#  Verify equivariance in the D dimension is broken when adding positional encoding
X = torch.rand(2, 7, 13, 11).to(device)
Y = s(X)
random_permutation = torch.randperm(X.size(2))
X_permuted = X[:, :, random_permutation]
Y_permuted = s(X_permuted)
assert not torch.allclose(Y_permuted, Y[:, :, random_permutation])