# Transformer
## Tokenizer

In [49]:
import os
os.environ['http_proxy'] = 'http://127.0.0.1:8889'
os.environ['https_proxy'] = 'http://127.0.0.1:8889'

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('FlagAlpha/Llama2-Chinese-7b-Chat')
text = 'Hello World!'
input_id = tokenizer(text)
print(input_id['input_ids'])
decoded_text = tokenizer.decode(input_id['input_ids'])
print(decoded_text)
print(len(tokenizer.vocab))
# for token in input_id['input_ids']:
#     decoded_token = tokenizer.decode([token])
#     print(f'{token}: {decoded_token}')

[1, 15043, 2787, 29991]
<s> Hello World!
32000


## Embedding

In [50]:
from torch import nn
from torch.nn import functional as F
import torch
input_id = torch.tensor(input_id['input_ids'])
# print(input_id)
# input_id_onehot = F.one_hot(input_id, num_classes=tokenizer.vocab_size)
# print(input_id_onehot.shape)

# id_demo = torch.tensor([1, 4, 5, 2, 3, 0])
# print(F.one_hot(id_demo, num_classes=6))

# embedding = nn.Embedding(
#     num_embeddings=tokenizer.vocab_size,
#     embedding_dim=512
# )
# input_id_embedding = embedding(input_id)
# print(input_id_embedding.shape)

## Positional Encoding

In [51]:
import numpy as np

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, batch_size: int):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.batch_size = batch_size
        pe = torch.zeros(batch_size, d_model)
        position = torch.arange(0, batch_size, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor):
        return x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)

# positional_encoding = PositionalEncoding(512, 1)
# input_id_embedding = positional_encoding(input_id_embedding.reshape([1, *input_id_embedding.shape])).reshape(input_id_embedding.shape)

## Encoder
### MultiHeadAttention

In [52]:
import einops

# wk = torch.ones(512, 512)
# wq = torch.ones(512, 512)
# wv = torch.ones(512, 512)
# wo = torch.ones(512, 512)
# def attention(x):
#     key = einops.rearrange(x @ wk, 'n (h d) -> h n d', h=16)
#     query = einops.rearrange(x @ wk, 'n (h d) -> h n d', h=16)
#     value = einops.rearrange(x @ wk, 'n (h d) -> h n d', h=16)
#     attn = (query @ key.transpose(1, 2)) / (512 ** 0.5)
#     attn = F.softmax(attn, dim=-1)
#     return einops.rearrange(attn @ value, 'h n d -> n (h d)') @ wo
# print(attention(input_id_embedding))

# linear(x) = x @ A + b 

class Attention(nn.Module):
    
    def __init__(self, d_model: int, num_heads: int, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.d_model = d_model
        self.num_heads = num_heads
        self.wk = nn.Linear(d_model, d_model, bias=False)
        self.wq = nn.Linear(d_model, d_model, bias=False)
        self.wv = nn.Linear(d_model, d_model, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)
        
    def forward(self, x, k=None, mask=None):
        key = einops.rearrange(self.wk(x if k is None else k), 'b n (h d) -> b h n d', h=self.num_heads)
        query = einops.rearrange(self.wq(x if k is None else k), 'b n (h d) -> b h n d', h=self.num_heads)
        value = einops.rearrange(self.wv(x), 'b n (h d) -> b h n d', h=self.num_heads)
        attn = (query @ key.transpose(-2, -1)) / (self.d_model ** 0.5)
        if mask is not None:
            attn = torch.masked_fill(attn, mask, -1e9)
        attn = F.softmax(attn, dim=-1)
        attn = einops.rearrange(attn @ value, 'b h n d -> b n (h d)')
        return self.wo(attn)

## Encoder / Decoder

In [53]:
class Encoder(nn.Module):
    
    def __init__(self, d_model: int, num_heads: int, batch_size: int):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.attn = Attention(d_model, num_heads)
        self.norm_a = nn.LayerNorm(d_model)
        self.norm_b = nn.LayerNorm(d_model)
        self.norm_c = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model)
        )
        self.attn_cross = Attention(d_model, num_heads)
        
    def forward(self, x: torch.Tensor, k=None, mask=None):
        # x.shape [batch_size, seq_len, embed_dim]
        x = self.norm_a(self.attn(x)) + x
        if mask is not None:
            x = self.norm_b(self.attn_cross(x, k, mask)) + x
        x = self.norm_c(self.ffn(x)) + x
        return x

## Transformer

In [54]:
class Transformer(nn.Module):
    
    def __init__(self, d_model: int, num_heads: int, batch_size: int, num_layers: int):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.batch_size = batch_size
        self.num_layers = num_layers
        self.embed = nn.Embedding(
            num_embeddings=tokenizer.vocab_size,
            embedding_dim=d_model
        )
        self.pe = PositionalEncoding(d_model, batch_size)
        self.encoders = nn.ModuleList([
            Encoder(d_model=d_model, num_heads=num_heads, batch_size=batch_size) 
            for _ in range(num_layers)
        ])
        self.decoders = nn.ModuleList([
            Encoder(d_model=d_model, num_heads=num_heads, batch_size=batch_size) 
            for _ in range(num_layers)
        ])
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.SiLU(),
            nn.Linear(d_model, d_model),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, src: torch.Tensor, tgt: torch.Tensor, mask=None):
        src = self.embed(src)
        tgt = self.embed(tgt)
        src = self.pe(src)
        tgt = self.pe(tgt)
        encoder_outputs = src
        for encoder in self.encoders:
            encoder_outputs = encoder(encoder_outputs)
        decoder_outputs = tgt
        if mask is not None:
            mask = torch.zeros(self.d_model // self.num_heads, self.d_model // self.num_heads)
        for decoder in self.decoders:
            decoder_outputs = decoder(decoder_outputs, encoder_outputs, mask)
        output = self.mlp(decoder_outputs)
        return output
        
transformer = Transformer(d_model=512, num_heads=16, batch_size=20, num_layers=5)
src = torch.randint(0, tokenizer.vocab_size - 1, (20, 20))
tgt = torch.randint(0, tokenizer.vocab_size - 1, (20, 20))
output = transformer(src, tgt)
print(output.size())

torch.Size([20, 20, 512])
