# CONSTANTS

In [1]:
import math

GLOBAL = {
    'embedded_length': 512,
    'continuity_length': 64,
    'n_vocab': 29
}

# Transformer (Top level)
TRANS_CONST = {
    'n_attention_layers': 8,
    'n_attention_heads': 8,    

    'max_output_length': 6,
    
    'embedding_dic_size': GLOBAL['n_vocab'], 
    'embedded_vec_size': GLOBAL['embedded_length'],
    
    # 'pos_encoding_input': GLOBAL['embedded_length'],
    # 'pos_encoding_output': GLOBAL['embedded_length'],
    
    'linear_input': GLOBAL['embedded_length'],
    'linear_output': GLOBAL['n_vocab'] # output vocab size
}

# Encoder, EncoderLayer
ENCODER_CONST = {
    'norm1_size': GLOBAL['embedded_length'], # same as input matrix width
    'norm2_size': GLOBAL['embedded_length'],

    # maybe rename these two, it's just for knowing the input dim and the dim that the FF layer will work with
    'ff1': GLOBAL['embedded_length'], 
    'ff2': GLOBAL['embedded_length'] * 4
}

# Decoder, DecoderLayer
DECODER_CONST = {
    'norm1_size': GLOBAL['embedded_length'], # same as input matrix width
    'norm2_size': GLOBAL['embedded_length'],
    'norm3_size': GLOBAL['embedded_length'],

    'ff1': GLOBAL['embedded_length'],#TODO RENAME
    'ff2': GLOBAL['embedded_length'] * 4#TODO RENAME
}

# MultiHeadAttention, SingleHeadAttention
ATTENTION_CONST = {
    'mh_concat_width': GLOBAL['continuity_length']*TRANS_CONST['n_attention_heads'], # single head attention width * number of heads
    'mh_output_width': GLOBAL['embedded_length'], #TODO - I'm just guessing this. Didn't see in illustrated transformer. Since we have to use this for the add & norm layer though it has to be the same as the input width (I think)

    # W_q weight matrix 
    'sh_linear1_input': GLOBAL['embedded_length'], # same as embedded length to end up with n_words x 64
    'sh_linear1_output': GLOBAL['continuity_length'], # specified in the paper
    # W_k weight matrix 
    'sh_linear2_input': GLOBAL['embedded_length'], # same as embedded length to end up with n_words x 64
    'sh_linear2_output': GLOBAL['continuity_length'], # specified in the paper
    # W_v weight matrix 
    'sh_linear3_input': GLOBAL['embedded_length'], # same as embedded length to end up with n_words x 64
    'sh_linear3_output': GLOBAL['continuity_length'], # specified in the paper
    
    'sh_scale_factor': math.sqrt(GLOBAL['continuity_length']) # specified in the paper, square root of dimension of key vector/matrix (64)
}

# FeedForward
FEEDFORWARD_CONST = {
    'dropout': 0.1
}



# FEEDFORWARD

In [2]:

import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, dim_model, dim_ff, dropout=FEEDFORWARD_CONST['dropout']):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(dim_model, dim_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_ff, dim_model)

    def forward(self, inputs):
        x = inputs
        x = self.linear1(x)
        x = nn.functional.relu(x) 
        x = self.dropout(x) 
        x = self.linear2(x) 
        return x


# MULTIHEADATTENTION

In [3]:

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, masked=False):
        super(MultiHeadAttention, self).__init__()

        self.masked = masked
        self.attentionHeads = nn.ModuleList([SingleHeadAttention(masked) for _ in range(n_heads)])
        self.linear = nn.Linear(ATTENTION_CONST['mh_concat_width'], ATTENTION_CONST['mh_output_width'])
        self.lastHeadKV = None

    def forward(self, inputs, encoderKV=None):
        x = []
        for head in self.attentionHeads:
            sh_attention, k, v = head(inputs, encoderKV=encoderKV) 
            x.append(sh_attention)
        self.lastHeadKV = {'K': k,'V': v}
        x = torch.cat(x, 1) # concatinate all single head attention outputs
        x = self.linear(x) # matmul with weight matrix (linear layer) to get 10x64 shape
        return x

class SingleHeadAttention(nn.Module):
    def __init__(self, masked):
        super(SingleHeadAttention, self).__init__()
        self.masked = masked
        self.linear1 = nn.Linear(ATTENTION_CONST['sh_linear1_input'], ATTENTION_CONST['sh_linear1_output'])
        self.linear2 = nn.Linear(ATTENTION_CONST['sh_linear2_input'], ATTENTION_CONST['sh_linear2_output'])
        self.linear3 = nn.Linear(ATTENTION_CONST['sh_linear3_input'], ATTENTION_CONST['sh_linear3_output'])
        self.scale = nn.Parameter(torch.FloatTensor([ATTENTION_CONST['sh_scale_factor']]))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, inputs, encoderKV=None):        
        q = self.linear1(inputs)
        k = self.linear2(inputs) if encoderKV == None else encoderKV['K']
        v = self.linear3(inputs) if encoderKV == None else encoderKV['V']
        x = torch.matmul(q, k.permute(1, 0)) 
        x = x * self.scale
        # if self.masked:
        #     # TODO "future positions" have to be set to -inf. this is for the decoder to only allow self attention to consider earlier positions.
        x = self.softmax(x) 
        x = torch.matmul(x, v)
        return x if encoderKV != None else x, k, v




# DECODER

In [4]:

import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, n_layers, n_attention_heads):
        super(Decoder, self).__init__()
        
        self.decoderLayers = nn.ModuleList([DecoderLayer(n_attention_heads) for _ in range(n_layers)])

    def forward(self, inputs, encoderKV):
        x = inputs
        for layer in self.decoderLayers:
            x = layer(x, encoderKV) 
        return x

class DecoderLayer(nn.Module):
    def __init__(self, n_attention_heads):
        super(DecoderLayer, self).__init__()

        self.mhattention_masked = MultiHeadAttention(n_attention_heads, masked=True)
        self.mhattention = MultiHeadAttention(n_attention_heads)
        self.feedforward = FeedForward(DECODER_CONST['ff1'], DECODER_CONST['ff2'])
        self.norm1 = nn.LayerNorm(DECODER_CONST['norm1_size'])
        self.norm2 = nn.LayerNorm(DECODER_CONST['norm2_size'])
        self.norm3 = nn.LayerNorm(DECODER_CONST['norm3_size'])

    def forward(self, inputs, encoderKV):
        x = inputs
        z = x
        x = self.mhattention_masked(x) #TODO masking not implemented
        x = z + x
        x = self.norm1(x)
        z = x
        x = self.mhattention(x, encoderKV=encoderKV) 
        x = z + x
        x = self.norm2(x)
        z = x
        x = self.feedforward(x)
        x = z + x
        x = self.norm3(x)
        return x



# ENCODER

In [5]:

import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, n_layer, n_attention_heads):
        super(Encoder, self).__init__()
        
        self.encoderLayers = nn.ModuleList([EncoderLayer(n_attention_heads) for _ in range(n_layer)])

    def forward(self, inputs):
        x = inputs
        lastLayerKV = None
        for layer in self.encoderLayers:
            x = layer(x)
            lastLayerKV = layer.lastLayerKV
        return x, lastLayerKV

class EncoderLayer(nn.Module):
    def __init__(self, n_attention_heads):
        super(EncoderLayer, self).__init__()

        self.mhattention = MultiHeadAttention(n_attention_heads)
        self.feedforward = FeedForward(ENCODER_CONST['ff1'], ENCODER_CONST['ff2'])
        self.norm1 = nn.LayerNorm(ENCODER_CONST['norm1_size'])
        self.norm2 = nn.LayerNorm(ENCODER_CONST['norm2_size'])
        self.lastLayerKV = None

    def forward(self, inputs):
        x = inputs 
        z = x
        x = self.mhattention(x)
        self.lastLayerKV = self.mhattention.lastHeadKV
        x = z + x
        x = self.norm1(x)
        z = x
        x = self.feedforward(x)
        x = z + x
        x = self.norm2(x) 
        return x



# TRANSFORMER

In [6]:
import torch
import torch.nn as nn
import numpy

class Transformer(nn.Module):
    def __init__(self, n_layers=TRANS_CONST['n_attention_layers'], n_attention_heads=TRANS_CONST['n_attention_heads']):
        super(Transformer, self).__init__()

        self.encoder = Encoder(n_layers, n_attention_heads)
        self.decoder = Decoder(n_layers, n_attention_heads)
        self.embedding = nn.Embedding(TRANS_CONST['embedding_dic_size'], TRANS_CONST['embedded_vec_size'])
        # self.posEncoding = #TODO
        self.linear = nn.Linear(TRANS_CONST['linear_input'], TRANS_CONST['linear_output'])
        self.softmax = nn.Softmax(dim=1)

    def __call__(self, inputs=None):
        if inputs != None: 
            raise NotImplementedError

        import random
        inputs = []
        for _ in range(13): inputs.append(numpy.zeros(26)) # 26 is vocab size, should be constant; 13 is just a random amount of words in the sequence
        inputs = torch.Tensor(inputs)
        for i in inputs: i[random.randint(0, len(i) - 1)] = 1
        return self.forward(inputs.long())

    def forward(self, inputs):
        # TODO FIRST PRIO
            # get all inputs for embedding (real example for translation tasks etc, noise, decoder input) on the same format, which should be NxV
        #### ENCODING ####
        x = self.doEmbedding(inputs)
        # x = self.posEncoding(x)
        _, encoderKV = self.encoder(x) #TODO try running the encoder output trough 2 additional linear layers to make the KV matrices

        #### DECODING ####
        sos = numpy.zeros(GLOBAL['n_vocab'])
        sos[0] = 1
        x = torch.Tensor([sos])
        ## Embedding
        x = self.doEmbedding(x)
        # x = self.posEncoding(x) #TODO
        ## Decoding
        while len(x) < TRANS_CONST['max_output_length']: #TODO add eos token
            new_word = self.decoder(x, encoderKV)[0, :].unsqueeze(dim=0)
            x = torch.cat([x, new_word], dim=0)

        x = self.linear(x)
        x = self.softmax(x)
        
        return x

    def doEmbedding(self, inputs):
        inputs = inputs.nonzero()[:, 1] # this gets all indices of nonzero values from the inputs matrix
        return self.embedding(inputs)

# TRAIN

In [7]:
import torch 
import torch.nn as nn
import random
import numpy as np

EPOCHS = 200

transformer = Transformer()
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0002)
loss = torch.nn.BCELoss()

real_sample = torch.Tensor([
        [1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0, 0,0],
        [0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0, 0,0],
        [0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0, 0,0],
        [0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0, 0,0],
        [0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0, 0,0],
        [0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0, 0, 0,0]
])

for _ in range(EPOCHS):
    sample = transformer()        
    # target = torch.ones(sample.shape[0], sample.shape[1])
    error = loss(sample, real_sample)
    if _ % 10 == 0: print(error, sample)
    error.backward()
    optimizer.step()
    if(error > 8): break

tensor(0.1707, grad_fn=<BinaryCrossEntropyBackward>) tensor([[0.0118, 0.0626, 0.0102, 0.0486, 0.0378, 0.0515, 0.0258, 0.0468, 0.0447,
         0.0118, 0.0204, 0.0265, 0.0610, 0.0135, 0.0146, 0.0135, 0.0378, 0.0340,
         0.0438, 0.0410, 0.0609, 0.0580, 0.0265, 0.0343, 0.0340, 0.0637, 0.0295,
         0.0208, 0.0150],
        [0.0147, 0.0584, 0.0134, 0.0416, 0.0584, 0.0568, 0.0355, 0.0312, 0.0374,
         0.0214, 0.0105, 0.0078, 0.0984, 0.0256, 0.0185, 0.0169, 0.0394, 0.0255,
         0.0287, 0.0524, 0.0221, 0.0641, 0.0141, 0.0365, 0.0066, 0.0768, 0.0394,
         0.0199, 0.0281],
        [0.0232, 0.0434, 0.0179, 0.0626, 0.0617, 0.0347, 0.0377, 0.0201, 0.0474,
         0.0113, 0.0084, 0.0142, 0.0807, 0.0176, 0.0503, 0.0288, 0.0436, 0.0214,
         0.0144, 0.0286, 0.0391, 0.0300, 0.0300, 0.0316, 0.0098, 0.0641, 0.0597,
         0.0193, 0.0485],
        [0.0254, 0.0640, 0.0158, 0.0486, 0.0885, 0.0431, 0.0307, 0.0180, 0.0580,
         0.0230, 0.0179, 0.0110, 0.0939, 0.0125, 0.0282, 0.

tensor(0.0873, grad_fn=<BinaryCrossEntropyBackward>) tensor([[8.7451e-01, 2.5409e-03, 1.1243e-03, 1.3373e-03, 1.7430e-02, 9.6225e-04,
         1.0473e-03, 2.1078e-02, 2.8653e-03, 2.1577e-03, 2.5629e-03, 2.7009e-02,
         2.1927e-03, 2.2779e-03, 1.2005e-02, 1.4055e-03, 1.0650e-03, 2.1383e-03,
         2.1420e-03, 1.0852e-03, 3.3577e-03, 7.3908e-04, 1.5342e-03, 5.6726e-04,
         3.5073e-03, 3.6539e-03, 4.1558e-03, 1.5381e-03, 2.0094e-03],
        [1.1812e-04, 2.1182e-05, 8.9074e-06, 1.7073e-05, 6.0815e-02, 1.7166e-05,
         1.9178e-05, 5.0153e-02, 2.3949e-05, 1.8548e-05, 2.0692e-05, 8.2141e-01,
         1.4403e-05, 1.0579e-05, 6.7086e-02, 9.5734e-06, 1.5336e-05, 2.2610e-05,
         2.6098e-05, 3.4454e-05, 7.1902e-06, 1.2554e-05, 1.3592e-05, 2.3386e-05,
         2.6472e-05, 2.9304e-05, 7.5860e-06, 6.4949e-06, 1.3946e-05],
        [1.1652e-04, 2.0946e-05, 8.5940e-06, 1.6992e-05, 5.9126e-02, 1.7461e-05,
         1.8885e-05, 4.3832e-02, 2.4602e-05, 1.9038e-05, 1.9929e-05, 8.2686e-0

tensor(0.1217, grad_fn=<BinaryCrossEntropyBackward>) tensor([[9.9919e-01, 2.3933e-06, 2.7270e-06, 1.0744e-06, 1.2154e-04, 6.0953e-07,
         1.0880e-06, 1.3307e-04, 4.2593e-06, 9.6625e-06, 6.7053e-06, 2.9415e-04,
         1.9789e-06, 9.0255e-06, 1.7155e-04, 3.3746e-06, 8.4258e-07, 3.1012e-06,
         2.3063e-06, 8.6245e-07, 3.6835e-06, 3.7309e-07, 1.8516e-06, 3.3783e-07,
         7.5678e-06, 4.9567e-06, 1.3242e-05, 2.3389e-06, 6.2084e-06],
        [5.1319e-05, 1.6064e-06, 2.7838e-07, 6.5169e-07, 2.6068e-01, 1.2995e-06,
         4.6008e-07, 4.0768e-01, 5.6755e-07, 3.3657e-07, 4.0277e-07, 1.7061e-03,
         9.8713e-07, 4.0920e-07, 3.2986e-01, 2.4079e-07, 6.4926e-07, 5.0148e-07,
         7.1880e-07, 1.4080e-06, 4.2871e-07, 1.7273e-06, 3.3349e-07, 1.3085e-06,
         3.8778e-07, 1.0576e-06, 2.9067e-07, 3.1824e-07, 3.0290e-07],
        [5.1148e-05, 1.5834e-06, 2.8208e-07, 6.4680e-07, 2.5168e-01, 1.3194e-06,
         4.6224e-07, 4.2230e-01, 5.5749e-07, 3.3640e-07, 4.0999e-07, 1.6103e-0

tensor(0.0750, grad_fn=<BinaryCrossEntropyBackward>) tensor([[9.9999e-01, 3.5580e-09, 9.6656e-09, 1.3593e-09, 9.9490e-07, 6.1663e-10,
         1.7517e-09, 1.4577e-06, 9.8180e-09, 6.2084e-08, 2.5821e-08, 4.3765e-06,
         2.8074e-09, 5.2097e-08, 3.0605e-06, 1.1973e-08, 1.0677e-09, 6.9779e-09,
         3.8795e-09, 1.0901e-09, 6.2959e-09, 3.0164e-10, 3.4404e-09, 3.1973e-10,
         2.4605e-08, 1.0377e-08, 6.2444e-08, 5.3908e-09, 2.7734e-08],
        [3.1775e-04, 2.0629e-06, 2.3250e-07, 1.1865e-06, 3.7402e-01, 3.1866e-06,
         4.5734e-07, 1.4555e-02, 5.9530e-07, 1.7531e-07, 1.9743e-07, 4.4927e-01,
         2.3000e-06, 3.6740e-07, 1.6181e-01, 2.1091e-07, 1.3789e-06, 4.4719e-07,
         1.1066e-06, 1.4916e-06, 6.8182e-07, 5.9365e-06, 3.3213e-07, 3.1232e-06,
         2.0268e-07, 7.9015e-07, 3.4479e-07, 4.3064e-07, 2.0461e-07],
        [3.0811e-04, 1.9875e-06, 2.1853e-07, 1.1292e-06, 3.7982e-01, 2.9498e-06,
         4.3912e-07, 1.4434e-02, 5.6743e-07, 1.6715e-07, 1.8878e-07, 4.4440e-0

tensor(0.1211, grad_fn=<BinaryCrossEntropyBackward>) tensor([[1.0000e+00, 7.1486e-12, 4.5484e-11, 2.3029e-12, 1.3382e-08, 8.3874e-13,
         3.7873e-12, 2.0263e-08, 3.0621e-11, 5.2777e-10, 1.3254e-10, 8.5689e-08,
         5.3472e-12, 4.0009e-10, 6.8996e-08, 5.6618e-11, 1.8290e-12, 2.1255e-11,
         8.8100e-12, 1.8576e-12, 1.4507e-11, 3.2710e-13, 8.5851e-12, 4.0561e-13,
         1.0785e-10, 2.9297e-11, 3.9328e-10, 1.6663e-11, 1.6395e-10],
        [1.8345e-06, 9.4782e-09, 1.0303e-09, 7.6394e-09, 4.5517e-01, 1.6830e-08,
         2.3125e-09, 1.2143e-02, 1.5097e-09, 3.9343e-10, 1.3789e-09, 9.6968e-03,
         7.8102e-09, 4.3377e-10, 5.2298e-01, 7.6386e-10, 6.4118e-09, 1.3384e-09,
         4.6559e-09, 1.2402e-08, 2.2694e-09, 2.7477e-08, 2.5251e-09, 2.0377e-08,
         1.0037e-09, 4.5909e-09, 6.1158e-10, 1.4399e-09, 5.6008e-10],
        [1.7977e-06, 1.0724e-08, 1.1755e-09, 8.7620e-09, 4.2436e-01, 1.9393e-08,
         2.5680e-09, 1.4531e-02, 1.7137e-09, 4.6568e-10, 1.6574e-09, 7.4048e-0

In [13]:
print(torch.argmax(real_sample, dim=1))
print(torch.argmax(sample, dim=1))

tensor([ 0,  7,  4, 11, 11, 14])
tensor([0, 4, 4, 4, 4, 4])
