<a href="https://colab.research.google.com/github/tlqwkrk4471/PytorchBasic/blob/main/Implementing_ViT_in_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1


# Implementing ViT in Pytorch
Reference: https://github.com/FrancescoSaverioZuppichini/ViT

In [2]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary

## Data

In [3]:
#img = Image.open('./meat_data/KakaoTalk_Image_2023-06-13-18-11-22_001.jpeg')

#fig = plt.figure()
#plt.imshow(img)

In [4]:
#img.size

In [5]:
#transform = Compose([Resize((4016, 4016)), ToTensor()])
#x = transform(img)
#x = x.unsqueeze(0)
#x.shape

In [45]:
from torch.utils.data import DataLoader

import torchvision.datasets as dset

if torch.cuda.is_available() :
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

transform = Compose([ToTensor()])
train_data = dset.CIFAR10('./datasets', train = True, download = True, transform = transform)
loader_train = DataLoader(train_data, batch_size = 32, shuffle= True)

test_data = dset.CIFAR10('./datasets', train = False, download = True, transform = transform)
loader_test = DataLoader(test_data, batch_size = 32, shuffle = True)

Files already downloaded and verified
Files already downloaded and verified


In [97]:
x = train_data[0][0].unsqueeze(0)
x.shape

torch.Size([1, 3, 32, 32])

## Patches Embedding

2D image → sequence of flattened 2D patches

H x W x C → N x (PPC)

In [99]:
patch_size = 4
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
patches.shape

torch.Size([1, 64, 48])

In [100]:
class PatchEmbedding(nn.Module):
  def __init__(self, in_channels: int = 3, patch_size: int = 4, emb_size: int = 48):
    self.patch_size = patch_size
    super().__init__()
    self.projection = nn.Sequential(
      Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
      nn.Linear(patch_size * patch_size * in_channels, emb_size)
    )

  def forward(self, x: Tensor) -> Tensor:
    x = self.projection(x)
    return x

In [101]:
PatchEmbedding()(x).shape

torch.Size([1, 64, 48])

In [102]:
class PatchEmbedding(nn.Module):
  def __init__(self, in_channels: int = 3, patch_size: int = 4, emb_size: int = 48):
    self.patch_size = patch_size
    super().__init__()
    self.projection = nn.Sequential(
        nn.Conv2d(in_channels, emb_size, kernel_size = patch_size, stride = patch_size),
        # out_channels = emb_size(768)

        Rearrange('b e h w -> b (h w) e')
        # in_e = channel
        # out_e = emb_size
    )

  def forward(self, x: Tensor) -> Tensor:
    x = self.projection(x)
    return x

PatchEmbedding()(x).shape
# torch.Size([1, 3, 4016, 4016]) -> torch.Size([1, 63001, 768])

torch.Size([1, 64, 48])

## CLS Token
shape = (1,1,emb_size)

In [103]:
class PatchEmbedding(nn.Module):
  def __init__(self, in_channels: int = 3, patch_size: int = 4, emb_size: int = 48):
    self.patch_size = patch_size
    super().__init__()
    self.projection = nn.Sequential(
        nn.Conv2d(in_channels, emb_size, kernel_size = patch_size, stride = patch_size),
        Rearrange('b e h w -> b (h w) e')
    )

    self.cls_token = nn.Parameter(torch.randn(1,1,emb_size))

  def forward(self, x: Tensor) -> Tensor:
    b, _, _, _ = x.shape
    x = self.projection(x)
    cls_token = repeat(self.cls_token, '() n e -> b n e', b = b)

    x = torch.cat([cls_token, x], dim = 1)
    return x

PatchEmbedding()(x).shape

torch.Size([1, 65, 48])

## Position Embedding
shape = (n+1, emb_size)

batch size는 broadcasting으로 같은 Position Embdding 적용하기 때문에 명시하지 않음

Position Embedding은 trainable하기 때문에 Parameter로 설정

In [104]:
class PatchEmbedding(nn.Module):
  def __init__(self, in_channels: int = 3, patch_size: int = 4, emb_size: int = 48, img_size: int = 32):
    self.patch_size = patch_size
    super().__init__()
    self.projection = nn.Sequential(
      nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
      Rearrange('b e (h) (w) -> b (h w) e'),
    )
    self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
    self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))

  def forward(self, x: Tensor) -> Tensor:
    b, _, _, _ = x.shape
    x = self.projection(x)
    cls_tokens = repeat(self.cls_token, '() n e -> b n e', b = b)

    x = torch.cat([cls_tokens, x], dim = 1)

    x += self.positions
    return x

PatchEmbedding()(x).shape

torch.Size([1, 65, 48])

# Transformer Encoder
## Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
