In [None]:
import lightning as L
import torch
import torch.nn as nn

In [None]:
# Sinusoidal
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()

        pos_encode_range = d_model // 2
        omega = torch.zeros(pos_encode_range,)
        for i in range(pos_encode_range):
            omega[i] = 10000**(-2 * i / d_model)

        self.register_buffer("omega", omega)


    def forward(self, input: torch.tensor):
        pos_encode = torch.zeros_like(input)
        if(input.dim() == 2):
            for pos in range(input.size(0)):
                for i in range(input.size(1)):
                    if i % 2 == 0:
                        pos_encode[pos][i] = input[pos][i] + torch.sin(pos * self.omega[i//2])
                    else:
                        pos_encode[pos][i] = input[pos][i] + torch.cos(pos * self.omega[(i-1)//2])
        else:
            raise ValueError
        
        return pos_encode

In [None]:
class KeypointTransformer(nn.Moduile):
    def __init__(self, 
                 feature_len: int, # Input featuresize 
                 d_model: int, # Dimension of model -> Tokens
                 head: int = 4, # Transformer Header Size
                 dim_ff: int = 512 # 
                 ):
        super().__init__()

        self.input_proj = nn.Sequential(
            nn.Linear(feature_len, d_model)
        )

        self.pos_encode = PositionalEncoding(d_model=d_model)
        

    def forward(self, input):
        x = self.input_proj(input)
        x = self.pos_encode(x)
        return x

    

In [None]:
class LitKeypointTransformer(L.LightningModule):
    def __init__(self):
        pass