In [1]:
from ast import literal_eval 

from glob import glob 

from PIL import Image 
import numpy as np 

import torch 
import torch.nn as nn
from timm import create_model

In [2]:
fin = open('../imagenet-sample-images/imagenet1000_clsidx_to_labels.txt', 'r')
class_map = literal_eval(fin.read())
fin.close()

In [3]:
image_names = glob('../imagenet-sample-images/*JPEG')

In [4]:
model = create_model("vit_tiny_patch16_224", pretrained=True)
model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU()
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
      (ls2): Identity()
      (d

In [20]:
class L2Attention(nn.Module):
    def __init__(
         self, 
         dim: int, 
         num_heads: int = 8, 
         qkv_bias: bool = False, 
         attn_drop: float = 0., 
         proj_drop: float = 0.
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        
    def forward(
        self,
        x: torch.tensor
    ) -> torch.tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        dots = q @ k.transpose(-2, -1)
        q_l2 = torch.pow(q.norm(dim=-1, p=2), 2).unsqueeze(-1)
        k_l2 = torch.pow(k.norm(dim=-1, p=2), 2).unsqueeze(-1)
        q_l2 = torch.matmul(q_l2, torch.ones(q_l2.shape).transpose(-1, -2))
        k_l2 = torch.matmul(torch.ones(k_l2.shape), k_l2.transpose(-1, -2))
        
        attn = (-1 * (q_l2 - 2 * dots + k_l2) * self.scale).softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [21]:
att = L2Attention(dim=16)
inp = torch.randn(1, 2, 16)
print (att(inp).shape)

torch.Size([1, 2, 16])


In [22]:
from timm.models.vision_transformer import VisionTransformer, DropPath, Mlp

In [38]:
class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma
    
class L2Block(nn.Module):

    def __init__(
            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
            drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = L2Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x

In [39]:
l2_model = create_model("vit_tiny_patch16_224", pretrained=True, block_fn=L2Block)
l2_model.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): L2Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): L2Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU()
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
      (ls2): Identity()
    

In [47]:
idx = 6
print (image_names[idx])
img = Image.open(image_names[idx])
img = np.asarray(img.resize((224, 224))).transpose(2, 0, 1) / 255.
img = torch.from_numpy(np.asarray([img])).float()
probas = model(img)
new_probas = l2_model(img)
print (torch.mean((probas - new_probas)**2))
print (class_map[torch.argmax(probas).item()])
print (class_map[torch.argmax(new_probas).item()])

../imagenet-sample-images/n04208210_shovel.JPEG
tensor(4.2448, grad_fn=<MeanBackward0>)
shovel
moving van
