## Embedding

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, max_len=512, dropout=0.1):
        super(BERTEmbedding, self).__init__()
        
        self.token_embeddings = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(max_len, embed_size)
        self.segment_embeddings = nn.Embedding(2, embed_size) # 2: A, B segments

        self.layernorm = nn.LayerNorm(embed_size, eps=1e-12, elementwise_affine=True) # affine: y = gamma * x + beta(learnable pe 시에 더 좋은 성능)
        self.dropout = nn.Dropout(p=dropout, inplace=False) # inplace=True: input에 직접 연산을 수행하고 output을 반환하지 않음(메모리 절약)

    def forward(self, tokens, segments):
        
        token_embeds = self.token_embeddings(tokens)

        positions = torch.arange(tokens.size(1)).unsqueeze(0).to(tokens.device) # seq_len -> [[0, 1, 2, ..., seq_len-1]]
        position_embeds = self.position_embeddings(positions)
        
        segment_embeds = self.segment_embeddings(segments)
        
        embeddings = token_embeds + position_embeds + segment_embeds
        embeddings = self.dropout(self.layernorm(embeddings))
        
        return embeddings

vocab_size = 28996
embed_size = 768

model = BERTEmbedding(vocab_size, embed_size)

tokens = torch.randint(0, vocab_size, (1, 10))
segments = torch.randint(0, 2, (1, 10))

embeddings = model(tokens, segments)
print(embeddings.shape) # torch.Size([1, 10, 768])

torch.Size([1, 10, 768])


In [3]:
from transformers import BertModel

model = BertModel.from_pretrained('bert-base-cased')
print(model.embeddings) 

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertEmbeddings(
  (word_embeddings): Embedding(28996, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)


## Model Architecture

In [5]:
class MultiheadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiheadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)
        
    def forward(self, values, keys, queries, mask):
        # values: (N, value_len, embed_size), keys: (N, key_len, embed_size), queries: (N, query_len, embed_size)
        N = queries.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
        
        values = values.reshape(N, value_len, self.heads, self.head_dim).permute(0, 2, 1, 3) # (N, value_len, heads, head_dim) -> (N, heads, value_len, head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim).permute(0, 2, 1, 3)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim).permute(0, 2, 1, 3)
        
        values = self.values(values) # (N, heads, value_len, head_dim) -> (N, heads, value_len, head_dim)
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.embed_size ** (1 / 2)) # (N, heads, query_len, head_dim) * (N, heads, head_dim, key_len) -> (N, heads, query_len, key_len)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))

        attention = F.softmax(scores, dim=-1) # (N, heads, query_len, key_len): key_len에 대한 softmax
        out = torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, query_len, self.heads * self.head_dim) # (N, heads, query_len, key_len) * (N, heads, value_len, head_dim) -> (N, heads, query_len, head_dim) -> (N, query_len, heads, head_dim) -> (N, query_len, heads * head_dim)

        return self.fc_out(out)
    
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = MultiheadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.GELU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query)) # query를 더하는 이유: original input을 representation하기 때문
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out
    
class BERT(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, dropout, forward_expansion, max_length):
        super(BERT, self).__init__()
        self.embedding = BERTEmbedding(vocab_size, embed_size)
        self.transformer_blocks = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion)
                for _ in range(num_layers)
            ]
        )

    def forward(self, x, segments):
        embeddings = self.embedding(x, segments)
        mask = (x != 0).unsqueeze(1).repeat(1, x.shape[1], 1).unsqueeze(1)
        for transformer in self.transformer_blocks:
            embeddings = transformer(embeddings, embeddings, embeddings, mask)
        return embeddings

vocab_size = 28996
embed_size = 768
num_layers = 12
heads = 12
dropout = 0.1
forward_expansion = 4
max_length = 512

model = BERT(vocab_size, embed_size, num_layers, heads, dropout, forward_expansion, max_length)
tokens = torch.randint(0, vocab_size, (1, 10))
segments = torch.randint(0, 2, (1, 10))
output = model(tokens, segments)
print(output.shape)  # torch.Size([1, 10, 768])

torch.Size([1, 10, 768])
