<a href="https://colab.research.google.com/github/siddharthmishra11/ML_sprinklr_code/blob/master/VitPose.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # using a conv layer instead of a linear one -> performance gains
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e (h) (w) -> b (h w) e'),
        )
                
    def forward(self, x: Tensor) -> Tensor:
        x = self.projection(x)
        return x

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size: int = 768, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        # fuse the queries, keys and values in one matrix
        self.qkv = nn.Linear(emb_size, emb_size * 3)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        
    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads
        qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
        queries, keys, values = qkv[0], qkv[1], qkv[2]
        # sum up over the last axis
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)
            
        scaling = self.emb_size ** (1/2)
        att = F.softmax(energy, dim=-1) / scaling
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out
    

In [None]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

In [None]:
class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, drop_p: float = 0.):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )

In [None]:
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 drop_p: float = 0.,
                 forward_expansion: int = 4,
                 forward_drop_p: float = 0.,
                 ** kwargs):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, **kwargs),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

In [None]:
class deconv_layer(nn.Module):
  def __init__(self,in_channels: int = 196,out_channels: int = 196,h:int = 28,w:int =28):
    super().__init__()
    self.deconv = nn.Sequential(
        nn.ConvTranspose2d(in_channels,out_channels,kernel_size = (2,2),stride = (2,2)),
        nn.LayerNorm((h,w)),
        nn.ReLU()
    )
  def forward(self, x: Tensor)->Tensor:
    x = self.deconv_layer(x)
    return x

In [None]:
class convert1dTo2d(nn.Module):
  def __init__(self,patch_size: int = 16,img_size: int = 224):
    super().__init__()
    self.patch_size= patch_size
    self.img_size = img_size
  def forward(self, x: Tensor)->Tensor:
    #image size calc
    siz = self.img_size/self.patch_size
    out = rearrange(x, "b (h w) n -> b n h w",h = siz,w = siz)
    return out

In [None]:
class classic_decoder(nn.Sequential):
  def __init__(self,in_channels: int = 768,out_channels: int = 768,nk: int = 17,h: int = 14,w: int = 14):
    super().__init__(
        convert1dTo2d(),
        deconv_layer(in_channels,out_channels,2*h,2*w),
        deconv_layer(in_channels,out_channels,4*h,4*w),
        nn.Conv2d(in_channels,nk,kernel_size = (1,1),stride = (1,1))
    )

In [None]:
class simple_decoder(nn.Sequential):
  def __init__(self,scale_factor: int = 2,in_channels: int = 17,out_channels: int = 17):
    super().__init__(
        F.UpsamplingBilinear2d(scale_factor),
        nn.ReLU(),
        nn.Conv2d(in_channels,out_channels,kernel_size = (3,3),stride = (1,1),padding = (1,1))
    )


In [None]:
class ViTPose(nn.Sequential):
    def __init__(self,     
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                n_classes: int = 1000,
                nk: int = 17,
                h: int = 14,
                w: int = 14,
                s:int = 2
                 ):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoderBlock(depth, emb_size=emb_size),
            classic_decoder(emb_size,emb_size,nk,h,w),
            simple_decoder(s,nk,nk)
        )