In [9]:
import torch 
from liptrf.models import ViT

In [10]:
v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

In [12]:
preds.shape

torch.Size([1, 1000])

In [6]:
import torch 
import torch.nn as nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [7]:
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

In [181]:
class L2Attention(nn.Module):
    def __init__(
        self, 
        dim: int, 
        heads: int = 8,
        dim_head: int = 64,
        dropout: float = 0.
    ) -> None:
        super().__init__()
        inner_dim = dim * heads 
        project_out = not (heads == 1 and dim_head == dim)
        
        self.heads = heads 
        self.scale = dim_head ** -0.5
        
        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)
        
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()
        
    def forward(
        self,
        x: torch.tensor
    ) -> torch.tensor:
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2))
        q_l2 = torch.pow(q.norm(dim=-1, p=2), 2).unsqueeze(-1)
        k_l2 = torch.pow(k.norm(dim=-1, p=2), 2).unsqueeze(-1)
        q_l2 = torch.matmul(q_l2, torch.ones(q_l2.shape).transpose(-1, -2))
        k_l2 = torch.matmul(torch.ones(k_l2.shape), k_l2.transpose(-1, -2))
        
        attn = self.attend(-1 * (q_l2 - 2 * dots + k_l2) * self.scale)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

In [182]:
att = L2Attention(4)
inp = torch.randn(1, 2, 4)
print (att(inp).shape)

torch.Size([1, 2, 4])
