<a href="https://colab.research.google.com/github/yukitiec/Research/blob/main/VisionTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Import Library

In [1]:
import torch\
import torch.nn as nn
import torch.nn.functional as F

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

#Make Module

##Input Layer

In [22]:
class VitInputLayer(nn.Module):
  def __init__(self,
               in_channels:int = 1,
               emb_dim:int = 384,
               num_patch_row:int = 2,
               image_size:int = 28):
    """
    in_channels : num of channels of input images
    emb_dim : length of vector after embedded
    num_patch_row : num of patch in height axis
    image size : image size
    """
    super(VitInputLayer,self).__init__()
    self.in_channels = in_channels
    self.emb_dim = emb_dim
    self.num_patch_row = num_patch_row
    self.image_size = image_size

    #num of patch 
    self.num_patch = self.num_patch_row**2

    #size of patch
    self.patch_size = int(self.image_size//self.num_patch_row)

    #make input images into patch and embedded one with Conv2D
    self.patch_emb_layer = nn.Conv2d(
        in_channels = self.in_channels,
        out_channels = self.emb_dim,
        kernel_size = self.patch_size,
        stride = self.patch_size
    )

    #class token
    self.cls_token = nn.Parameter(
        torch.randn(1,1,emb_dim)
    )

    #positional embedding
    #prepare (batch_size+1) vectors for embedded vectors because the header is class token
    self.pos_emb = nn.Parameter(
        torch.randn(1,self.num_patch+1,emb_dim)
    )

  def forward(self,x: torch.Tensor) -> torch.Tensor:
    """
    Args:
      x : input image (B,C,H,W)
        B:Batch size, C:Channel, H:Height, W:Width


    Return:
      z_0 : input for Vit (B,N,D)
        B:Batch size, N : Num of Token D: Length of embedded vectors
    """

    #Patch embedding
    #(B,C,H,W) -> (B,C,H/P,W/P)
    z_0 = self.patch_emb_layer(x)

    #patch flatten
    #(B,C,H/P,W/P) -> (B,D，Np) (D = (P^2*C), after 2 is flatten)
    z_0 = z_0.flatten(2)

    #reshape the matrix
    #(B,D,Np) -> (B,Np,D)
    z_0 = z_0.transpose(1,2)

    #Concatenate class token at the head of embeddings
    #(B,Np,D) -> (B,N,D) N = Np + 1
    #cls token : (1,1,D) -> (B,1,D)
    z_0 = torch.cat(
        [self.cls_token.repeat(repeats = (x.size(0),1,1)), z_0], dim=1
    )

    #Positional Embedding
    #(B,N,D) -> (B,N,D)
    z_0 = z_0 + self.pos_emb

    return z_0

##Multi-Head Self-Attention

In [23]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self,
               emb_dim:int = 384,
               head:int = 3,
               dropout:float = 0):
    """
    Args:
      emb_dim : the length of embedded vector
      head : num of head
      dropout : the rate of dropout
    """

    super(MultiHeadSelfAttention, self).__init__()
    self.head = head
    self.emb_dim = emb_dim
    self.head_dim = emb_dim // head
    self.sqrt_dh = self.head_dim**0.5 #for attention weight

    #linear layer for q,k,v
    self.w_q = nn.Linear(emb_dim,emb_dim,bias=False)
    self.w_k = nn.Linear(emb_dim,emb_dim,bias=False)
    self.w_v = nn.Linear(emb_dim,emb_dim,bias=False)

    #dropout layer
    self.attn_drop = nn.Dropout(dropout)

    #linear layer for ouput of MHSA
    self.w_o = nn.Sequential(
        nn.Linear(emb_dim,emb_dim),
        nn.Dropout(dropout)
    )

  def forward(self, z:torch.Tensor) -> torch.Tensor:
    """
    Args:
      z: input for MHSA (B,N,D)
        B:Batch size, N: Num of patches, D:length of embedded vectors
    
    Return:
      out: output of MHSA (B,N,D)
    """

    batch_size, num_patch, _ = z.size()

    #embedding
    q = self.w_q(z)
    k = self.w_k(z)
    v = self.w_v(z)

    #split q,k,v for MHSA
    #(B,N,D) -> (B,N,h,D//h)
    q = q.view(batch_size,num_patch,self.head,self.head_dim)
    k = k.view(batch_size,num_patch,self.head,self.head_dim)
    v = v.view(batch_size,num_patch,self.head,self.head_dim)

    #arrange data for self-attention
    #(B,N,h,D//h) -> (B,h,N,D//h)
    q = q.transpose(1,2)
    k = k.transpose(1,2)
    v = v.transpose(1,2)

    #arragen k for attention weight 
    #(B,h,N,D//h) -> (B,h,h//D,N) 
    k_T = k.transpose(2,3)
    #inner dot
    #(B,h,N,D//h)*(B,h,h//D,N) -> (B,h,N,N)
    dots = (q@k_T)/self.sqrt_dh
    #softmax in row axis
    attn = F.softmax(dots,dim=-1)
    #Dropout
    attn = self.attn_drop(attn)

    #get new embeddings
    #(B,h,N,N)*(B,h,N,D//h) -> (B,h,N,D//h)
    out = attn@v
    #(B,h,N,D//h) -> (B,N,h,D//h)
    out = out.transpose(1,2)
    #(B,N,h,D//h) -> (B,N,D)
    out = out.reshape(batch_size,num_patch,self.emb_dim)

    #output layer
    out = self.w_o(out)

    return out

##Encoder

In [24]:
class VitEncoderBlock(nn.Module):
  def __init__(self,
                emb_dim:int = 384,
                head:int = 8,
                hidden_dim:int = 384*4,
                dropout:float = 0):
    """
    Args:
      emb_dim : length of embedded vectors
      head : num of heads in MHSA
      hidden_dim : length of the middle layer of MLP in Encoder Block, here 384*4 as in paper
      dropout : dropout rate
    """

    super(VitEncoderBlock,self).__init__()
    #first LayerNormalization
    self.ln1 = nn.LayerNorm(emb_dim)
    #MHSA
    self.msa = MultiHeadSelfAttention(
        emb_dim = emb_dim,
        head = head,
        dropout = dropout
    )

    #second LayerNormalization
    self.ln2 = nn.LayerNorm(emb_dim)

    #MLP
    self.mlp = nn.Sequential(
        nn.Linear(emb_dim, hidden_dim),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_dim,emb_dim),
        nn.Dropout(dropout)
    )

  def forward(self,z:torch.Tensor) ->  torch.Tensor:
    """
    Args:
      z : input for Encoder Block (B,N,D)
    
    Return:
      out:out for Encoder Block (B,N,D)
    """
    #first half
    out = self.msa(self.ln1(z)) + z
    #second half
    out = self.mlp(self.ln2(out)) + out
    return out

###Checking Code

In [5]:
batch_size, channel, height, width = 2,3,32,32
x = torch.randn(batch_size, channel, height, width)
input_layer = VitInputLayer(num_patch_row = 2)
z_0 = input_layer(x)

#check if the shape is (2,5,384)
print("after input layer")
print(z_0.shape)

#MHSA
mhsa = MultiHeadSelfAttention()
out = mhsa(z_0)

print("after MHSA layer")
print(out.shape)

vit_enc = VitEncoderBlock()
z_1 = vit_enc(z_0)

print("after Vit Encoder Block")
print(z_1.shape)

after input layer
torch.Size([2, 5, 384])
after MHSA layer
torch.Size([2, 5, 384])
after Vit Encoder Block
torch.Size([2, 5, 384])


##Visual Transformer

In [25]:
class Vit(nn.Module):
  def __init__(self,
               in_channels:int = 1,
               num_classes:int = 10,
               emb_dim:int = 384,
               num_patch_row:int = 2,
               image_size:int = 28,
               num_blocks:int = 7,
               head:int = 8,
               hidden_dim:int = 384*4,
               dropout:float = 0.):
    """
    Args:
      in_channesl : num of channels of input image
      num_classes : num of classes
      emb_dim : length of embedded vectors
      num_patch_row : num of patch per row
      image_size : image length
      num_blocks : num of Encoder Block
      head : num of heads
      hidden_dim : length of middle layer in Encoder block
      dropout : rate of dropout
    """
    super(Vit, self).__init__()

    #Input layer
    self.input_layer = VitInputLayer(
        in_channels,
        emb_dim,
        num_patch_row,
        image_size
    )

    #Encoder
    self.encoder = nn.Sequential(*[
        VitEncoderBlock(
            emb_dim = emb_dim,
            head = head,
            hidden_dim = hidden_dim,
            dropout = dropout
        )
        for _ in range(num_blocks)
    ])

    #MLP Head
    self.mlp_head = nn.Sequential(
        nn.LayerNorm(emb_dim),
        nn.Linear(emb_dim,num_classes)
    )

  def forward(self, x:torch.Tensor) -> torch.Tensor:
    """
    Args:
      x : input for Visual Transformer (B,C,H,W)
        B:Batch size, C:Num of Channels, H:Height, W:Width
    
    Return:
      out : output of ViT (B,M)
        B:Batch size, M: num of classes
    """

    #Input layer
    #(B,C,H,W) -> (B,N,D)
    #N : (Num of tokens) + 1(Class), D : length of embedded vecotor
    out = self.input_layer(x)

    #Encoder
    #(B,N,D) -> (B,N,D)
    out = self.encoder(out)
    
    #extract class token
    #(B,N,D) -> (B,D)
    cls_token = out[:,0]
    
    #MLP
    #(B,D) -> (B,M)
    pred = self.mlp_head(cls_token)
    
    return pred


###Checking Code

In [7]:
num_classes = 10
batch_size, channel, height, width = 2,3,32,32
x = torch.randn(batch_size,channel,height,width)
print(x.shape)
vit = Vit(in_channels = channel,num_classes = num_classes)
pred = vit(x)

print(pred.shape)

torch.Size([2, 3, 32, 32])
torch.Size([2, 10])


#Data Load

In [8]:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [9]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:03<00:00, 8569233.42it/s] 


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 142621.98it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 2639738.59it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 19244453.65it/s]

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw






In [10]:
batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


##Setup device

In [11]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


#Training Session

##Model setting

In [26]:
model = Vit().to(device)
print(model)

Vit(
  (input_layer): VitInputLayer(
    (patch_emb_layer): Conv2d(1, 384, kernel_size=(14, 14), stride=(14, 14))
  )
  (encoder): Sequential(
    (0): VitEncoderBlock(
      (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (msa): MultiHeadSelfAttention(
        (w_q): Linear(in_features=384, out_features=384, bias=False)
        (w_k): Linear(in_features=384, out_features=384, bias=False)
        (w_v): Linear(in_features=384, out_features=384, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (w_o): Sequential(
          (0): Linear(in_features=384, out_features=384, bias=True)
          (1): Dropout(p=0.0, inplace=False)
        )
      )
      (ln2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=384, out_features=1536, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.0, inplace=False)
        (3): Linear(in_features=1536, out_features=384, bias=True)
    

##Training setup

In [27]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
#torch.optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0,

In [16]:
print(model.parameters())

<generator object Module.parameters at 0x7fa4ec1ad7e0>


##Train session & Test session

In [28]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [29]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.331674  [   64/60000]
loss: 0.914306  [ 6464/60000]
loss: 0.511618  [12864/60000]
loss: 0.944841  [19264/60000]
loss: 0.728349  [25664/60000]
loss: 0.587158  [32064/60000]
loss: 0.670818  [38464/60000]
loss: 0.586845  [44864/60000]
loss: 0.696419  [51264/60000]
loss: 0.628529  [57664/60000]
Test Error: 
 Accuracy: 78.2%, Avg loss: 0.570324 

Epoch 2
-------------------------------
loss: 0.443964  [   64/60000]
loss: 0.513597  [ 6464/60000]
loss: 0.464759  [12864/60000]
loss: 0.491383  [19264/60000]
loss: 0.592807  [25664/60000]
loss: 0.557748  [32064/60000]
loss: 0.442315  [38464/60000]
loss: 0.555614  [44864/60000]
loss: 0.589858  [51264/60000]
loss: 0.551738  [57664/60000]
Test Error: 
 Accuracy: 81.5%, Avg loss: 0.524524 

Epoch 3
-------------------------------
loss: 0.443368  [   64/60000]
loss: 0.410286  [ 6464/60000]
loss: 0.350519  [12864/60000]
loss: 0.546039  [19264/60000]
loss: 0.525412  [25664/60000]
loss: 0.512137  [32064/600

#Saving model

In [None]:
import os
model_path = '/content/gdrive/My Drive/YAMAKAWA_LAB/Vit'
if not os.path.exists(model_path):
  os.makedirs(model_path)

torch.save(model.state_dict(), os.path.join(model_path,"model.pth"))
print("Saved PyTorch Model State to model.pth")

#Visualize hidden layer

##Visualize positional encoding

In [None]:
import math
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

#model setting
model = vit()
#load pretrained model
checkpoint = torch.load('/content/gdrive/My Drive/YAMAKAWA_LAB/Vit/model.pth')
checkpoint_model = checkpoint["model"]
model.load_state_dict(checkpoint_model)

#load positional embeddings from model
#N : num of patches + 1(cls token), D:dimension
pos_embed = model.state_dict()["pos_embed"]#shape:(1,N,D)
H_and_W = int(math.sqrt(pos_embed.shape[-1]-1)) #subtract cls token, -> calculate num_patch_row

#visualize cosine similarity
fig = plt.figure(figsize=(10,10))
for i in range(1,pos_embed.shape[1]):
  sim = F.cosine_similarity(pos_embed[0,i:i+1],pos_embed[0,1:],dim=1)
  sim = sim.reshape((H_and_W,H_and_W)).detach().cpu().numpy() #numpy uses CPU, so change device to CPU when changing tensor to numpy
  ax = fig.add_subplot(H_and_W,H_and_W,i)
  ax.imshow(sim)
plt.savefig("/content/gdrive/My Drive/YAMAKAWA_LAB/Vit/position_embedding.pdf")

##Visualize with Attention Rollout

In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image

def extract(pre_model,target,inputs):
  """Extract Attention Weight
  """
  feature = None
  def forward_hook(module,target,outputs):
    """save output of forward ias a global variables
    """
    global blocks
    blocks = outputs.detach()

  #enroll callback function
  handle = target.register_forward_hook(forward_hook)
  #inference
  pre_model.eval()
  pre_model(inputs)
  #remove callback
  handle.remove()
  return blocks

#model setting
model = vit()
#load pretrained model
checkpoint = torch.load('/content/gdrive/My Drive/YAMAKAWA_LAB/Vit/model.pth')
checkpoint_model = checkpoint["model"]
model.load_state_dict(checkpoint_model)

#get each Attention weight 
#L :num of layers, H:num of Heads, N:num of patches + cls token(1)
attention_weight = []

#resize the images and crop image ceter
#change the image shape to fit the model input
normalize = transforms.Normalize(mean=[0.5,0.5,0.5],std = [0.5,0.5,0.5])
transform = transforms.Compose([
    transforms.Resize(192,192), #64*3
    trainsforms.CenterCrop(192,192),
    transforms.ToTensor(),
    normalize,
])

#load image file
image = Image.open("/content/gdrive/My Drive/YAMAKAWA_LAB/技術補佐員/") #image name
x = transform(image) #shape:(1,3,224,224)

for i in range(len(model.blocks)):
  target_module = model.blocks[i].attn.attn_drop
  features = extract(model, target_module, x) #shape:(1,H,N,N)
  attention_weight.append([features.to("cpu").detach.numpy().copy()])
attention_weight = np.squeeze(np.concatenate(attention_weight),axis = 1) #shape: (L,H,N,N)

#Calculation
#Average in head axis
mean_head = np.mean(attention_weight,axis=1) #(L,N,N)

#add N*N eye matrix to mean_head
mean_head = mean_head + np.eye(mean_head.shape[1])

#Normalize
mean_head = mean_head / mean_head.sum(axis=(1,2))[:,np.newaxis,np.newaxis] #(L,N,N)

#multiply in layer axis
v = mean_head[-1]
for n in range(1,len(mean_head)):
  v = np.matmul(v,mean_head[-1-n])

#make attention map
mask = v[0,1;].reshape(3,3) # num of patches
attention_map = cv2.resize(mask/mask.max(), (ori_img.shape[2],ori_img.shape[3]))[...,np.newaxis]

#show attention map
plt.imshow(attention_map)
plt.savefig('/content/gdrive/My Drive/YAMAKAWA_LAB/Vit/attention_map.pdf')


SyntaxError: ignored

#Load model

In [None]:
model = Vit().to(device)
model.load_state_dict(torch.load("model.pth"))