<a href="https://colab.research.google.com/github/taekyungss/computer_vision_planting_grass/blob/main/VIT_pytorch_%EA%B5%AC%ED%98%84.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

아직 미완

In [None]:
!pip install einops

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

from torch import optim
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
from torchvision import utils

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
import numpy as np
import time
import copy
import random
from tqdm.notebook import tqdm
import math

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
path2data = '/content/data'

if not os.path.exists(path2data):
  os.mkdir(path2data)

train_ds = datasets.STL10(path2data, split="train", download=True, transform=transforms.ToTensor())
val_ds = datasets.STL10(path2data, split="test", download=True, transform = transforms.ToTensor())

print(len(train_ds))
print(len(val_ds))

In [None]:
# transformation

transformation = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(224)
])

train_ds.transform = transformation
val_ds.transform = transformation

train_dl = DataLoader(train_ds, batch_size=32, shuffle = True)
val_dl = DataLoader(val_ds, batch_size = 64, shuffle = True)

In [None]:
def show(img, y=None):
  npimg = img.numpy()
  npimg_tr = np.transpose(npimg, (1,2,0))
  plt.imshow(npimg_tr)

  if y is not None:
    plt.title("labels: "+str(y))

np.random.seed(10)
torch.manual_seed(0)

grid_size=4
rnd_ind = np.random.randint(0,len(train_ds), grid_size)

x_grid = [train_ds[i][0] for i in rnd_ind]
y_grid = [val_ds[i][1] for i in rnd_ind]

x_grid = utils.make_grid(x_grid, nrow=grid_size, padding=2)
plt.figure(figsize=(10,10))
show(x_grid, y_grid)

## VIT 구현

In [None]:
# patch embedding

class PatchEmbedding(nn.Module):
  def __init__(self, in_channels=3, patch_size = 16, emb_size = 768, img_size = 224):
    super().__init__()
    self.patch_size = patch_size

    self.projection = nn.Sequential(
        nn.Conv2d(in_channels, emb_size, patch_size , stride=patch_size),
        # einops.rearrange는multidimensional tensor를 쉽게 reordering하는 함수입니다.
        Rearrange('b e (h) (w) -> b (h w) e')
    )

    self.cls_token = nn.Parameter(torch.randn(1,1,emb_size))
    self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 +1, emb_size))

  def forward(self, x:Tensor) -> Tensor:
    b = x.shape[0]
    x = self.projection(x)
    cls_tokens = repeat( self.cls_token, '() n e -> b n e', b=b)
    # elnops -> repeat
    x = torch.cat([cls_tokens, x], dim=1)
    x+= self.positions
    return x

In [None]:
# check

x = torch.randn(16,3,224,224).to(device)
patch_embedding = PatchEmbedding().to(device)
patch_output = patch_embedding(x)
print("[batch, 1+num of patches, emb_size] = ", patch_output.shape)

In [None]:
# multihead attention
class MultiHeadAttention(nn.Module):
  def __init__(self,emb_size = 768, num_heads=8, dropout = 0):
    super().__init__()
    self.emb_size = emb_size
    self.num_heads = num_heads
    self.keys = nn.Linear(emb_size, emb_size)
    self.queries = nn.Linear(emb_size, emb_size)
    self.values = nn.Linear(emb_size, emb_size)
    self.att_drop = nn.Dropout(dropout)
    self.projection = nn.Linear(emb_size, emb_size)

  def forward(self,x,mask=None):
    queries = rearrange(self.queries(x), 'b n (h d) -> b h n d',
                        h = self.num_heads)
    keys = rearrange(self.keys(x), 'b n (h d) -> b h n d', h=self.num_heads)
    values = rearrange(self.values(x),'b n (h d) -> b h n d',h = self.num_heads)
    energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)

    if mask is not None:
      fill_value = torch.finfo(torch.float32).min
      energy.mask_fill(~mask, fill_value)

    scaling = self.emb_size ** (1/2)
    att = F.softmax(energy, dim =-1) / scaling
    att = self.att_drop(att)
    out = torch.einsum('bhal, bhiv -> bhav', att, values)
    out = rearrange(out, 'b h n d -> b n (h d)')
    out = self.projection(out)

    return out




In [None]:
# 잘 구현되었는지 확인

MHA = MultiHeadAttention().to(device)
MHA_output = MHA(patch_output)
print(MHA_output.shape)

In [None]:
# Residual block

class ResidualAdd(nn.Module):
  def __init__(self, fn):
    super().__init__()
    self.fn = fn

  def forward(self, x, **kwargs):
    res = x
    x = self.fn(x, **kwargs)
    x += res
    return x

In [None]:
class FeedForwardBlock(nn.Sequential):
  def __init__(self, emb_size, expansion = 4, drop_p = 0):
    super().__init__(
        nn.Linear(emb_size, expansion * emb_size),
        nn.GELU(),
        nn.Dropout(drop_p),
        nn.Linear(expansion * emb_size, emb_size)
    )

In [None]:
# check
x = torch.randn(16,1,128).to(device)
model = FeedForwardBlock(128).to(device)
output = model(x)
print(output.shape)

In [None]:
# TransformerEncoderBlock

class TransformerEncoderBlock(nn.Sequential):
  def __init__(self, emb_size = 768, drop_p=0., forward_expansion = 4, forward_drop_p = 0., **kwargs):
    super().__init__(
        ResidualAdd(nn.Sequential(
            nn.LayerNorm(emb_size),
            MultiHeadAttention(emb_size, **kwargs),
            nn.Dropout(drop_p)
        )),
        ResidualAdd(nn.Sequential(
          nn.LayerNorm(emb_size),
          FeedForwardBlock(emb_size, expansion= forward_expansion,
                         drop_p = forward_drop_p),
        nn.Dropout(drop_p)
    ))
    )

In [None]:
# check

model = TransformerEncoderBlock().to(device)
output = model(patch_output).to(device)
print(output.shape)

In [None]:
class TransformerEncoder(nn.Sequential):
  def __init__(self, depth =12, **kwargs):
    super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])


In [None]:
model = TransformerEncoderBlock().to(device)
output = model(patch_output)
print(output.shape)

In [None]:
# classficiationHead

class classificationHead(nn.Sequential):
  def __init__(self, emb_size = 768, n_classes = 10):
    super().__init__(
        Reduce('b n e -> b e', reduction = "mean"),
        nn.LayerNorm(emb_size),
        nn.Linear(emb_size, n_classes)
    )

In [None]:
# check
x = torch.rand(16,1,768).to(device)
model = classificationHead().to(device)
output = model(x)
print(output.shape)

<!--## VIT 코딩 -->