In [None]:
import torch
import torch.nn as nn

### Token Embedding

In [None]:
# Token Embedding

class Token_Embedding(nn.Embedding):
  def __init__(self, vocab_size, d_model):
    super(Token_Embedding, self).__init__(vocab_size, d_model, padding_idx = 1)

### Positional Encoding

In [None]:
# Positional Encoding

class Positional_Encoding(nn.Module):
  def __init__(self, max_len, d_model):
    super(Positional_Encoding, self).__init__()
    self.pos_encoding = self.positional_encoding(max_len, d_model)

  def get_angle(self, pos, i, d_model):
    angle = 1/np.power(10000,2*(i//2)/d_model)
    return pos * angle

  def positional_encoding(self, max_len, d_model):
    pos = self.get_angle(pos = torch.arange(0, max_len, dtype=torch.float32).view(-1,1),
                         i = torch.arange(0, d_model, dtype=torch.float32).view(1,-1),
                         d_model = d_model)
    pos[:,0::2] = torch.sin(pos[:,0::2])
    pos[:,1::2] = torch.cos(pos[:,1::2])
    return pos
  
  def forward(self, input):
    batch_size, seq_len, d_model = input.size()
    return self.pos_encoding[:seq_len,:]


### Embedding Layer

In [None]:
# Transformer Embedding

class Transformer_Embedding(nn.Module):
  def __init__(self, vocab_size, max_len, d_model, drop_rate):
    super(Transformer_Embedding, self).__init__()
    self.tok_embedding = Token_Embedding(vocab_size, d_model)
    self.pos_embedding = Positional_Encoding(max_len, d_model)
    self.dropout = nn.Dropout(drop_rate)

  def forward(self, x):
    tok_emb = self.tok_embedding(x)
    pos_emb = self.pos_embedding(x)
    return self.drop_out(tok_emb + pos_emb) 

### Multi-head Attention

In [None]:
# Self_attention

class Self_Attention(nn.Module):
  def __init__(self):
    super(Self_Attention, self).__init__()
    # dim=-1 : last dimension, dim 방향으로 sum 했을때 1로 기억하자
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, Q, K, V, mask=None):
    # input : (batch_size, n_head, seq_len, d_k)
    # output : (batch_size, n_head, seq_len, d_k)
    batch_size, head, seq_len, d_k = K.size()
    # dot product
    K_t = K.transpose(2,3)
    score = Q @ K_t / np.sqrt(d_k)
    # masking
    if mask is not None:
      score = score.masked_fill(mask==0, -10000)
    # softmax
    score = self.softmax(score)
    # multiply V
    V  = score @ V
    return V, score

class MultiHead_Attention(nn.Module):
  def __init__(self, d_model, n_head):
    super(MultiHead_Attention, self).__init__()
    self.n_head = n_head
    self.attention = Self_Attention()
    self.W_Q = nn.Linear(d_model, d_model)
    self.W_K = nn.Linear(d_model, d_model)
    self.W_V = nn.Linear(d_model, d_model)
    self.W_concat = nn.Linear(d_model, d_model)
  
  def split(self, tensor):
    # input : (batch_size, seq_len, d_model)
    # output : (batch_size, n_head, seq_len, d_k)
    batch_size, seq_len, d_model = tensor.size()

    d_k = d_model//self.n_head
    tensor = tensor.view(batch_size, seq_len, self.n_head, d_k).transpose(1,2)
    return tensor

  def concat(self, tensor):
    # input : (batch_size, n_head, seq_len, d_k)
    # output : (batch_size, seq_len, d_model)
    batch_size, n_head, seq_len, d_k = tensor.size()

    d_model = n_head * d_k
    # transpose, view로만 변환시킬 경우 기존텐서와 메모리를 공유하고 모양만 변함
    # 주소값 연속성이 불변인 것이 문제, contiguous로 새로운 공간에 데이터 복사
    tensor = tensor.transpose(1,2).contiguous().view(batch_size, seq_len, d_model)
    return tensor

  def forward(self, Q, K, V, mask=None):
    # dot product
    Q,K,V = self.W_Q(Q), self.W_K(K), self.W_V(V)
    # split tensor by n_head
    Q,K,V = self.split(Q),self.split(K),self.split(V)
    # scale dot product attention
    out, attention = self.attention(Q,K,V,mask=mask)
    # out : (batch_size, n_head, seq_len, d_k), attention : (batch_size, n_head, seq_len, seq_len)
    out = self.concat(out)
    out = self.W_concat(out)
    return out


### Layer Normalization

In [None]:
# Layer Norm
class Layer_Normalization(nn.Module):
  def __init__(self, d_model, eps = 1e-12):
    super(Layer_Normalization, self).__init__()
    self.gamma = nn.Parameter(torch.ones(d_model))
    self.beta = nn.Parameter(torch.zeros(d_model))
    self.eps = eps

  def forward(self, x):
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, keepdim=True, unbiased=False)

    out = (x-mean) / torch.sqrt(var + self.eps)
    out = self.gamma * out + self.beta
    return out

### Feed-Forward Network

In [None]:
# FFN

class Feed_Forward_Network(nn.Module):
  def __init__(self, d_model, hidden, drop_rate = 0.1):
    super(Feed_Forward_Network, self).__init__()
    self.linear1 = nn.Linear(d_model, hidden)
    self.linear2 = nn.Linear(hidden, d_model)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(p = drop_rate)

  def forward(self, x):
    out = self.linear1(x)
    out = self.relu(out)
    out = self.dropout(out)
    out = self.linear2(out)
    return out

### Encoder Layer

In [None]:
# Encoder Layer

class Encoder_Layer(nn.Module):
  def __init__(self, d_model, ffn_hidden, n_head, drop_rate):
    super(Encoder_Layer, self).__init__()
    self.attention = MultiHead_Attention(d_model, n_head)
    self.layer_norm1 = Layer_Normalization(d_model)
    self.dropout1 = nn.Dropout(drop_rate)
    self.ffn = Feed_Forward_Network(d_model, ffn_hidden, drop_rate)
    self.layer_norm2 = Layer_Normalization(d_model)
    self.dropout2 = nn.Dropout(drop_rate)

  def forward(self, x, src_mask):
    x0 = x
    x = self.attention(x,x,x,src_mask)
    x = self.dropout1(x)
    x = self.layer_norm1(x+x0)
    
    x0 = x
    x = self.ffn(x)
    x = self.dropout1(x)
    x = self.layer_norm2(x+x0)
    return x  

In [None]:
# Encoder

class Encoder(nn.Module):
  def __init__(self, enc_voc_size, max_len, d_model,
               ffn_hidden, n_head, n_layers, drop_rate, device):
    self.embedding = Transformer_Embedding(enc_voc_size, max_len, d_model, drop_rate)
    self.layers = nn.ModuleList([Encoder_Layer(d_model, ffn_hidden,
                                               n_head, drop_rate) for _ in range(n_layers)])
  def forward(self, x, src_mask):
    x = self.embedding(x)
    for layer in self.layers:
      x = layer(x, src_mask)
    return x

### Decoder Layer

In [None]:
# Decoder Layer

class Decoder_Layer(nn.Module):
  def __init__(self, d_model, ffn_hidden, n_head, drop_rate):
    super(Decoder_Layer, self).__init__()
    self.self_attention = MultiHead_Attention(d_model, n_head)
    self.layer_norm1 = Layer_Normalization(d_model)
    self.dropout1 = nn.Dropout(drop_rate)

    self.enc_dec_attention = MultiHead_Attention(d_model, n_head)
    self.layer_norm2 = Layer_Normalization(d_model)
    self.dropout2 = nn.Dropout(drop_rate)

    self.ffn = Feed_Forward_Network(d_model, ffn_hidden, drop_rate)
    self.layer_norm3 = Layer_Normalization(d_model)
    self.dropout3 = nn.Dropout(drop_rate)

  def forward(self, x_dec, x_enc, trg_mask, src_mask):
    x_dec0 = x_dec
    x_dec = self.self_attention(x_dec, x_dec, x_dec, trg_mask)
    x_dec = self.dropout1(x_dec)
    x = self.layer_norm1(x_dec + x_dec0)

    if x_enc is not None:
      x0 = x
      x = self.enc_dec_attention(Q=x, K=x_enc, V=x_enc, mask=src_mask)
      x = self.dropout1(x)
      x = self.layer_norm1(x + x_dec0)  

    x0 = x
    x = self.ffn(x)
    x = self.dropout3(x)
    x = self.layer_norm3(x+x0)
    return x   

In [None]:
# Decoder

class Decoder(nn.Module):
  def __init__(self, dec_voc_size, max_len, d_model,
               ffn_hidden, n_head, n_layers, drop_rate, device):
    super(Decoder, self).__init__()
    self.embedding = Transformer_Embedding(dec_voc_size, max_len, d_model, drop_rate)
    self.layers = nn.ModuleList([Decoder_Layer(d_model,
                                               ffn_hidden, n_head, drop_rate) for _ in range(n_layers)])
    self.linear = nn.Linear(d_model, dec_voc_size)

  def forward(self, x_dec, x_enc, trg_mask, src_mask):
    x_dec = self.embedding(x_dec)
    for layer in self.layers:
      x_dec = layer(x_dec, x_enc, trg_mask, src_mask)

    out = self.linear(x)
    return out

### Transformer

In [None]:
# Transformer

class Transformer(nn.Module):

    def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len,
                 ffn_hidden, n_layers, drop_prob, device):
        super().__init__()
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.trg_sos_idx = trg_sos_idx
        self.device = device
        self.encoder = Encoder(d_model=d_model,
                               n_head=n_head,
                               max_len=max_len,
                               ffn_hidden=ffn_hidden,
                               enc_voc_size=enc_voc_size,
                               drop_prob=drop_prob,
                               n_layers=n_layers,
                               device=device)

        self.decoder = Decoder(d_model=d_model,
                               n_head=n_head,
                               max_len=max_len,
                               ffn_hidden=ffn_hidden,
                               dec_voc_size=dec_voc_size,
                               drop_prob=drop_prob,
                               n_layers=n_layers,
                               device=device)

    def forward(self, src, trg):
        src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)

        src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx)

        trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx) * \
                   self.make_no_peak_mask(trg, trg)

        enc_src = self.encoder(src, src_mask)
        output = self.decoder(trg, enc_src, trg_mask, src_trg_mask)
        return output

    def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
        len_q, len_k = q.size(1), k.size(1)

        # batch_size x 1 x 1 x len_k
        k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
        # batch_size x 1 x len_q x len_k
        k = k.repeat(1, 1, len_q, 1)

        # batch_size x 1 x len_q x 1
        q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
        # batch_size x 1 x len_q x len_k
        q = q.repeat(1, 1, 1, len_k)

        mask = k & q
        return mask

    def make_no_peak_mask(self, q, k):
        len_q, len_k = q.size(1), k.size(1)

        # len_q x len_k
        mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor).to(self.device)

        return mask

In [None]:
k = torch.tensor([[1,2],[3,4]])
idx = torch.tensor([[1,1],[1,1]])
k.ne(idx)

tensor([[False,  True],
        [ True,  True]])