In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import pprint

In [8]:
from functions import *

# Torch Module

In [38]:
# self-attention module
class SelfAttention(nn.Module):
    def __init__(self, embed_size, head_count):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.head_count = head_count # attention heads
        # create linear layers for query, key and value projections for each head
        self.query_layers = nn.ModuleList(
            [
                nn.Linear(embed_size, embed_size, bias=False)
                for _ in range(head_count)
            ]
        )
        self.key_layers = nn.ModuleList(
            [
                nn.Linear(embed_size, embed_size, bias=False)
                for _ in range(head_count)
            ]
        )
        self.value_layers = nn.ModuleList(
            [
                nn.Linear(embed_size, embed_size, bias=False)
                for _ in range(head_count)
            ]
        )
        self.fc_out = nn.Linear(head_count * embed_size, embed_size) # final layer to combine head outputs
    
    def forward(self, embeddings):
        batch_size, token_count = embeddings[:2]
        qkvs = torch.zeros(self.head_count, 3, batch_size, token_count,
            self.embed_size
        ).to(device)

        # looping over heads to compute query, key and value projections
        for i in range(self.head_count):
            qkvs[i, 0] = self.query_layers[i](embeddings)
            qkvs[i, 1] = self.key_layers[i](embeddings)
            qkvs[i, 2] = self.value_layers[i](embeddings)
        
        # computing energy term for each head, batch, adn pair of tokens
        energy = torch.zeros(self.head_count, batch_size, token_count, token_count).to(device)
        # create mask with false on below the diagonal and true above the diagonal
        mask = torch.triu(torch.ones((token_count, token_count), diagonal=1)).bool().to(device)

        for h in range(self.head_count): 
            for b in range(batch_size):
                for i in range(token_count):
                    for j in range(token_count):
                        energy[h, b, i, j] = torch.dot(qkvs[h, 0, b, i], qkvs[h, i, b, j]) # energy for each word with respect to every other word.
                energy[h, b] = energy[h, b].masked_fill(mask, float('-inf')) # making sure only preceding energy is calculated
        
        # attention scores
        attention = torch.nn.functional.softmax(energy, dim=3) # masked attention

        # compute weighted sum of alues for each head and token
        out = torch.zeros(batch_size, token_count, self.head_count, self.embed_size).to(device)
        for h in range(self.head_count):
            for b in range(batch_size):
                for i in range(token_count):
                    for j in range(token_count):
                        out[b, i, h] +=(attention[h,b, i, j] * qkvs[h, 2, b, j]) # weighted sum of value vector for each token
        
        out = out.reshape(batch_size, token_count, self.head_count * self.embed_size)
        return self.fc_out(out)

class TransformerBlock(nn.Module):
    # ResNet architecture, we can stack many transformers in a block to increase performance
    def __init__(self, embed_size, head_count):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, head_count) # self attention
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        # feed forward
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, embed_size),
            nn.ReLU(),
            nn.Linear(embed_size, embed_size)
        )
    
    def forward(self, embeddings):
        attention = self.attention(embeddings)
        #
        out = self.norm1(attention + embeddings)
        out = attention + self.feed_forward(out)
        out = self.norm2(out)
        return out

class Transformer(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, head_count):
        super(Transformer, self).__init__()
        self.embed_size = embed_size # size of your word embeddings
        self.vocab_size = vocab_size # size of your vocab
        self.word_embedding = nn.Embedding(vocab_size, embed_size)

        # transformer blocks
        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, head_count)
                for _ in range(num_layers)
            ] 
        )
        self.fc_out = nn.Linear(embed_size, vocab_size) # final linear layer to produce logits

    def forward(self, input_tokens, mask=None):
        batch_size, token_count = input_tokens.shape[:2]
        out = self.word_embedding(input_tokens) # word embedding

        # positional enconding
        positions = torch.arange(0, token_count).expand(batch_size, token_count)
        positional_encoding = self.positional_encoding(positions, self.embed_size)
        out += positional_encoding.reshape(out.shape)

        # pass through transformer blocks
        for layer in self.layers:
            out = layer(out)
        
        # produce logits for the final token in each sequence
        out = self.fc_out(out[:, -1, :].reshape(batch_size, self.emb_size)).reshape(batch_size, self.vocab_size)

        return torch.nn.funcional.softmax(out, dim=1)

    def positional_encoding(self, positions, embed_size):
        angle_rads = self.get_angles(
            positions.unsqueeze(2).float(),
            torch.arange(embed_size)[None, None, :].float().to(device),
            embed_size
        )
        sines = torch.sin(angle_rads[:, :, 0::2])
        cosines = torch.cos(angle_rads[:, :, 1::2])
        pos_encoding = torch.cat([sines, cosines], dim=1)
        pos_encoding = pos_encoding.unsqueeze() # pos_enconding[None, ...]
        return pos_encoding

    def get_angles(self, pos, i, embed_size):
        angle_rates = 1 / torch.pow(10000, (2 * (i//2)) / emb_size )
        return pos * angle_rates

# Configuration

In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [44]:
######################################################################
# Let's now define the parameters of our model and instantiate the same. Below, we also
# define our loss function which is the cross-entropy loss and the optimizer used for training.
#
torch.manual_seed(0)

VOCAB_SIZE = 151#vocab.num_words+1
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 32
BATCH_SIZE = 128

model = Transformer(VOCAB_SIZE, EMB_SIZE, FFN_HID_DIM, NHEAD)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(device)
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [45]:
def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 285,482,135 trainable parameters
