In [3]:
from ast import literal_eval 

from glob import glob 

from PIL import Image 
import numpy as np 

import torch 
import torch.nn as nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from liptrf.models.vit import ViT, L2Attention, FeedForward

In [5]:
v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1, 
    attention_type='L2'
).cuda()

img = torch.randn(1, 3, 256, 256).cuda()

preds = v(img) # (1, 1000)

In [19]:
ff = FeedForward(2, 2)
ff.net[0].weight

Parameter containing:
tensor([[-0.5506, -0.2802],
        [-0.4900,  0.0121]], requires_grad=True)

In [36]:
fin = open('../imagenet-sample-images/imagenet1000_clsidx_to_labels.txt', 'r')
class_map = literal_eval(fin.read())
fin.close()

In [4]:
image_names = glob('../imagenet-sample-images/*JPEG')

In [5]:
class L2Attention(nn.Module):
    def __init__(
         self, 
         dim: int, 
         num_heads: int = 8, 
         qkv_bias: bool = False, 
         attn_drop: float = 0., 
         proj_drop: float = 0.
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
    def forward(
        self,
        x: torch.tensor
    ) -> torch.tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        dots = q @ k.transpose(-2, -1)
        q_l2 = torch.pow(q.norm(dim=-1, p=2), 2).unsqueeze(-1)
        k_l2 = torch.pow(k.norm(dim=-1, p=2), 2).unsqueeze(-1)
        q_l2 = torch.matmul(q_l2, torch.ones(q_l2.shape).transpose(-1, -2))
        k_l2 = torch.matmul(torch.ones(k_l2.shape), k_l2.transpose(-1, -2))
        
        attn = (-1 * (q_l2 - 2 * dots + k_l2) * self.scale).softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [43]:
class L2Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qv = nn.Linear(dim, inner_dim * 2, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qv = self.to_qv(x).chunk(2, dim = -1)
        q, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qv)

        dots = q @ q.transpose(-2, -1)
        q_l2 = torch.pow(q.norm(dim=-1, p=2), 2).unsqueeze(-1)
        k_l2 = torch.pow(q.norm(dim=-1, p=2), 2).unsqueeze(-1)
        q_l2 = torch.matmul(q_l2, torch.ones(q_l2.shape).transpose(-1, -2))
        k_l2 = torch.matmul(torch.ones(k_l2.shape), k_l2.transpose(-1, -2))
        
        attn = (-1 * (q_l2 - 2 * dots + k_l2) * self.scale).softmax(dim=-1)
        attn = self.dropout(attn)
        
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

In [46]:
att = L2Attention(dim=16)
inp = torch.randn(1, 2, 16)
print (att(inp).shape)

torch.Size([1, 2, 16])
torch.Size([1, 8, 2, 1]) torch.Size([1, 8, 2, 64])
torch.Size([1, 2, 16])


In [None]:
idx = 6
print (image_names[idx])
img = Image.open(image_names[idx])
img = np.asarray(img.resize((224, 224))).transpose(2, 0, 1) / 255.
img = torch.from_numpy(np.asarray([img])).float()
probas = model(img)
new_probas = l2_model(img)
print (torch.mean((probas - new_probas)**2))
print (class_map[torch.argmax(probas).item()])
print (class_map[torch.argmax(new_probas).item()])

../imagenet-sample-images/n04208210_shovel.JPEG
tensor(4.2448, grad_fn=<MeanBackward0>)
shovel
moving van
