### Trainable Fourier PEs
Paper: Learnable Fourier Features for Multi-Dimensional Spatial Positional Encoding (https://arxiv.org/pdf/2106.02795.pdf)  
Original Implementation: willGuimont (https://github.com/willGuimont/learnable_fourier_positional_encoding)

In [9]:
import numpy as np
import torch
import torch.nn as nn

In [10]:
class FourierPE(nn.Module):

    # Compute Fourier feature PE of multi-dim. position
    # G: positional groups
    # M: M-dimensional positional values
    # F: depth of Fourier feature dim.
    # H: hidden layer dim.
    # D: positional encoding dim.
    # gamma: parameter to initialize Wr
    
    def __init__(self, G: int, M: int, F: int, H: int, D: int, gamma: float):
        super().__init__()
        self.G = G
        self.M = M
        
        # Hyperparameters
        self.F = F
        self.H = H
        self.D = D
        self.gamma = gamma

        # Projection Matrix
        self.Wr = nn.Linear(self.M, self.F // 2, bias=False)
        
        # MLP
        self.mlp = nn.Sequential(
            nn.Linear(self.F, self.H, bias=True),
            nn.GELU(),
            nn.Linear(self.H, self.D // self.G)
        )

        self.init_weights()

    # Wr Initialization
    def init_weights(self):
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)

    # Produce positional encodings from x
    # x: tensor of shape [N, G, M]
    # returns: positional encoding for X
    
    def forward(self, x):
        N, G, M = x.shape
        
        # Fourier Features
        projected = self.Wr(x)
        cosines = torch.cos(projected)
        sines = torch.sin(projected)
        F = 1 / np.sqrt(self.F) * torch.cat([cosines, sines], dim=-1)
        
        # Projected Fourier Features
        Y = self.mlp(F)
        
        # Reshape Output
        PEx = Y.reshape((N, self.D))
        
        return PEx

In [12]:
# Test
if __name__ == '__main__':
    G = 3
    M = 17
    
    # Input
    x = torch.randn((97, G, M))
    enc = FourierPE(G, M, 768, 32, 768, 10)
    
    # Output
    pex = enc(x)
    print(pex.shape)

torch.Size([97, 768])


In [13]:
print(pex)

tensor([[-0.0965,  0.0046,  0.0877,  ..., -0.1737, -0.0644,  0.0150],
        [-0.0967,  0.0047,  0.0879,  ..., -0.1740, -0.0649,  0.0148],
        [-0.0968,  0.0048,  0.0880,  ..., -0.1738, -0.0646,  0.0149],
        ...,
        [-0.0966,  0.0046,  0.0878,  ..., -0.1737, -0.0648,  0.0151],
        [-0.0968,  0.0047,  0.0878,  ..., -0.1735, -0.0650,  0.0149],
        [-0.0967,  0.0046,  0.0879,  ..., -0.1732, -0.0647,  0.0150]],
       grad_fn=<ReshapeAliasBackward0>)
