In [1]:
import torch 
import torch.nn as nn
from timm import create_model

In [5]:
model = create_model("vit_tiny_patch16_224", pretrained=True)
model.eval()

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

In [3]:
class L2Attention(nn.Module):
    def __init__(
         self, 
         dim: int, 
         num_heads: int = 8, 
         qkv_bias: bool = False, 
         attn_drop: float = 0., 
         proj_drop: float = 0.
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
    def forward(
        self,
        x: torch.tensor
    ) -> torch.tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        dots = q @ k.transpose(-2, -1)
        q_l2 = torch.pow(q.norm(dim=-1, p=2), 2).unsqueeze(-1)
        k_l2 = torch.pow(k.norm(dim=-1, p=2), 2).unsqueeze(-1)
        q_l2 = torch.matmul(q_l2, torch.ones(q_l2.shape).transpose(-1, -2))
        k_l2 = torch.matmul(torch.ones(k_l2.shape), k_l2.transpose(-1, -2))
        
        attn = (-1 * (q_l2 - 2 * dots + k_l2) * self.scale).softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [4]:
att = L2Attention(dim=16)
inp = torch.randn(1, 2, 16)
print (att(inp).shape)

torch.Size([1, 2, 16])
