In [1]:
import os

import torch
import torch.nn as nn

import pickle as pk

from transformer_definitions import AuTransformer
from torch.utils.data import Dataset, DataLoader

MODEL_NAME = "trained_model.pk"
DATA_PATH = "problem_1_train_dfa.dat"

# Model and data definitions

In [2]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        #self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        
        #div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        div_term = 10000 ** ( (2 * torch.arange(0, d_model) ) / d_model)
        pe = torch.zeros(max_len, 1, d_model)
        for i in range(max_len):
            if i % 2 == 0:    
                pe[i, 0, :] = torch.sin(position[i] / div_term)
            else:
                pe[i, 0, :] = torch.cos(position[i] / div_term)
        self.register_buffer('pe', pe)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return x #self.dropout(x)


# sidenote: understanding skip-connections: https://theaisummer.com/skip-connections/
class Encoder(nn.Module):
    def __init__(self, alphabet_size: int, embedding_dim: int, max_len:int, embedding_layer=None):
        super().__init__()
        self.pos_encoding = PositionalEncoding(d_model=embedding_dim, max_len=max_len+2)
        
        self.mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=3)
        self.ln = nn.LayerNorm(embedding_dim, eps=1e-12, elementwise_affine=True)
        
    def forward(self, x: torch.Tensor):
        sequence_len = list(x.size())[0]
        x = self.pos_encoding(x)
        
        attn_output, attn_output_weights = self.mha(query=x, key=x, value=x, is_causal=True, \
                                                attn_mask=nn.Transformer.generate_square_subsequent_mask(sequence_len))

        x = x + attn_output # skip-connection
        x = self.ln(x)
                
        return x, attn_output, attn_output_weights
    
class Decoder(nn.Module):
    def __init__(self, alphabet_size: int, embedding_dim: int, max_len:int, embedding_layer=None): #must be same as encoder
        super().__init__()
        self.pos_encoding = PositionalEncoding(d_model=embedding_dim, max_len=max_len+2)
        
        self.masked_mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=3)
        self.ln = nn.LayerNorm(embedding_dim, eps=1e-12, elementwise_affine=True)

        self.mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=3)
        
        
    def forward(self, x: torch.Tensor, query: torch.Tensor=None, key: torch.Tensor=None):
        sequence_len = list(x.size())[0]
        x = self.pos_encoding(x)
        
        attn_output, attn_output_weights = self.masked_mha(query=x, key=x, value=x, is_causal=True, \
                                                attn_mask=nn.Transformer.generate_square_subsequent_mask(sequence_len))#, is_causal=True)

        x = x + attn_output # skip-connection
        x = self.ln(x)
        
        if query is None or key is None: # only for debugging
            attn_output, attn_output_weights = self.mha(query=x, key=x, value=x)
        else:
            attn_output, attn_output_weights = self.mha(query=query, key=key, value=x)
        
        x = x + attn_output # skip-connection
        x = self.ln(x)
        
        return x
    
# sidenote: understanding skip-connections: https://theaisummer.com/skip-connections/
class AuTransformer(nn.Module):
    def __init__(self, alphabet_size: int, embedding_dim: int, max_len:int):
        super().__init__()
        
        self.input_embedding = nn.Embedding(alphabet_size+3, embedding_dim) # +3 for start, stop, padding symbol
        self.encoder = Encoder(alphabet_size=alphabet_size, embedding_dim=embedding_dim, max_len=max_len, embedding_layer=self.input_embedding)
        self.decoder = Decoder(alphabet_size=alphabet_size, embedding_dim=embedding_dim, max_len=max_len, embedding_layer=self.input_embedding)
        
        self.output_fnn = nn.Linear(in_features=embedding_dim, out_features=alphabet_size+3) # +2 for start and stop
        self.gelu = torch.nn.GELU()
        
        self.dropout = nn.Dropout(0.2)
        self.softmax_output = nn.Softmax(dim=-1)
        
        self.attention_output_layer = nn.Identity() 
        self.attention_weight_layer = nn.Identity() 
        
    def forward(self, src: torch.Tensor, tgt: torch.Tensor):
        x_src = self.input_embedding(src)
        x, attention_output, attention_weights = self.encoder(x_src)
        #print("Before: ", attention_weights)
        attention_output = self.attention_output_layer(attention_output)
        attention_weights = self.attention_weight_layer(attention_weights)
        #print("After: ", attention_weights)

        x = self.dropout(x)

        x_tgt = self.input_embedding(tgt)
        x = self.decoder(x=x_tgt, query=x, key=x)
        x = self.dropout(x)
        
        x = self.gelu(self.output_fnn(x))
        x = self.softmax_output(x)
        return x

In [3]:
class SequenceDataset(Dataset):
    def __init__(self, datapath: str, maxlen: int, pad_sequences: bool=True, max_sequences: int=None):
        super().__init__()
        
        assert(os.path.isfile(datapath))
        self.symbol_dict = dict()
        self.label_dict = dict()
        self.sequences, self.labels, self.sequence_lengths = self._read_sequences(datapath, max_sequences)
        print("Sequences loaded. Some examples: \n{}".format(self.sequences[:3]))
        
        self.SOS = self.alphabet_size
        self.EOS = self.alphabet_size + 1
        self.PAD = self.alphabet_size + 2
        self.maxlen = maxlen + 2  # +2 for EOS/PAD and SOS 
        self.pad_sequences = pad_sequences
        
    def encode_sequences(self):
        self.ordinal_seq, self.ordinal_seq_sr = self._ordinal_encode_sequences(self.sequences)
        self.one_hot_seq, self.one_hot_seq_sr = self._one_hot_encode_sequences(self.sequences)
        
        del self.sequences
        self.sequences = None
        
        print("The symbol dictionary: {}".format(self.symbol_dict))
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.ordinal_seq[idx], self.ordinal_seq_sr[idx], self.one_hot_seq[idx], \
               self.one_hot_seq_sr[idx], self.labels[idx], self.sequence_lengths[idx]
       
    def _read_sequences(self, datapath: str, max_sequences: int):
        sequences = list()
        labels = list()
        sequence_lengths = list()
        
        for i, line in enumerate(open(datapath)):
            if i == 0:
                line = line.split()
                self.alphabet_size = int(line[1])
                print("Alphabet size: ", self.alphabet_size)
                continue
            elif max_sequences and i-1 >= max_sequences:
                break
            
            line = line.split()
            label = line[0]
            if not label in self.label_dict:
                self.label_dict[label] = len(self.label_dict)
            label = self.label_dict[label]
            labels.append(label)
            
            sequences.append(line[2:])
            sequence_lengths.append(len(line) - 1)
        return sequences, labels, sequence_lengths
    
    def _pad_one_hot(self, sequences: list, do_eos: bool=False):
        for i in range(len(sequences)):
            seq = sequences[i]
            #print("Before one hot:\n{}".format(seq))
            current_size = len(seq)
            
            t = torch.zeros((self.maxlen - current_size, self.alphabet_size + 3), dtype=torch.float32)
            t[:, self.PAD] = 1
            if do_eos and self.maxlen > current_size:
                t[0, self.PAD] = 0
                t[0, self.EOS] = 1
            
            seq = torch.cat((seq, t), dim=0)
            sequences[i] = seq
            #print("After one hot:\n{}".format(seq))
        return sequences
    
    def _one_hot_encode_sequences(self, strings: list):
        res = list()
        res_sr = list()
        for string in strings:
            x1, x2 = self._one_hot_encode_string(string)
            res.append(x1)
            res_sr.append(x2)
            
        if self.pad_sequences:
            res = self._pad_one_hot(res)
            res_sr = self._pad_one_hot(res_sr)

        return res, res_sr
    
    def _one_hot_encode_string(self, string: list):
        encoded_string = torch.zeros((len(string)+2, self.alphabet_size + 3), dtype=torch.float32) # alphabet_size + 3 because SOS, EOS, padding token
        encoded_string[0][self.SOS] = 1
        encoded_string[-1][self.EOS] = 1

        encoded_string_sl = torch.zeros((len(string)+2, self.alphabet_size + 3), dtype=torch.float32)
        encoded_string_sl[-2][self.EOS] = 1
        encoded_string_sl[-1][self.PAD] = 1

        for i, symbol in enumerate(string):
            if not symbol in self.symbol_dict:
                self.symbol_dict[symbol] = len(self.symbol_dict)

            encoded_string[i+1][self.symbol_dict[symbol]] = 1
            encoded_string_sl[i][self.symbol_dict[symbol]] = 1
        encoded_string_sl.requires_grad_()
        return encoded_string, encoded_string_sl
    
    def _pad_ordinal(self, sequences: list, do_eos: bool=False):
        for i in range(len(sequences)):
            seq = sequences[i]
            #print("Before ordinal:{}".format(seq))
            current_size = len(seq)
            
            t = torch.ones((self.maxlen - current_size,), dtype=torch.long)
            t = t*self.PAD 
            if do_eos and self.maxlen > current_size:
                t[0] = self.EOS
            
            seq = torch.cat((seq, t), dim=0)
            sequences[i] = seq
            #print("After ordinal:{}".format(seq))
        return sequences
    
    def _ordinal_encode_sequences(self, strings: list):
        res = list()
        res_sr = list()
        for string in strings:
            x1, x2 = self._ordinal_encode_string(string)
            res.append(x1)
            res_sr.append(x2)
        
        if self.pad_sequences: 
            res = self._pad_ordinal(res)
            res_sr = self._pad_ordinal(res_sr)
        return res, res_sr
    
    def _ordinal_encode_string(self, string: list):
        encoded_string = torch.zeros((len(string)+2,), dtype=torch.long)
        encoded_string[0] = self.SOS
        encoded_string[-1] = self.EOS

        encoded_string_sl = torch.zeros((len(string)+2,), dtype=torch.long)
        encoded_string_sl[-2] = self.EOS
        encoded_string_sl[-1] = self.PAD

        for i, symbol in enumerate(string):
            if not symbol in self.symbol_dict:
                self.symbol_dict[symbol] = len(self.symbol_dict)

            encoded_string[i+1] = self.symbol_dict[symbol]
            encoded_string_sl[i] = self.symbol_dict[symbol]
        return encoded_string, encoded_string_sl
    
    def get_alphabet_size(self):
        return self.alphabet_size
    
    def initialize(self, path: str="dataset.pk"):
        data = pk.load(open(path, "rb"))
        self.alphabet_size = self.alphabet_size
        self.symbol_dict = self.symbol_dict
        self.label_dict = self.label_dict
        
    def save_state(self, path: str="dataset.pk"):
        data = dict()
        data["alphabet_size"] = self.alphabet_size
        data["symbol_dict"] = self.symbol_dict
        data["label_dict"] = self.label_dict
        pk.dump(data, open(path, "wb"))

# Check if the model learned the sequences properly

In [4]:
model = torch.load(MODEL_NAME)
for param in model.parameters():
    param.requires_grad = False
model.eval()

AuTransformer(
  (input_embedding): Embedding(7, 3)
  (encoder): Encoder(
    (pos_encoding): PositionalEncoding()
    (mha): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=3, out_features=3, bias=True)
    )
    (ln): LayerNorm((3,), eps=1e-12, elementwise_affine=True)
  )
  (decoder): Decoder(
    (pos_encoding): PositionalEncoding()
    (masked_mha): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=3, out_features=3, bias=True)
    )
    (ln): LayerNorm((3,), eps=1e-12, elementwise_affine=True)
    (mha): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=3, out_features=3, bias=True)
    )
  )
  (output_fnn): Linear(in_features=3, out_features=7, bias=True)
  (gelu): GELU(approximate='none')
  (dropout): Dropout(p=0.2, inplace=False)
  (softmax_output): Softmax(dim=-1)
  (attention_output_layer): Identity()
  (attention_weight_layer): Identity()
)

In [5]:
dataset = SequenceDataset(DATA_PATH, maxlen=10, max_sequences=None)
dataset.initialize()
dataset.encode_sequences()

Alphabet size:  4
Sequences loaded. Some examples: 
[['a', 'b'], ['a', 'b'], ['c', 'd', 'a', 'b']]
The symbol dictionary: {'a': 0, 'b': 1, 'c': 2, 'd': 3}


In [6]:
test_idx = 5000

In [7]:
ordinal_seq, ordinal_seq_sr, one_hot_seq, one_hot_seq_sr, label, sequence_length = dataset[test_idx]
#ordinal_seq, ordinal_seq_sr, one_hot_seq, one_hot_seq_sr, label, sequence_length

In [8]:
res = model(torch.unsqueeze(ordinal_seq, -1), torch.unsqueeze(ordinal_seq_sr, -1))

In [9]:
list(res.size()), list(ordinal_seq_sr.size())

([12, 1, 7], [12])

In [10]:
torch.argmax(torch.squeeze(res), dim=1), ordinal_seq_sr

(tensor([2, 3, 2, 3, 2, 3, 5, 6, 6, 6, 6, 6]),
 tensor([2, 3, 2, 3, 2, 3, 5, 6, 6, 6, 6, 6]))

In [11]:
torch.squeeze(res)

tensor([[1.1780e-12, 1.4222e-06, 1.0000e+00, 1.1780e-12, 9.9387e-13, 4.5924e-07,
         1.1780e-12],
        [4.7549e-12, 4.7549e-12, 4.7549e-12, 1.0000e+00, 4.0117e-12, 1.6505e-06,
         1.8015e-07],
        [1.1664e-12, 1.8209e-06, 1.0000e+00, 1.1664e-12, 9.8408e-13, 3.7062e-07,
         1.1664e-12],
        [4.8074e-12, 4.8074e-12, 4.8074e-12, 1.0000e+00, 4.0560e-12, 1.3310e-06,
         2.3238e-07],
        [1.2104e-12, 8.4219e-07, 1.0000e+00, 1.2104e-12, 1.0212e-12, 7.2722e-07,
         1.2104e-12],
        [5.4800e-12, 5.4800e-12, 5.4800e-12, 1.0000e+00, 4.6234e-12, 3.1223e-07,
         1.3728e-06],
        [1.1181e-12, 1.1181e-12, 2.2798e-07, 1.9881e-06, 9.4330e-13, 1.0000e+00,
         1.1181e-12],
        [6.6800e-08, 7.7910e-13, 7.7910e-13, 3.1592e-07, 6.5732e-13, 7.7910e-13,
         1.0000e+00],
        [1.9566e-08, 9.2453e-13, 9.2453e-13, 9.2743e-07, 7.8002e-13, 9.2453e-13,
         1.0000e+00],
        [2.5560e-07, 6.8659e-13, 6.8659e-13, 1.0349e-07, 5.7927e-13, 6.86

# Inspect the internal representations of the model

In [12]:
# Tutorial on hooks: https://web.stanford.edu/~nanbhas/blog/forward-hooks-pytorch/

activation = {}
def getActivation(name):
    global activation
    # the hook signature
    def hook(model, input, output):
        global activation
        activation[name] = output.detach()
    return hook

model.input_embedding.register_forward_hook(getActivation("embedding"))
#model.encoder.register_forward_hook(getActivation("encoder"))
model.attention_output_layer.register_forward_hook(getActivation("attention_output"))
model.attention_weight_layer.register_forward_hook(getActivation("attention_weights"))

<torch.utils.hooks.RemovableHandle at 0x7f938575e200>

In [13]:
with torch.no_grad():
    model(torch.unsqueeze(ordinal_seq, -1), torch.unsqueeze(ordinal_seq_sr, -1))

In [14]:
for name, t in activation.items():
    print("{}: {}".format(name, list(t.size())))

embedding: [12, 1, 3]
attention_output: [12, 1, 3]
attention_weights: [1, 12, 12]


In [16]:
attention_weights = torch.squeeze(activation["attention_weights"])
list(attention_weights.size()), attention_weights

([12, 12],
 tensor([[4.2063e-03, 3.8363e-02, 4.4841e-02, 2.5555e-01, 3.5213e-01, 5.2729e-02,
          1.9417e-01, 1.5719e-03, 4.4500e-03, 1.6906e-02, 2.9645e-02, 5.4369e-03],
         [2.4204e-03, 2.3007e-02, 4.7571e-02, 1.6831e-01, 4.1408e-01, 3.2123e-02,
          2.2197e-01, 1.9804e-03, 5.0910e-03, 3.5400e-02, 3.7290e-02, 1.0764e-02],
         [5.2721e-02, 1.0247e-02, 2.3416e-02, 2.6651e-03, 5.4248e-03, 8.1750e-03,
          8.2890e-03, 3.9164e-01, 1.2639e-01, 1.0442e-01, 3.2908e-02, 2.3371e-01],
         [3.1594e-03, 1.9194e-02, 6.1075e-02, 9.9865e-02, 3.6742e-01, 2.5309e-02,
          2.1938e-01, 6.0317e-03, 1.1222e-02, 9.3844e-02, 5.8515e-02, 3.4986e-02],
         [2.3977e-02, 2.8375e-03, 1.0347e-02, 5.0535e-04, 1.5900e-03, 2.1244e-03,
          2.7387e-03, 4.4890e-01, 1.0146e-01, 1.0176e-01, 1.8106e-02, 2.8565e-01],
         [2.5558e-03, 2.2539e-02, 5.0098e-02, 1.5574e-01, 4.0992e-01, 3.1169e-02,
          2.2374e-01, 2.4103e-03, 5.8696e-03, 4.2100e-02, 4.0614e-02, 1.3246e-02],

# Cluster the representations and express sequences in terms of their clusters. Can you see something?

In [31]:
weights = None
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
with torch.no_grad():
    for j, (test_string_ord, test_string_ord_sr, _, _, label, sequence_length) in enumerate(train_dataloader):
        if j == 1000:
            break
        #print(test_string_ord, test_string_ord_sr)
        #test_string_ord = test_string_ord[:sequence_length+2]
        #test_string_ord_sr = te
        #break
        
        test_string_ord = torch.permute(test_string_ord, dims=[1,0])
        test_string_ord_sr = torch.permute(test_string_ord_sr, dims=[1,0])
        
        res = model(test_string_ord, test_string_ord_sr)
        res = torch.squeeze(res)
        
        attn_weights = torch.squeeze(activation["attention_weights"])
        
        if weights is None:
            weights = attn_weights[1:sequence_length+1]
        else:
            weights = torch.cat((weights, attn_weights[1:sequence_length+1]), dim=0)

In [32]:
weights = weights.numpy()
weights.shape

(6856, 12)

In [33]:
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=4)
kmeans.fit(weights)

