<a href="https://colab.research.google.com/github/tae-yeop/transformer-adventure/blob/main/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://github.com/The-AI-Summer/self-attention-cv/blob/8280009366b633921342db6cab08da17b46fdf1c/self_attention_cv/transformer_vanilla/transformer_block.py

In [None]:
!pip install einops

Collecting einops
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops
Successfully installed einops-0.3.2


In [None]:
import numpy as np
import torch

from einops import rearrange
from torch import nn

### Einsum

In [None]:
x = torch.rand((2,3))

In [None]:
# Permutation of Tensors
torch.einsum("ij->ji", x)

In [None]:
# Summation
torch.einsum("ij->",x)

In [None]:
# Column sum
torch.einsum("ij->j", x)

In [None]:
# Row sum
torch.einsum("ij->i", x)

In [None]:
# Mat-vector mul
v = torch.rand((1,3))
torch.einsum("ik, jk-> ij", x, v)  

In [None]:
# Mat- MAt mul
x.mm(x.t())

In [None]:
torch.einsum('ij,kj->ik', x,x) #2x2 = 2x3 x 3x2

In [None]:
# dot product first row with first row of matrix
torch.einsum("i,i->", x[0], x[0])

In [None]:
# dot product with matrix
torch.einsum("ij,ij->", x,x)

In [None]:
# Hadarmard product (element-wise mul)
torch.einsum("ij,ij->ij",x,x)

In [None]:
# Outer Product
a = torch.rand((3))
b = torch.rand((5))
torch.einsum("i,j->ij", a, b)

In [None]:
# Batch Mat Mul torch.bmm
a = torch.rand((3,2,5))
b = torch.rand((3,5,3))
torch.einsum("ijk, ikl->ijl", a,b)

In [None]:
# matrix diagonal
x = torch.rand((3,3))
print(x)
print(torch.einsum("ii->i",x))

In [None]:
# matrix trace
torch.einsum("ii->",x)

# Self Attention

In [None]:
class SelfAttention(nn.Module):
  def __init__(self, dim):
    super().__init__()
    self.to_qvk = nn.Linear(dim, dim*3, bias=False)
    self.scale_factor = dim ** -0.5

  def forward(self, x, mask=None):
    """
    x : [b, T, d]
    Returns : [b, T, d]
    """
    assert x.dim() == 3, '3D tensor must be provided'
    # [B, tokens, dim*3]
    qvk = self.to_qvk(x)

    q, k, v = tuple(rearrange(qvk, 'b t (d k) -> k b t d', k=3))

    # [batch, tokens, tokens]
    scaled_dot_prod = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale_factor

    if mask is not None:
      # check shape = [tokens, tokens]
      assert mask.shape == scaled_dot_prod.shape[1:]
      scaled_dot_prod = scaled_dot_prod.masked_fill(mask==0, -np.inf)
    attention = torch.softmax(scaled_dot_prod, dim=-1)
    return torch.einsum('b i j, b j d -> b i d', attention, v)


In [None]:
# Test
sa = SelfAttention(32)
t = torch.randn((8, 16, 32))
sa(t).shape

torch.Size([8, 16, 32])

# MHSA

In [None]:
def compute_mhsa(q, k, v, scaled_factor=1, mask=None):
  """
  Returns : [batch, heads, tokens, dim]
  """
  # scaled_dot_prod.shape = [b, h, token, token]
  scaled_dot_prod = torch.einsum('... i d, ... j d -> ... i j', q, k) * scaled_factor

  if mask is not None:
    assert mask.shape == scaled_dot_prod.shape[2:]
    scaled_dot_prod = scaled_dot_prod.masked_fill(maks==0, -np.inf)

  attention = torch.softmax(scaled_dot_prod, dim=-1)
  return torch.einsum('... i j, ... j d -> ... i d', attention, v)

In [None]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self, dim, heads=8, dim_head=None):
    """
    Implementation of multi-head attention layer of the original transformer model.
    einsum and einops.rearrange is used whenever possible
    Args:
        dim: token's dimension, i.e. word embedding vector size
        heads: the number of distinct representations to learn
        dim_head: the dim of the head. In general dim_head<dim.
        However, it may not necessary be (dim/heads)
    """
    # dim이 head의 배수로 맞아떨어지지 않는 경우까지 고려했음
    super().__init__()
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    _dim = self.dim_head * heads

    self.heads = heads
    self.to_qvk = nn.Linear(dim, _dim*3, bias=False)
    
    self.W_0 = nn.Linear(_dim, dim, bias=False)
    self.scale_factor = self.dim_head * -0.5

  def forward(self, x, mask=None):
    assert x.dim() == 3
    qkv = self.to_qvk(x)

    q,k,v = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d', k=3, h=self.heads))

    out = compute_mhsa(q, k, v, self.scale_factor)

    out = rearrange(out, 'b h t d -> b t (h d)')
    return self.W_0(out)



In [None]:
mhsa = MultiHeadSelfAttention(32, 8)
t = torch.randn((8, 16, 32))
mhsa(t).shape

torch.Size([8, 16, 32])

# Vanilla Transformer

In [None]:
import os
import random
from typing import List, Tuple

import numpy as np
import torch
from einops import repeat
from torch import Tensor, nn

In [None]:
def expand_to_batch(tensor, desire_size):
  tile = desired_size // tensor.shape[0]
  return repeat(tensor, 'b ... -> (b tile) ...', tile=tile)

def init_random_seed(seed, gpu=False):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  os.envision['PYTHONHASHSEED'] = str(seed)
  if gpu:
    torch.backends.cudnn.deterministic = True

# from https://huggingface.co/transformers/_modules/transformers/modeling_utils.html
def get_module_device(parameter : nn.Module):
  try:
    return next(parameter.parameters()).device
  except StopIteration:
    # For nn.DataParallel compatibility in PyTorch 1.5
    def find_tensor_attributes(module : nn.Module) -> List[Tuple[str, Tensor]]:
      tuples = [(k, v) for k,v in module.__dict__.items() if torch.is_tensor(v)]
      return tuples
  gen = parameter._named_members(get_members_fn=find_tensor_attributes)
  first_tuple = next(gen)
  return first_tuple[1].device

In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, dim, heads=8, dim_head=None, dim_linear_block=1024, dropout=0.1, activation=nn.GELU,
               mhsa=None, prenorm=False):
    """
    Args:
        dim: token's vector length
        heads: number of heads
        dim_head: if none dim/heads is used
        dim_linear_block: the inner projection dim
        dropout: probability of droppping values
        mhsa: if provided you can change the vanilla self-attention block
        prenorm: if the layer norm will be applied before the mhsa or after
    """
    super().__init__()
    self.mhsa = mhsa if mhsa is not None else MultiHeadSelfAttention(dim=dim, heads=heads, dim_head_dim = dim_head)
    self.prenorm = prenorm
    self.drop = nn.Dropout(dropout)
    self.norm_1 = nn.LayerNorm(dim)
    self.norm_2 = nn.LayerNorm(dim)

    self.linear = nn.Sequential(nn.Linear(dim, dim_linear_block), 
                                activation(), 
                                nn.Dropout(dropout),
                                nn.Linear(dim_linear_block, dim),
                                nn.Dropout(dropout))
  
  def forward(self, x, mask=None):
    if self.prenorm:
      y = self.drop(self.mhsa(self.norm_1(x), mask)) + x
      out = self.linear(self.norm_2(y)) + y
    else:
      y = self.norm_1(self.drop(self.mhsa(x, mask)) + x)
      out = self.norm_2(self.linear(y) + y)

      return out

In [None]:
class TransformerEncoder(nn.Module):
  def __init__(self, dim, blocks=6, heads=8, dim_head=None, dim_linear_block=1024, dropout=0, prenorm=False):
    super().__init__()
    self.block_list = [TransformerBlock(dim, heads, dim_head, dim_linear_block, dropout, prenorm=prenomr) for _ in range(blocks)]
    self.layers = nn.ModuleList(self.block_list)

  def forward(self, x, mask=None):
    for layer in self.layers:
      x = layer(x, mask)
    return x

In [None]:
import copy

def get_clones(module, N):
  return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [None]:
class Encoder(nn.Module):
  def __init__(self, vocab_size, d_model, N, heads, dropout):
    super().__init__()
    self.N = N
    self.embed = Embedder(vocab_size, d_model)
    self.pe = PositionalEncoder(d_model, dropout=dropout)
    self.layers = get_clones(Encoder)

### ViT
- https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
- https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py
- https://github.dev/huggingface/transformers

In [None]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [None]:
import logging
from typing import Callable, List, Optional, Tuple, Union

import torch
from torch import nn as nn
import torch.nn.functional as F

from itertools import repeat
import collections.abc


# From PyTorch internals
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple


class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    dynamic_img_pad: torch.jit.Final[bool]

    def __init__(
            self,
            img_size: Optional[int] = 224,
            patch_size: int = 16,
            in_chans: int = 3,
            embed_dim: int = 768,
            norm_layer: Optional[Callable] = None,
            flatten: bool = True,
            bias: bool = True,
            strict_img_size: bool = True,
            dynamic_img_pad: bool = False,
    ):
        super().__init__()
        self.patch_size = to_2tuple(patch_size)
        if img_size is not None:
            self.img_size = to_2tuple(img_size)
            self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
            self.num_patches = self.grid_size[0] * self.grid_size[1]
        else:
            self.img_size = None
            self.grid_size = None
            self.num_patches = None

        self.strict_img_size = strict_img_size
        self.dynamic_img_pad = False

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
        self.flatten = False
    def forward(self, x):
        B, C, H, W = x.shape
        if self.dynamic_img_pad:
            pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
            pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
            x = F.pad(x, (0, pad_w, 0, pad_h))
        x = self.proj(x)
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # NCHW -> NLC
        x = self.norm(x)
        return x

In [None]:
class ViTPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
        self.mean_pool = config.mean_pool
        
    def forward(self, hidden_states):
        # first token or mean token
        hidden_states = hidden_states.mean(dim=1) if self.mean_pool else hidden_states[:, 0]
        pooled_output = self.dense(hidden_states)
        pooled_output = self.activation(pooled_output)
        return pooled_output

In [None]:
class ViTEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_hidden_layers = config.num_hidden_layers
        self.norm = nn.LayerNorm(self.hidden_size)
        self.layers = nn.ModuleList([ViTLayer(config) for _ in range(self.num_hidden_layers)])

    def forward(self, hidden_states, output_hidden_states: bool = False):
        all_hidden_states = () if output_hidden_states else None
        
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            layer_outputs = layer_module(hidden_states)
            hidden_states = layer_outputs[0]

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states, )


        return ViTEncoderOutput(last_hidden_state=hidden_states,
                                hidden_states=all_hidden_states)

In [None]:
class ViTModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        image_height, image_width = config.image_size, config.image_size
        self.patch_size = config.patch_size
        self.num_channels = config.num_channels
        self.hidden_size = config.hidden_size
        self.drop_rate = config.drop_rate

        
        assert image_height % self.patch_size == 0 and image_width % self.patch_size == 0, 'Image dimensions must be divisible by the patch size'
        num_patches = (image_height // self.patch_size) * (image_width // self.patch_size)
        patch_dim = self.num_channels * self.patch_size * self.patch_size

        self.patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size),
            nn.LayerNorm(patch_dim), # pre-norm
            nn.Linear(patch_dim, self.hidden_size),
            nn.LayerNorm(self.hidden_size)
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches +1, self.hidden_size))
        self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
        self.dropout = nn.Dropout(self.drop_rate)

        self.encoder = ViTEncoder(config)
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.pooler = ViTPooler(config) if config.add_pooling_layer else None

    def forwrad(self, pixel_values):
        embedding_output = self.patch_embedding(pixel_values)
        batch, seq_len,_ = embedding_output.shape
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=batch)
        embedding_output = torch.cat((cls_tokens, embedding_output), dim=1)
        # x가 더 작은 크기가 오면 여기서 n
        embedding_output += self.pos_embedding[:, :(seq_len+1)]
        embedding_output = self.dropout(embedding_output)

        sequence_output = self.encoder(embedding_output)
        sequence_output = self.layernorm(sequence_output)

        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        return sequence_output, pooled_output