<a href="https://colab.research.google.com/github/tae-yeop/transformer-adventure/blob/main/attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://github.com/The-AI-Summer/self-attention-cv/blob/8280009366b633921342db6cab08da17b46fdf1c/self_attention_cv/transformer_vanilla/transformer_block.py

In [None]:
!pip install einops

Collecting einops
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops
Successfully installed einops-0.3.2


In [None]:
import numpy as np
import torch

from einops import rearrange
from torch import nn

# Self Attention

In [None]:
class SelfAttention(nn.Module):
  def __init__(self, dim):
    super().__init__()
    self.to_qvk = nn.Linear(dim, dim*3, bias=False)
    self.scale_factor = dim ** -0.5

  def forward(self, x, mask=None):
    """
    x : [b, T, d]
    Returns : [b, T, d]
    """
    assert x.dim() == 3, '3D tensor must be provided'
    # [B, tokens, dim*3]
    qvk = self.to_qvk(x)

    q, k, v = tuple(rearrange(qvk, 'b t (d k) -> k b t d', k=3))

    # [batch, tokens, tokens]
    scaled_dot_prod = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale_factor

    if mask is not None:
      # check shape = [tokens, tokens]
      assert mask.shape == scaled_dot_prod.shape[1:]
      scaled_dot_prod = scaled_dot_prod.masked_fill(mask==0, -np.inf)
    attention = torch.softmax(scaled_dot_prod, dim=-1)
    return torch.einsum('b i j, b j d -> b i d', attention, v)


In [None]:
# Test
sa = SelfAttention(32)
t = torch.randn((8, 16, 32))
sa(t).shape

torch.Size([8, 16, 32])

# MHSA

In [None]:
def compute_mhsa(q, k, v, scaled_factor=1, mask=None):
  """
  Returns : [batch, heads, tokens, dim]
  """
  # scaled_dot_prod.shape = [b, h, token, token]
  scaled_dot_prod = torch.einsum('... i d, ... j d -> ... i j', q, k) * scaled_factor

  if mask is not None:
    assert mask.shape == scaled_dot_prod.shape[2:]
    scaled_dot_prod = scaled_dot_prod.masked_fill(maks==0, -np.inf)

  attention = torch.softmax(scaled_dot_prod, dim=-1)
  return torch.einsum('... i j, ... j d -> ... i d', attention, v)

In [None]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self, dim, heads=8, dim_head=None):
    """
    Implementation of multi-head attention layer of the original transformer model.
    einsum and einops.rearrange is used whenever possible
    Args:
        dim: token's dimension, i.e. word embedding vector size
        heads: the number of distinct representations to learn
        dim_head: the dim of the head. In general dim_head<dim.
        However, it may not necessary be (dim/heads)
    """
    # dim이 head의 배수로 맞아떨어지지 않는 경우까지 고려했음
    super().__init__()
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    _dim = self.dim_head * heads

    self.heads = heads
    self.to_qvk = nn.Linear(dim, _dim*3, bias=False)
    
    self.W_0 = nn.Linear(_dim, dim, bias=False)
    self.scale_factor = self.dim_head * -0.5

  def forward(self, x, mask=None):
    assert x.dim() == 3
    qkv = self.to_qvk(x)

    q,k,v = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d', k=3, h=self.heads))

    out = compute_mhsa(q, k, v, self.scale_factor)

    out = rearrange(out, 'b h t d -> b t (h d)')
    return self.W_0(out)



In [None]:
mhsa = MultiHeadSelfAttention(32, 8)
t = torch.randn((8, 16, 32))
mhsa(t).shape

torch.Size([8, 16, 32])

# Vanilla Transformer

In [None]:
import os
import random
from typing import List, Tuple

import numpy as np
import torch
from einops import repeat
from torch import Tensor, nn

In [None]:
def expand_to_batch(tensor, desire_size):
  tile = desired_size // tensor.shape[0]
  return repeat(tensor, 'b ... -> (b tile) ...', tile=tile)

def init_random_seed(seed, gpu=False):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  os.envision['PYTHONHASHSEED'] = str(seed)
  if gpu:
    torch.backends.cudnn.deterministic = True

# from https://huggingface.co/transformers/_modules/transformers/modeling_utils.html
def get_module_device(parameter : nn.Module):
  try:
    return next(parameter.parameters()).device
  except StopIteration:
    # For nn.DataParallel compatibility in PyTorch 1.5
    def find_tensor_attributes(module : nn.Module) -> List[Tuple[str, Tensor]]:
      tuples = [(k, v) for k,v in module.__dict__.items() if torch.is_tensor(v)]
      return tuples
  gen = parameter._named_members(get_members_fn=find_tensor_attributes)
  first_tuple = next(gen)
  return first_tuple[1].device

In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, dim, heads=8, dim_head=None, dim_linear_block=1024, dropout=0.1, activation=nn.GELU,
               mhsa=None, prenorm=False):
    """
    Args:
        dim: token's vector length
        heads: number of heads
        dim_head: if none dim/heads is used
        dim_linear_block: the inner projection dim
        dropout: probability of droppping values
        mhsa: if provided you can change the vanilla self-attention block
        prenorm: if the layer norm will be applied before the mhsa or after
    """
    super().__init__()
    self.mhsa = mhsa if mhsa is not None else MultiHeadSelfAttention(dim=dim, heads=heads, dim_head_dim = dim_head)
    self.prenorm = prenorm
    self.drop = nn.Dropout(dropout)
    self.norm_1 = nn.LayerNorm(dim)
    self.norm_2 = nn.LayerNorm(dim)

    self.linear = nn.Sequential(nn.Linear(dim, dim_linear_block), 
                                activation(), 
                                nn.Dropout(dropout),
                                nn.Linear(dim_linear_block, dim),
                                nn.Dropout(dropout))
  
  def forward(self, x, mask=None):
    if self.prenorm:
      y = self.drop(self.mhsa(self.norm_1(x), mask)) + x
      out = self.linear(self.norm_2(y)) + y
    else:
      y = self.norm_1(self.drop(self.mhsa(x, mask)) + x)
      out = self.norm_2(self.linear(y) + y)

      return out

In [None]:
class TransformerEncoder(nn.Module):
  def __init__(self, dim, blocks=6, heads=8, dim_head=None, dim_linear_block=1024, dropout=0, prenorm=False):
    super().__init__()
    self.block_list = [TransformerBlock(dim, heads, dim_head, dim_linear_block, dropout, prenorm=prenomr) for _ in range(blocks)]
    self.layers = nn.ModuleList(self.block_list)

  def forward(self, x, mask=None):
    for layer in self.layers:
      x = layer(x, mask)
    return x

In [None]:
import copy

def get_clones(module, N):
  return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [None]:
class Encoder(nn.Module):
  def __init__(self, vocab_size, d_model, N, heads, dropout):
    super().__init__()
    self.N = N
    self.embed = Embedder(vocab_size, d_model)
    self.pe = PositionalEncoder(d_model, dropout=dropout)
    self.layers = get_clones(Encoder)