In [102]:
import torch

In [105]:
class DeepSetDemo(torch.nn.Module):
    def __init__(self, n_in: int, n_out: int = 1, n_hidden_channels: int = 64, pooling = 'sum', pooling_dim: int = -2):
        super(DeepSetDemo, self).__init__()
        self.fc1 = torch.nn.Linear(n_in, n_hidden_channels)
        self.fc2 = torch.nn.Linear(n_hidden_channels, n_hidden_channels)
        self.fc_out = torch.nn.Linear(n_hidden_channels, n_out)

        if pooling == 'sum':
            self.pooling = torch.sum
        elif pooling == 'mean':
            self.pooling = torch.mean
        elif pooling == 'max':
            self.pooling = torch.max
        else:
            raise ValueError(f"Pooling method {pooling} not supported") 
        
        self.pooling_dim = pooling_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.pooling(x, dim=self.pooling_dim)
        x = self.fc_out(x)
        return x


In [109]:
model = DeepSetDemo(n_in = 4, n_out = 1, n_hidden_channels=64, pooling='sum', pooling_dim=-2)

# 20 sequences of 128 elements with 4 features each
x = torch.randn(20, 70, 4)

y = model(x)
print(y.shape)

torch.Size([20, 1])


In [110]:
class DeepSetTransformerPooling(torch.nn.Module):
    def __init__(self, n_in: int, n_out: int = 1, n_hidden_channels: int = 64, num_heads=8):
        super(DeepSetTransformerPooling, self).__init__()
        self.fc1 = torch.nn.Linear(n_in, n_hidden_channels)

        # For pooling
        self.query = torch.nn.Parameter(torch.randn(1, 1, n_hidden_channels))
        self.pooling = torch.nn.MultiheadAttention(n_hidden_channels, num_heads=num_heads, batch_first=True)

        self.fc_out = torch.nn.Linear(n_hidden_channels, n_out)
        

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.relu(self.fc1(x))
        q = self.query.expand(x.shape[0], -1, -1)
        x = self.pooling(q, x, x)[0]
        x = self.fc_out(x)
        return x.squeeze(1)


In [111]:
model = DeepSetTransformerPooling(n_in = 4, n_out = 2, n_hidden_channels=64, num_heads=8)

In [113]:
y = model(x)
print(y.shape)


torch.Size([20, 2])
