In [1]:
import os

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from transformer_definitions import *

import math
import pickle as pk

from torch.utils.data import Dataset, DataLoader

DATA_PATH = "problem_1_train_dfa.dat"
EPOCHS = 50
MODEL_NAME = "trained_model.pk" # how to save model

## Data loading and preprocessing

In [2]:
class SequenceDataset(Dataset):
    def __init__(self, datapath: str, maxlen: int, pad_sequences: bool=True, max_sequences: int=None):
        super().__init__()
        
        assert(os.path.isfile(datapath))
        self.symbol_dict = dict()
        self.label_dict = dict()
        self.sequences, self.labels, self.sequence_lengths = self._read_sequences(datapath, max_sequences)
        print("Sequences loaded. Some examples: \n{}".format(self.sequences[:3]))
        
        self.SOS = self.alphabet_size
        self.EOS = self.alphabet_size + 1
        self.PAD = self.alphabet_size + 2
        self.maxlen = maxlen + 2  # +2 for EOS/PAD and SOS 
        self.pad_sequences = pad_sequences
        
    def encode_sequences(self):
        self.ordinal_seq, self.ordinal_seq_sr = self._ordinal_encode_sequences(self.sequences)
        self.one_hot_seq, self.one_hot_seq_sr = self._one_hot_encode_sequences(self.sequences)
        
        del self.sequences
        self.sequences = None
        
        print("The symbol dictionary: {}".format(self.symbol_dict))
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.ordinal_seq[idx], self.ordinal_seq_sr[idx], self.one_hot_seq[idx], \
               self.one_hot_seq_sr[idx], self.labels[idx], self.sequence_lengths[idx]
       
    def _read_sequences(self, datapath: str, max_sequences: int):
        sequences = list()
        labels = list()
        sequence_lengths = list()
        
        for i, line in enumerate(open(datapath)):
            if i == 0:
                line = line.split()
                self.alphabet_size = int(line[1])
                print("Alphabet size: ", self.alphabet_size)
                continue
            elif max_sequences and i-1 >= max_sequences:
                break
            
            line = line.split()
            label = line[0]
            if not label in self.label_dict:
                self.label_dict[label] = len(self.label_dict)
            label = self.label_dict[label]
            labels.append(label)
            
            sequences.append(line[2:])
            sequence_lengths.append(len(line) - 1)
        return sequences, labels, sequence_lengths
    
    def _pad_one_hot(self, sequences: list, do_eos: bool=False):
        for i in range(len(sequences)):
            seq = sequences[i]
            #print("Before one hot:\n{}".format(seq))
            current_size = len(seq)
            
            t = torch.zeros((self.maxlen - current_size, self.alphabet_size + 3), dtype=torch.float32)
            t[:, self.PAD] = 1
            if do_eos and self.maxlen > current_size:
                t[0, self.PAD] = 0
                t[0, self.EOS] = 1
            
            seq = torch.cat((seq, t), dim=0)
            sequences[i] = seq
            #print("After one hot:\n{}".format(seq))
        return sequences
    
    def _one_hot_encode_sequences(self, strings: list):
        res = list()
        res_sr = list()
        for string in strings:
            x1, x2 = self._one_hot_encode_string(string)
            res.append(x1)
            res_sr.append(x2)
            
        if self.pad_sequences:
            res = self._pad_one_hot(res)
            res_sr = self._pad_one_hot(res_sr)

        return res, res_sr
    
    def _one_hot_encode_string(self, string: list):
        encoded_string = torch.zeros((len(string)+2, self.alphabet_size + 3), dtype=torch.float32) # alphabet_size + 3 because SOS, EOS, padding token
        encoded_string[0][self.SOS] = 1
        encoded_string[-1][self.EOS] = 1

        encoded_string_sl = torch.zeros((len(string)+2, self.alphabet_size + 3), dtype=torch.float32)
        encoded_string_sl[-2][self.EOS] = 1
        encoded_string_sl[-1][self.PAD] = 1

        for i, symbol in enumerate(string):
            if not symbol in self.symbol_dict:
                self.symbol_dict[symbol] = len(self.symbol_dict)

            encoded_string[i+1][self.symbol_dict[symbol]] = 1
            encoded_string_sl[i][self.symbol_dict[symbol]] = 1
        encoded_string_sl.requires_grad_()
        return encoded_string, encoded_string_sl
    
    def _pad_ordinal(self, sequences: list, do_eos: bool=False):
        for i in range(len(sequences)):
            seq = sequences[i]
            #print("Before ordinal:{}".format(seq))
            current_size = len(seq)
            
            t = torch.ones((self.maxlen - current_size,), dtype=torch.long)
            t = t*self.PAD 
            if do_eos and self.maxlen > current_size:
                t[0] = self.EOS
            
            seq = torch.cat((seq, t), dim=0)
            sequences[i] = seq
            #print("After ordinal:{}".format(seq))
        return sequences
    
    def _ordinal_encode_sequences(self, strings: list):
        res = list()
        res_sr = list()
        for string in strings:
            x1, x2 = self._ordinal_encode_string(string)
            res.append(x1)
            res_sr.append(x2)
        
        if self.pad_sequences: 
            res = self._pad_ordinal(res)
            res_sr = self._pad_ordinal(res_sr)
        return res, res_sr
    
    def _ordinal_encode_string(self, string: list):
        encoded_string = torch.zeros((len(string)+2,), dtype=torch.long)
        encoded_string[0] = self.SOS
        encoded_string[-1] = self.EOS

        encoded_string_sl = torch.zeros((len(string)+2,), dtype=torch.long)
        encoded_string_sl[-2] = self.EOS
        encoded_string_sl[-1] = self.PAD

        for i, symbol in enumerate(string):
            if not symbol in self.symbol_dict:
                self.symbol_dict[symbol] = len(self.symbol_dict)

            encoded_string[i+1] = self.symbol_dict[symbol]
            encoded_string_sl[i] = self.symbol_dict[symbol]
        return encoded_string, encoded_string_sl
    
    def get_alphabet_size(self):
        return self.alphabet_size
    
    def initialize(self, path: str="dataset.pk"):
        data = pk.load(open(path, "rb"))
        self.alphabet_size = self.alphabet_size
        self.symbol_dict = self.symbol_dict
        self.label_dict = self.label_dict
        
    def save_state(self, path: str="dataset.pk"):
        data = dict()
        data["alphabet_size"] = self.alphabet_size
        data["symbol_dict"] = self.symbol_dict
        data["label_dict"] = self.label_dict
        pk.dump(data, open(path, "wb"))

In [3]:
dataset = SequenceDataset(DATA_PATH, maxlen=10)
dataset.encode_sequences()
dataset.save_state()

Alphabet size:  4
Sequences loaded. Some examples: 
[['a', 'b'], ['a', 'b'], ['c', 'd', 'a', 'b']]
The symbol dictionary: {'a': 0, 'b': 1, 'c': 2, 'd': 3}


In [4]:
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

## Encoder and Decoder

In [5]:
# using template from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# tutorial about positional encoding: https://kikaben.com/transformers-positional-encoding/
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        #self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        
        #div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        div_term = 10000 ** ( (2 * torch.arange(0, d_model) ) / d_model)
        pe = torch.zeros(max_len, 1, d_model)
        for i in range(max_len):
            if i % 2 == 0:    
                pe[i, 0, :] = torch.sin(position[i] / div_term)
            else:
                pe[i, 0, :] = torch.cos(position[i] / div_term)
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return x #self.dropout(x)


# sidenote: understanding skip-connections: https://theaisummer.com/skip-connections/
class Encoder(nn.Module):
    def __init__(self, alphabet_size: int, embedding_dim: int, max_len:int, embedding_layer=None):
        super().__init__()
        self.pos_encoding = PositionalEncoding(d_model=embedding_dim, max_len=max_len+2)
        
        self.mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=3)
        self.ln = nn.LayerNorm(embedding_dim, eps=1e-12, elementwise_affine=True)
        
    def forward(self, x: torch.Tensor):
        sequence_len = list(x.size())[0]
        x = self.pos_encoding(x)
        
        attn_output, attn_output_weights = self.mha(query=x, key=x, value=x, is_causal=True, \
                                                attn_mask=nn.Transformer.generate_square_subsequent_mask(sequence_len))

        x = x + attn_output # skip-connection
        x = self.ln(x)
                
        return x, attn_output, attn_output_weights
    
class Decoder(nn.Module):
    def __init__(self, alphabet_size: int, embedding_dim: int, max_len:int, embedding_layer=None): #must be same as encoder
        super().__init__()
        self.pos_encoding = PositionalEncoding(d_model=embedding_dim, max_len=max_len+2)
        
        self.masked_mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=3)
        self.ln = nn.LayerNorm(embedding_dim, eps=1e-12, elementwise_affine=True)

        self.mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=3)
        
        
    def forward(self, x: torch.Tensor, query: torch.Tensor=None, key: torch.Tensor=None):
        sequence_len = list(x.size())[0]
        x = self.pos_encoding(x)
        
        attn_output, attn_output_weights = self.masked_mha(query=x, key=x, value=x, is_causal=True, \
                                                attn_mask=nn.Transformer.generate_square_subsequent_mask(sequence_len))#, is_causal=True)

        x = x + attn_output # skip-connection
        x = self.ln(x)
        
        if query is None or key is None: # only for debugging
            attn_output, attn_output_weights = self.mha(query=x, key=x, value=x)
        else:
            attn_output, attn_output_weights = self.mha(query=query, key=key, value=x)
        
        x = x + attn_output # skip-connection
        x = self.ln(x)
        
        return x
    
# sidenote: understanding skip-connections: https://theaisummer.com/skip-connections/
class AuTransformer(nn.Module):
    def __init__(self, alphabet_size: int, embedding_dim: int, max_len:int):
        super().__init__()
        
        self.input_embedding = nn.Embedding(alphabet_size+3, embedding_dim) # +3 for start, stop, padding symbol
        self.encoder = Encoder(alphabet_size=alphabet_size, embedding_dim=embedding_dim, max_len=max_len, embedding_layer=self.input_embedding)
        self.decoder = Decoder(alphabet_size=alphabet_size, embedding_dim=embedding_dim, max_len=max_len, embedding_layer=self.input_embedding)
        
        self.output_fnn = nn.Linear(in_features=embedding_dim, out_features=alphabet_size+3) # +2 for start and stop
        self.gelu = torch.nn.GELU()
        
        self.dropout = nn.Dropout(0.2)
        self.softmax_output = nn.Softmax(dim=-1)
        
        self.attention_output_layer = nn.Identity() 
        self.attention_weight_layer = nn.Identity() 
        
    def forward(self, src: torch.Tensor, tgt: torch.Tensor):
        x_src = self.input_embedding(src)
        x, attention_output, attention_weights = self.encoder(x_src)
        attention_output = self.attention_output_layer(attention_output)
        attention_weights = self.attention_weight_layer(attention_weights)
        x = self.dropout(x)

        x_tgt = self.input_embedding(tgt)
        x = self.decoder(x=x_tgt, query=x, key=x)
        x = self.dropout(x)
        
        x = self.gelu(self.output_fnn(x))
        x = self.softmax_output(x)
        return x

In [6]:
model = AuTransformer(dataset.get_alphabet_size(), 3, max_len=10)
model.eval()

AuTransformer(
  (input_embedding): Embedding(7, 3)
  (encoder): Encoder(
    (pos_encoding): PositionalEncoding()
    (mha): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=3, out_features=3, bias=True)
    )
    (ln): LayerNorm((3,), eps=1e-12, elementwise_affine=True)
  )
  (decoder): Decoder(
    (pos_encoding): PositionalEncoding()
    (masked_mha): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=3, out_features=3, bias=True)
    )
    (ln): LayerNorm((3,), eps=1e-12, elementwise_affine=True)
    (mha): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=3, out_features=3, bias=True)
    )
  )
  (output_fnn): Linear(in_features=3, out_features=7, bias=True)
  (gelu): GELU(approximate='none')
  (dropout): Dropout(p=0.2, inplace=False)
  (softmax_output): Softmax(dim=-1)
)

## Train

In [7]:
# training loop from https://pytorch.org/tutorials/beginner/introyt/trainingyt.html

optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
loss_fn = nn.CrossEntropyLoss()

running_loss = 0.
last_loss = 0.
divisor = 0.

for i in range(EPOCHS):
    print("Epoch: ", i)
    for j, (test_string_ord, test_string_ord_sr, test_string_oh, test_string_oh_sr, _, sequence_length) in enumerate(train_dataloader):
        optimizer.zero_grad()
        test_string_ord = torch.permute(test_string_ord, dims=[1,0])
        test_string_ord_sr = torch.permute(test_string_ord_sr, dims=[1,0])
        test_string_oh_sr = torch.permute(test_string_oh_sr, dims=[1,0,2])
        
        outputs = model(test_string_ord, test_string_ord_sr)
        loss = loss_fn(torch.squeeze(outputs), torch.squeeze(test_string_oh_sr))
        loss.backward()
        
        #break

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        divisor += float( list(test_string_ord.size())[1] )
        if j % 1000 == 0:
            last_loss = running_loss / divisor # loss per batch
            print('  batch {} loss: {}'.format(j, last_loss))
            running_loss = 0.
            divisor = 0.
    #break

2023-11-08 16:14:50.348570: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Epoch:  0
  batch 0 loss: 1.8642925024032593
  batch 1000 loss: 1.849761075258255
  batch 2000 loss: 1.8440512104034423
  batch 3000 loss: 1.8343844048976898
  batch 4000 loss: 1.8237437331676483
  batch 5000 loss: 1.8163463504314423
  batch 6000 loss: 1.8120236954689026
  batch 7000 loss: 1.8036180588006974
  batch 8000 loss: 1.7928736736774444
  batch 9000 loss: 1.7853946917057038
  batch 10000 loss: 1.784455164551735
  batch 11000 loss: 1.7733713262081146
  batch 12000 loss: 1.7684401646852492
Epoch:  1
  batch 0 loss: 1.7541624094902064
  batch 1000 loss: 1.754656299829483
  batch 2000 loss: 1.7438372901678085
  batch 3000 loss: 1.7408930761814116
  batch 4000 loss: 1.728685742378235
  batch 5000 loss: 1.7247979179620743
  batch 6000 loss: 1.720530793428421
  batch 7000 loss: 1.7062780493497849
  batch 8000 loss: 1.6974281995296479
  batch 9000 loss: 1.690214879512787
  batch 10000 loss: 1.6887503609657288
  batch 11000 loss: 1.6801608951091767
  batch 12000 loss: 1.673946210622787

  batch 5000 loss: 1.2886809856891632
  batch 6000 loss: 1.2917154603004455
  batch 7000 loss: 1.288399415373802
  batch 8000 loss: 1.28885343337059
  batch 9000 loss: 1.2912294952869414
  batch 10000 loss: 1.29135799741745
  batch 11000 loss: 1.2871426991224288
  batch 12000 loss: 1.2899217044115066
Epoch:  17
  batch 0 loss: 1.2820767938890039
  batch 1000 loss: 1.290476791024208
  batch 2000 loss: 1.2899952150583267
  batch 3000 loss: 1.2899896819591523
  batch 4000 loss: 1.2897409137487412
  batch 5000 loss: 1.285532755613327
  batch 6000 loss: 1.2875641310214996
  batch 7000 loss: 1.2864458593130113
  batch 8000 loss: 1.2843213291168214
  batch 9000 loss: 1.2831848030090331
  batch 10000 loss: 1.291511397242546
  batch 11000 loss: 1.2905567313432693
  batch 12000 loss: 1.2865615142583846
Epoch:  18
  batch 0 loss: 1.2897901426662097
  batch 1000 loss: 1.2871549139022826
  batch 2000 loss: 1.2888737255334854
  batch 3000 loss: 1.2884727692604065
  batch 4000 loss: 1.288650905609130

In [8]:
torch.save(model, MODEL_NAME)