In [1]:
import lightning as L
import torch
import torch.nn as nn
import math
from typing import Tuple

In [2]:
# Sinusoidal
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()

        pos_encode_range = d_model // 2
        omega = torch.zeros(pos_encode_range,)
        for i in range(pos_encode_range):
            omega[i] = 10000**(-2 * i / d_model)

        self.register_buffer("omega", omega)


    def forward(self, input: torch.Tensor) -> torch.Tensor:
        pos_encode = torch.zeros_like(input)
        if(input.dim() == 3):
            for pos in range(input.size(1)):
                for i in range(input.size(2)):
                    if i % 2 == 0:
                        pos_encode[:, pos, i] = input[:, pos, i] + torch.sin(pos * self.omega[i//2])
                    else:
                        pos_encode[:, pos, i] = input[:, pos, i] + torch.cos(pos * self.omega[(i-1)//2])
        else:
            raise ValueError(f"Wrong input dimension. Your dimension is: {input.dim()}, but should be: 3")
        
        return pos_encode

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, nheads: int, d_model:int):
        if(d_model % nheads != 0):
            raise ValueError(f"Can't divide {d_model} (d_model) by {nheads} (nheads).")

        super().__init__()

        self.nheads = nheads
        self.d_head = d_model // nheads

        self.w_q = nn.Linear(d_model, d_model, bias=True)
        self.w_k = nn.Linear(d_model, d_model, bias=True)
        self.w_v = nn.Linear(d_model, d_model, bias=True)
        self.output = nn.Linear(d_model, d_model, bias=True)


    def split(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        q_split = q.reshape(q.size(0), q.size(1), self.nheads, self.d_head).transpose(1, 2)
        k_split = k.reshape(k.size(0), k.size(1), self.nheads, self.d_head).transpose(1, 2)
        v_split = v.reshape(v.size(0), v.size(1), self.nheads, self.d_head).transpose(1, 2)

        return q_split, k_split, v_split


    def attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        score_input = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
        score = torch.softmax(score_input, dim=-1) 
        score = score @ v
        return score


    def forward(self, input) -> torch.Tensor:
        Q = self.w_q(input)
        K = self.w_k(input)
        V = self.w_v(input)

        Q_split, K_split, V_split = self.split(Q, K, V)

        head_query = []

        for i in range(self.nheads):
            Q_i, K_i, V_i = Q_split[:, i, :, :], K_split[:, i, :, :], V_split[:, i, :, :]
            head = self.attention(Q_i, K_i, V_i)
            head_query.append(head)

        head_query = torch.stack(head_query, dim=1)

        result = head_query.transpose(1, 2).contiguous().view(input.size(0), input.size(1), -1)

        return self.output(result)

In [None]:
class KeypointTransformer(nn.Module):
    def __init__(self, 
                 feature_len: int, # Input featuresize 
                 d_model: int, # Dimension of model -> Token size
                 nheads: int = 4, # Number of attention heads for transformer
                 dim_ff: int = 512, 
                 batch_size: int = 32
                 ):
        super().__init__()

        self.input_proj = nn.Sequential(
            nn.Linear(feature_len, d_model)
        )

        self.pos_encode = SinusoidalPositionalEncoding(d_model=d_model)
        

    def forward(self, input):
        x = self.input_proj(input)
        x = self.pos_encode(x)
        return x

    

In [None]:
class LitKeypointTransformer(L.LightningModule):
    def __init__(self):
        pass