In [1]:
import torch
import torch.nn as nn

import numpy as np
import math

## Some helpers for preprocessing

In [2]:
mapping_dict = dict()

def one_hot_encode_string(string: list, alphabet_size: int):
    global mapping_dict
    
    SOS = alphabet_size
    EOS = alphabet_size + 1
    PAD = alphabet_size + 2
    
    encoded_string = np.zeros((len(string)+2, alphabet_size + 3)) # alphabet_size + 3 because SOS, EOS, padding token
    encoded_string[0][SOS] = 1
    encoded_string[-1][EOS] = 1
    
    encoded_string_sr = np.zeros((len(string)+2, alphabet_size + 3))
    encoded_string_sr[1][SOS] = 1
    encoded_string_sr[0][PAD] = 1
    
    for i, symbol in enumerate(string):
        if not symbol in mapping_dict:
            mapping_dict[symbol] = len(mapping_dict)
        
        encoded_string[i+1][mapping_dict[symbol]] = 1
        encoded_string_sr[i+2][mapping_dict[symbol]] = 1
    return encoded_string, encoded_string_sr

def ordinal_encode_string(string: list, alphabet_size: int):
    global mapping_dict
    
    SOS = alphabet_size
    EOS = alphabet_size + 1
    PAD = alphabet_size + 2
    
    encoded_string = np.zeros((len(string)+2, 1))
    encoded_string[0] = SOS
    encoded_string[-1] = EOS
    
    encoded_string_sr = np.zeros((len(string)+2, 1))
    encoded_string_sr[1] = EOS
    encoded_string_sr[0] = PAD
    
    for i, symbol in enumerate(string):
        if not symbol in mapping_dict:
            mapping_dict[symbol] = len(mapping_dict)
        
        encoded_string[i+1] = mapping_dict[symbol]
        encoded_string_sr[i+2] = mapping_dict[symbol]
    return encoded_string, encoded_string_sr

## Encoder and Decoder

In [3]:
# using template from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# tutorial about positional encoding: https://kikaben.com/transformers-positional-encoding/
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        #self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        
        #div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        div_term = 10000 ** ( (2 * torch.arange(0, d_model) ) / d_model)
        pe = torch.zeros(max_len, 1, d_model)
        for i in range(max_len):
            if i % 2 == 0:    
                pe[i, 0, :] = torch.sin(position[i] / div_term)
            else:
                pe[i, 0, :] = torch.cos(position[i] / div_term)
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return x #self.dropout(x)


# sidenote: understanding skip-connections: https://theaisummer.com/skip-connections/
class Encoder(nn.Module):
    def __init__(self, alphabet_size: int, embedding_dim: int, max_len:int, embedding_layer=None):
        super().__init__()
        self.input_embedding = nn.Embedding(alphabet_size+3, embedding_dim) if embedding_layer is None else embedding_layer # +3 for start, stop, padding symbol
        self.pos_encoding = PositionalEncoding(d_model=embedding_dim, max_len=max_len+2)
        
        self.mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=3)
        self.ln = nn.LayerNorm(embedding_dim, eps=1e-12, elementwise_affine=True)

    def forward(self, x: torch.Tensor):
        sequence_len = list(x.size())[0]
        x = self.input_embedding(x)
        x = self.pos_encoding(x)
        
        attn_output, attn_output_weights = self.mha(query=x, key=x, value=x, is_causal=True, \
                                                attn_mask=nn.Transformer.generate_square_subsequent_mask(sequence_len))

        x = x + attn_output # skip-connection
        x = self.ln(x)
        
        return x
    
class Decoder(nn.Module):
    def __init__(self, alphabet_size: int, embedding_dim: int, max_len:int, embedding_layer=None): #must be same as encoder
        super().__init__()
        self.input_embedding = nn.Embedding(alphabet_size+3, embedding_dim) if embedding_layer is None else embedding_layer # +3 for start, stop, padding symbol
        self.pos_encoding = PositionalEncoding(d_model=embedding_dim, max_len=max_len+2)
        
        self.masked_mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=3)
        self.ln = nn.LayerNorm(embedding_dim, eps=1e-12, elementwise_affine=True)

        self.mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=3)
        
        
    def forward(self, x: torch.Tensor, query: torch.Tensor=None, key: torch.Tensor=None):
        sequence_len = list(x.size())[0]
        
        x = self.input_embedding(x)
        x = self.pos_encoding(x)
        
        attn_output, attn_output_weights = self.masked_mha(query=x, key=x, value=x, is_causal=True, \
                                                attn_mask=nn.Transformer.generate_square_subsequent_mask(sequence_len))#, is_causal=True)

        x = x + attn_output # skip-connection
        x = self.ln(x)
        
        if query is None or key is None: # only for debugging
            attn_output, attn_output_weights = self.mha(query=x, key=x, value=x)
        else:
            attn_output, attn_output_weights = self.mha(query=query, key=key, value=x)
        
        x = x + attn_output # skip-connection
        x = self.ln(x)
        
        return x

## Debugging encoder and decoder with a toy string

### Encoder

In [4]:
ALPHABET_SIZE = 2
test_string = ["0", "1", "1"]

test_string_ord, test_string_ord_sr = ordinal_encode_string(test_string, ALPHABET_SIZE)
test_string_ord, test_string_ord_sr, test_string_ord.shape

(array([[2.],
        [0.],
        [1.],
        [1.],
        [3.]]),
 array([[4.],
        [3.],
        [0.],
        [1.],
        [1.]]),
 (5, 1))

In [5]:
model = Encoder(ALPHABET_SIZE, 3, max_len=30)
test_res = model(torch.tensor(test_string_ord, dtype=torch.int32))
list(test_res.size()), test_res

([5, 1, 3],
 tensor([[[ 1.3217, -1.0965, -0.2252]],
 
         [[ 1.2487, -1.1993, -0.0494]],
 
         [[ 1.0641,  0.2746, -1.3387]],
 
         [[-1.1065,  1.3160, -0.2095]],
 
         [[-1.2807,  0.1208,  1.1599]]], grad_fn=<NativeLayerNormBackward0>))

### Decoder

In [6]:
model = Decoder(ALPHABET_SIZE, 3, max_len=30)
test_res = model(torch.tensor(test_string_ord, dtype=torch.int32))
list(test_res.size()), test_res

([5, 1, 3],
 tensor([[[ 1.3699, -0.3807, -0.9892]],
 
         [[ 1.3300, -0.2487, -1.0813]],
 
         [[ 0.4117,  0.9659, -1.3775]],
 
         [[-1.3600,  1.0159,  0.3441]],
 
         [[-1.1674, -0.1076,  1.2750]]], grad_fn=<NativeLayerNormBackward0>))

In [7]:
# sidenote: understanding skip-connections: https://theaisummer.com/skip-connections/
class AuTransformer(nn.Module):
    def __init__(self, alphabet_size: int, embedding_dim: int, max_len:int):
        super().__init__()
        
        self.input_embedding = nn.Embedding(alphabet_size+3, embedding_dim) # +3 for start, stop, padding symbol
        self.encoder = Encoder(alphabet_size=alphabet_size, embedding_dim=embedding_dim, max_len=max_len, embedding_layer=self.input_embedding)
        self.decoder = Decoder(alphabet_size=alphabet_size, embedding_dim=embedding_dim, max_len=max_len, embedding_layer=self.input_embedding)
        
        self.output_fnn = nn.Linear(in_features=embedding_dim, out_features=alphabet_size+3) # +2 for start and stop
        self.gelu = torch.nn.GELU()
        
        self.dropout = nn.Dropout(0.2)
        self.softmax_output = nn.Softmax(dim=-1)
        
    def forward(self, src: torch.Tensor, tgt: torch.Tensor):
        x = self.encoder(src)
        x = self.dropout(x)
        x = self.decoder(x=tgt, query=x, key=x)
        x = self.dropout(x)
        
        x = self.gelu(self.output_fnn(x))        
        x = self.softmax_output(x)
        return x

## Debugging the whole network

In [8]:
test_string_oh, test_string_oh_sr = one_hot_encode_string(test_string, ALPHABET_SIZE)
test_string_oh, test_string_oh_sr

(array([[0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0.]]),
 array([[0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.]]))

In [9]:
model = AuTransformer(ALPHABET_SIZE, 3, max_len=30)
res = model(torch.tensor(test_string_ord, dtype=torch.int32), torch.tensor(test_string_ord_sr, dtype=torch.int32))
res, list(res.size())

(tensor([[[0.3690, 0.1370, 0.1465, 0.1751, 0.1723]],
 
         [[0.1910, 0.1959, 0.2067, 0.2085, 0.1979]],
 
         [[0.2938, 0.2124, 0.1581, 0.1540, 0.1817]],
 
         [[0.1343, 0.3252, 0.2672, 0.1322, 0.1412]],
 
         [[0.1061, 0.3750, 0.3009, 0.1064, 0.1116]]],
        grad_fn=<SoftmaxBackward0>),
 [5, 1, 5])

In [10]:
model.eval()

AuTransformer(
  (input_embedding): Embedding(5, 3)
  (encoder): Encoder(
    (input_embedding): Embedding(5, 3)
    (pos_encoding): PositionalEncoding()
    (mha): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=3, out_features=3, bias=True)
    )
    (ln): LayerNorm((3,), eps=1e-12, elementwise_affine=True)
  )
  (decoder): Decoder(
    (input_embedding): Embedding(5, 3)
    (pos_encoding): PositionalEncoding()
    (masked_mha): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=3, out_features=3, bias=True)
    )
    (ln): LayerNorm((3,), eps=1e-12, elementwise_affine=True)
    (mha): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=3, out_features=3, bias=True)
    )
  )
  (output_fnn): Linear(in_features=3, out_features=5, bias=True)
  (gelu): GELU(approximate='none')
  (dropout): Dropout(p=0.2, inplace=False)
  (softmax_output): Softmax(dim=-1)
)

## Overfit on a single sequence

In [11]:
# training loop from https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

loss_fn = nn.CrossEntropyLoss()#nn.CrossEntropyLoss()

test_string_ord = torch.tensor(test_string_ord, dtype=torch.long)
test_string_ord_sr = torch.tensor(test_string_ord_sr, dtype=torch.long)

test_string_oh = torch.tensor(test_string_oh, dtype=torch.float32, requires_grad=True)
test_string_oh_sr = torch.tensor(test_string_oh_sr, dtype=torch.float32, requires_grad=True)

loss_fn(test_string_oh, torch.squeeze(test_string_ord_sr.type(torch.LongTensor))), \
loss_fn(test_string_oh, torch.squeeze(test_string_ord.type(torch.LongTensor) ))

(tensor(1.7048, grad_fn=<NllLossBackward0>),
 tensor(0.9048, grad_fn=<NllLossBackward0>))

In [13]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

running_loss = 0.
last_loss = 0.

for i in range(100000):
    optimizer.zero_grad()

    outputs = model(test_string_ord, test_string_ord_sr)
    loss = loss_fn(torch.squeeze(outputs), test_string_oh_sr)
    loss.backward()

    # Adjust learning weights
    optimizer.step()

    # Gather data and report
    running_loss += loss.item()
    if i % 1000 == 999:
        last_loss = running_loss / 1000 # loss per batch
        print('  batch {} loss: {}'.format(i + 1, last_loss))
        running_loss = 0.

  batch 1000 loss: 1.517309710264206
  batch 2000 loss: 1.5054040787220002
  batch 3000 loss: 1.4933911712169647
  batch 4000 loss: 1.481315294623375
  batch 5000 loss: 1.4691589014530182
  batch 6000 loss: 1.4567446029186248
  batch 7000 loss: 1.4439928393363952
  batch 8000 loss: 1.4311956275701523
  batch 9000 loss: 1.4187696551084519
  batch 10000 loss: 1.4067645512819291
  batch 11000 loss: 1.3950567507743836
  batch 12000 loss: 1.3835871900320054
  batch 13000 loss: 1.3723446624279023
  batch 14000 loss: 1.3613491530418396
  batch 15000 loss: 1.3506313000917434
  batch 16000 loss: 1.3402339998483659
  batch 17000 loss: 1.330149579644203
  batch 18000 loss: 1.3203274800777436
  batch 19000 loss: 1.3107287448644638
  batch 20000 loss: 1.3013671793937682
  batch 21000 loss: 1.292262720823288
  batch 22000 loss: 1.2834305545091629
  batch 23000 loss: 1.2748874014616012
  batch 24000 loss: 1.2666352112293244
  batch 25000 loss: 1.2586599950790405
  batch 26000 loss: 1.2509332666397095

In [14]:
model(test_string_ord, test_string_ord_sr), test_string_oh, test_string_oh_sr

(tensor([[[2.1130e-03, 2.4145e-03, 2.6499e-03, 2.0371e-03, 9.9079e-01]],
 
         [[7.0820e-03, 6.4381e-03, 9.7242e-01, 5.9750e-03, 8.0843e-03]],
 
         [[9.9079e-01, 3.1784e-03, 1.3962e-03, 1.1785e-03, 3.4584e-03]],
 
         [[6.9245e-04, 9.9771e-01, 6.9072e-04, 4.1504e-04, 4.9191e-04]],
 
         [[6.9256e-04, 9.9771e-01, 6.9064e-04, 4.1503e-04, 4.9190e-04]]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[0., 0., 1., 0., 0.],
         [1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0.]], requires_grad=True),
 tensor([[0., 0., 0., 0., 1.],
         [0., 0., 1., 0., 0.],
         [1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0.]], requires_grad=True))