In [1]:
import torch
from torch import nn

## Create Encoder

`Embedding` and `Positional Encoding`

In [134]:
def create_positional_encoding(max_length, d_model):
    assert d_model % 2 == 0, "Dimension model must be even"

    pos = torch.arange(0, max_length).unsqueeze(1) # (max_length, 1)
    pos_expanded = pos.repeat(1, d_model // 2) # (max_length, d_model // 2)

    power = torch.arange(0, d_model, 2).float() / d_model
    div_term = torch.pow(10000, power).unsqueeze(0) # (1, d_model // 2)
    div_term_expanded = div_term.repeat(max_length, 1)  # (max_length, d_model // 2)

    pe = torch.zeros(max_length, d_model) # (max_length, d_model)
    pe[:, 0::2] = torch.sin(pos_expanded / div_term_expanded) # (max_length, d_model // 2)
    pe[:, 1::2] = torch.cos(pos_expanded / div_term_expanded) # (max_length, d_model // 2)

    return pe

class Embedding(nn.Module):
  def __init__(self, vocab_size, max_length, d_model):
    super(Embedding, self).__init__()
    self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
    self.pos_encoding = create_positional_encoding(max_length, d_model) # (seq_length, d_model)

  def forward(self, x):
    """ Apply embedding and positional encoding to the input

    Input:
      x: (N, seq_length)
    Output:
      x: (N, seq_length, d_model)
    """
    # apply embedding
    x = self.embedding(x)
    # add positional encoding
    x += self.pos_encoding[:x.size(1)]
    return x

`SelfMutliHeadAttention` and `Add&Norm`

In [135]:
class SelfAttention(nn.Module):
  def __init__(self, d_model, num_heads, dropout):
    super(SelfAttention, self).__init__()
    self.mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True)
    self.layer_norm = nn.LayerNorm(d_model)


  def forward(self, x):
    attn_output, _ = self.mha(query=x, key=x, value=x)
    x = self.layer_norm(x + attn_output)
    return x

`FeedForward` and `Add&Norm`

In [136]:
class FeedForward(nn.Module):
  def __init__(self, d_model, dropout):
    super(FeedForward, self).__init__()
    self.seq = nn.ModuleList([
        nn.Linear(d_model, 2 * d_model),
        nn.ReLU(),
        nn.Linear(2 * d_model, d_model),
        nn.Dropout(dropout)
    ])

    self.layernorm = nn.LayerNorm(d_model)

  def forward(self, x):
    original_x = x
    for layer in self.seq:
      x = layer(x)
    out = x + original_x
    return self.layernorm(out)


In [137]:
class Encoder_Layer(nn.Module):
  def __init__(self, d_model, num_heads, dropout):
    super(Encoder_Layer, self).__init__()
    self.self_attention = SelfAttention(d_model, num_heads, dropout)
    self.ff = FeedForward(d_model, dropout)

  def forward(self, x):
    x = self.self_attention(x)
    x = self.ff(x)
    return x

In [138]:
class Encoder(nn.Module):
  def __init__(self, vocab_size, max_length, d_model, num_heads, num_layers, dropout):
    super(Encoder, self).__init__()
    self.embedding = Embedding(vocab_size, max_length, d_model)
    self.encoder_layers = nn.ModuleList([
        Encoder_Layer(d_model, num_heads, dropout) for _ in range(num_layers)
    ])


  def forward(self, x):
    x = self.embedding(x)
    for enc_layer in self.encoder_layers:
      x = enc_layer(x)
    return x

## Create Decoder

`CasualMultiHeadAttention` and `Add&Norm`

In [149]:
mask_size = 128
causal_mask = torch.zeros(mask_size, mask_size)
for i in range(mask_size):
    for j in range(mask_size):
        if i < j:
            causal_mask[i, j] = float('-inf')


class CausalSelfAttention(nn.Module):
  def __init__(self, d_model, num_heads, dropout):
    super(CausalSelfAttention, self).__init__()
    self.mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True)
    self.layer_norm = nn.LayerNorm(d_model)


  def forward(self, x):
    attn_output, _ = self.mha(query=x, key=x, value=x, attn_mask=causal_mask, is_causal=True)
    x = self.layer_norm(x + attn_output)
    return x

`CrossMultiHeadAttention` and+ `Add&Norm`

In [140]:
class CrossAttention(nn.Module):
  def __init__(self, d_model, num_heads, dropout):
    super(CrossAttention, self).__init__()
    self.mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True)
    self.layer_norm = nn.LayerNorm(d_model)


  def forward(self, x, image_embedding):
    attn_output, _ = self.mha(query=x, key=image_embedding, value=image_embedding)
    x = self.layer_norm(x + attn_output)
    return x


`FeedForward` and `Add&Norm Layer`

In [141]:
class FeedForward(nn.Module):
  def __init__(self, d_model, dropout):
    super(FeedForward, self).__init__()
    self.seq = nn.ModuleList([
        nn.Linear(d_model, 2 * d_model),
        nn.ReLU(),
        nn.Linear(2 * d_model, d_model),
        nn.Dropout(dropout)
    ])

    self.layernorm = nn.LayerNorm(d_model)

  def forward(self, x):
    original_x = x
    for layer in self.seq:
      x = layer(x)
    out = x + original_x
    return self.layernorm(out)


In [142]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, dropout):
    super(DecoderLayer, self).__init__()
    self.casual_attention = CausalSelfAttention(d_model, num_heads, dropout)
    self.cross_attention = CrossAttention(d_model, num_heads, dropout)
    self.ff = FeedForward(d_model, dropout)


  def forward(self, x, encoder_output):
    x = self.casual_attention(x)
    x = self.cross_attention(x, encoder_output)
    x = self.ff(x)
    return x

In [143]:
class Decoder(nn.Module):
  def __init__(self, vocab_size, max_length, d_model, num_heads, num_layers, dropout):
    super().__init__()
    self.decoder_embedding = Embedding(vocab_size, max_length, d_model)
    self.decoder_layers = nn.ModuleList([
        DecoderLayer(d_model, num_heads, dropout) for _ in range(num_layers)
    ])
    self.last_layer = nn.Linear(d_model, vocab_size)


  def forward(self, deocder_input, encoder_output):
    '''
    deocder_input: (N, max_length)
    encoder_output: (N, max_length, d_model)
    '''
    x = self.decoder_embedding(deocder_input) # (N, max_length, d_model)

    for dec_layer in self.decoder_layers:
      x = dec_layer(x, encoder_output)

    x = self.last_layer(x)
    return x

## Create TransFormer

In [145]:
class TransFormer(nn.Module):
  def __init__(self, vocab_size=5000, max_length=128, d_model=512, num_heads=8, num_layers=4, dropout=0.5):
    super(TransFormer, self).__init__()
    self.encoder = Encoder(vocab_size, max_length, d_model, num_heads, num_layers, dropout)
    self.decoder = Decoder(vocab_size, max_length, d_model, num_heads, num_layers, dropout)

  def forward(self, encoder_input, decoder_input):
    encoder_output = self.encoder(encoder_input)
    model_output = self.decoder(decoder_input, encoder_output)
    return model_output

In [150]:
decoder_input = torch.randint(0, 5000, size=(32, 128))
encoder_input = torch.randint(0, 5000, size=(32, 128))
transformer = TransFormer()
output = transformer(encoder_input, decoder_input)
print(output.shape)

torch.Size([32, 128, 5000])
