In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from tqdm import tqdm
import json
from anytree.importer import JsonImporter
from anytree import PreOrderIter
import pandas as pd
import sys
sys.path.append("utils/")
from utils.TreePlotter import TreePlotter
from utils.TreeConverter import TreeConverter
import math
# from torchtext.data import Field, BucketIterator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'

In [224]:
def prepare_data():
    reserved_tokens_path = '../data/ast_trees/reserved_tokens.json'
    asts_path = '../data/ast_trees/asts.csv.bz2'

    with open(reserved_tokens_path, 'r') as json_f:
        json_data = json_f.read()

    reserved_tokens = json.loads(json_data)

    asts = pd.read_csv(asts_path, chunksize=1e4)

    importer = JsonImporter()
    
#     reserved_tokens['<sos>'] = len(reserved_tokens)
#     reserved_tokens['<eos>'] = len(reserved_tokens)
    reserved_tokens['<pad>'] = len(reserved_tokens)


    data = []
    try:
        for ast_chunk in asts:
            for ast in ast_chunk['AST']:
                tree = importer.import_(ast)
                tree_repr = torch.tensor([node.token for node in PreOrderIter(tree) if node.res])
                if len(tree_repr) < 300 and len(tree_repr) > 1:
                    data.append(tree_repr)
    except Exception:
        pass
    
    # Sort trees such that during batching we have to pad the batches minimally
    data.sort(key=len)

#     print(torch.stack(trees))
#     data = torch.tensor(trees)
    
    return data, reserved_tokens

data, reserved_tokens = prepare_data()    

                                                  

In [13]:
def preprocess_batch(batch):
#     batch = [torch.tensor([reserved_tokens['<sos>']] + item + [reserved_tokens['<eos>']]) for item in batch]
    batch = nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=reserved_tokens['<pad>'])
    
    return batch.to(device)


In [628]:
class EncoderVAE(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, latent_size, dropout, bidirectional=False):
        super().__init__()       
        
        # Variables
        self.latent_size = latent_size
        self.bidirectional = bidirectional
        self.hidden_size = hidden_size

        # Layers
#         self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=reserved_tokens['<pad>'])
#         self.dropout = nn.Dropout(dropout)
        self.lstm = nn.LSTM(embedding_size, hidden_size, bidirectional=self.bidirectional, batch_first=True)
        self.z_mean = nn.Linear(hidden_size, self.latent_size)
        self.z_log_var = nn.Linear(hidden_size, self.latent_size)
        
        
    def reparameterize(self, mu, log_var):
        if self.training:
            std = torch.exp(0.5 * log_var)
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
        
    def forward(self, inp):
        # inp shape: (batch_size, seq_len, emb_size)
        
        [batch_size, seq_len, emb_size] = inp.size()
        
#         embedding = self.dropout(self.embedding(inp))

        inp = torch.nn.utils.rnn.pack_padded_sequence(inp, [len(el) for el in inp], batch_first=True)
        
        # Get final hidden state from the last layer
        outputs, (hidden, cell) = self.lstm(inp)
        
        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
        
        outputs = outputs[:, -1, :].view(1, batch_size, -1)
#         print(outputs.shape)
#         print(hidden.shape)
#         print(cell.shape)
        
#         print(cell.view(1, batch_size, self.hidden_size).shape)
        
#         last_hidden_state = hidden[:, -1]
        
#         z_mean = self.z_mean(last_hidden_state)
#         z_log_var = self.z_log_var(last_hidden_state)

        z_mean = self.z_mean(outputs)
        z_log_var = self.z_log_var(outputs)
        
        z = self.reparameterize(z_mean, z_log_var)
        
        return z, z_mean, z_log_var

In [629]:
class DecoderVAE(nn.Module):
    def __init__(self, latent_size, hidden_size, embed_dim, output_size):
        super().__init__()
        
        self.latent_size = latent_size
        self.hidden_size = hidden_size
        
        self.lstm = nn.LSTM(latent_size + embed_dim, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=2)
        
    def forward(self, z, seq_len, target, embedding):
        # z shape: (1, batch_size, latent_variable_size)
        
        [_, batch_size, latent_variable_size] = z.size()
                
        # Replicate z to append the same z at each time step
#         z = torch.cat([z] * seq_len, 1).view(batch_size, seq_len, self.latent_size) 

#         z.squeeze(0)

#         If we are not training
        if not self.training:
#             z = torch.cat([z] * seq_len, 1).view(batch_size, seq_len, self.latent_size)
            z = z.view(batch_size, 1, self.latent_size)
            out = torch.zeros(seq_len, batch_size)
        
            hidden = None
            cell = None

            
            for i in range(seq_len):
                decoder_inp = torch.cat([target, z], 2)
                output, (hidden, cell) = self.lstm(decoder_inp, None if hidden is None else (hidden, cell))
                out_seq = self.softmax(self.fc(output))
                prediction = torch.argmax(out_seq, dim=2)
                            
                out[i] = prediction.view(batch_size)
            
                target = embedding(prediction)
            
            
#             print(torch.tensor(outputs).shape)
            print(out.view(batch_size, -1))
#             outputs = outputs.view(batch_size, seq_len, -1)
            
        else:
            z = torch.cat([z] * seq_len, 1).view(batch_size, seq_len, self.latent_size)
#             print(z.shape)
#             print(target.shape)
            decoder_inp = torch.cat([target, z], 2)
            
#             print(decoder_inp.shape)
            
            outputs, _ = self.lstm(decoder_inp)

#         outputs, _ = self.lstm(z)


            out = self.softmax(self.fc(outputs))
            # shape of out: (batch_size, seq_len, output_size)
        
        return out
        
            

In [630]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder, vocab_size, embed_dim, dropout):
        super().__init__()   
        self.embed_dim = embed_dim
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=reserved_tokens['<pad>'])
        self.dropout = nn.Dropout(dropout)
        self.encoder = encoder
        self.decoder = decoder
            
    def forward(self, inp, train=False):
        [batch_size, seq_len] = inp.size()
                
        embedded = self.dropout(self.embedding(inp))
        # embedded shape: (batch_size, seq_length, embedding_size)
               
        
        z, z_mean, z_log_var = self.encoder(embedded)
        
        # Remove the last token of the embedded as we do not need to predict after this
        embedded = embedded[:, :-1,:].view(batch_size, seq_len -1, self.embed_dim)
        
        if not train:
            # Get only first token of input which is always root, when not training
            embedded = embedded[:, 0,:].view(batch_size, 1, self.embed_dim) 
                                    
        dec_output = self.decoder(z, seq_len - 1, embedded, self.embedding)
        
        return dec_output, z_mean, z_log_var
        
            
    def loss(self, dec_output, target, mu, log_var):
        # Remove first token as this is always the root and given
        
        target = target[:, 1:].contiguous()
        
        # Calculate latent loss/reconstruction loss
        # How far is the reconstruction from the original input
        # reconstruction shape: (batch_size, seq_len, vocab_len) -> (batch_size * seq_len, vocab_len)
        # target shape: (batch_size, seq_len) -> (batch_size * seq_len)
        reconstruction_loss = F.nll_loss(
            m(dec_output).view(-1, len(reserved_tokens)),
            target.view(-1),
            ignore_index=reserved_tokens['<pad>']
        )
        
        # Get kullback leichner loss
        kl_loss = 0.5 * torch.sum(log_var.exp() - log_var - 1 + mu.pow(2))
        
#         print(reconstruction_loss, kl_loss)
        
        return reconstruction_loss + kl_loss
    
    def train(self, data, epochs, batch_size):
        self.encoder.train()
        self.decoder.train()
        
        encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=0.001)
        decoder_optimizer = optim.Adam(self.decoder.parameters(), lr=0.001)
        
        
        for epoch in range(epochs):
            pbar = tqdm(range(math.ceil(len(data)/batch_size)), unit='batch', desc=f'Epoch: {epoch}', position=0)
            for i in range(0, len(data), batch_size):
                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()

                batch = preprocess_batch(data[i: i + batch_size])

                dec_output, z_mean, z_log_var = self(batch, train=True)
                loss = self.loss(dec_output, batch, z_mean, z_log_var)
                loss.backward()
                encoder_optimizer.step()
                decoder_optimizer.step()
                pbar.set_postfix(loss=round(loss.item(), 3))
                pbar.update()
            print(f'Epoch {epoch} loss: {loss}')
        
        
    def generate(self, data, batch_size):
        self.encoder.eval()
        self.decoder.eval()
        
        pbar = tqdm(range(math.ceil(len(data)/batch_size)), unit='batch')
        for i in range(0, len(data), batch_size):
            batch = preprocess_batch(data[i: i + batch_size])
            self(batch)
            pbar.update()
        
        

In [631]:
vocab_size = len(reserved_tokens)
embed_dim = 300
lstm_hidden_size = 256
latent_size = 128
dropout = 0.5

encoder = EncoderVAE(vocab_size, embed_dim, lstm_hidden_size, latent_size, dropout).to(device)
decoder = DecoderVAE(latent_size, lstm_hidden_size, embed_dim, vocab_size).to(device)

vae = VAE(encoder, decoder, vocab_size, embed_dim, dropout).to(device)

In [None]:
vae.train(data, 20, 64)

In [634]:
vae.generate(data[:16], 16)


  0%|          | 0/1 [00:00<?, ?batch/s][A

tensor([[11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11., 11.,
         11., 11.],
        [ 8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
          8.,  8.],
        [12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12., 12.,
         12., 12.],
        [13., 13., 13., 13., 13., 13., 13., 13., 13., 13., 13., 13., 13., 13.,
         13., 13.],
        [18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18., 18.,
         18., 18.],
        [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,
          5.,  5.],
        [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.,
          6.,  6.],
        [ 7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.,
          7.,  7.],
        [ 8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.,
          8.,  8.],
        [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,
          5.,  5.],
        [ 

In [607]:
data[:16]

[tensor([ 0, 11,  8, 12, 13, 41,  9]),
 tensor([ 0, 11,  8, 12, 13, 14,  8, 16,  4]),
 tensor([ 0, 11,  8, 12, 13, 14,  8, 16,  4]),
 tensor([ 0, 11,  8, 12, 13, 41, 14,  8, 16,  4]),
 tensor([ 0, 11,  8, 12, 13, 18,  5,  6,  7,  8]),
 tensor([ 0, 11,  8, 12, 13, 14,  8, 16,  4, 41,  9]),
 tensor([ 0, 11,  8, 12, 13, 14,  8, 16,  4, 41,  9]),
 tensor([ 0, 11,  8, 12, 13, 14,  8, 16,  4, 41,  9]),
 tensor([ 0, 11,  8, 12, 13, 14,  8, 16,  4, 41,  9]),
 tensor([ 0, 11,  8, 12, 13, 18,  5,  6,  7,  8, 41,  9]),
 tensor([ 0, 11,  8, 12, 13, 14,  8, 16,  4, 41, 33,  9]),
 tensor([ 0,  5,  6,  7,  8, 11,  8, 12, 13, 18,  5,  6,  7,  8]),
 tensor([ 0,  5,  6,  7,  8,  5,  6,  7,  8, 11,  8, 12, 13, 41,  9]),
 tensor([ 0, 11,  8, 12, 13, 18,  5,  6,  7,  8,  5,  6,  7,  8, 41,  9]),
 tensor([ 0, 11,  8, 12, 13, 14,  8, 15, 63, 16, 14,  8, 16,  4, 41,  9]),
 tensor([ 0, 11,  8, 12, 13, 18,  5,  6,  7,  8,  9,  5,  6,  7,  8, 41,  9])]

In [448]:
idx = 4

encoder.eval()
decoder.eval()

z, _, _ = encoder(data[idx].view(1, -1).to(device))
reconstruction = torch.argmax(decoder(z, len(data[idx], data[idx])), dim=2)

print(data[idx])
print(reconstruction)

ValueError: not enough values to unpack (expected 3, got 2)

In [353]:
len(data[idx])

35

In [281]:
decoder(z, len(data[-1])).shape

Epoch: 0: 100%|██████████| 213/213 [04:16<00:00,  1.20s/batch, loss=3.91]
Epoch: 1: 100%|██████████| 213/213 [03:54<00:00,  1.10s/batch, loss=3.27]
Epoch: 2: 100%|██████████| 213/213 [03:33<00:00,  1.00s/batch, loss=3.1]
Epoch: 3: 100%|██████████| 213/213 [03:12<00:00,  1.11batch/s, loss=3.02]
Epoch: 4: 100%|██████████| 213/213 [02:51<00:00,  1.25batch/s, loss=2.97]
Epoch: 5: 100%|██████████| 213/213 [02:29<00:00,  1.42batch/s, loss=2.96]
Epoch: 6: 100%|██████████| 213/213 [02:08<00:00,  1.66batch/s, loss=2.95]
Epoch: 7: 100%|██████████| 213/213 [01:47<00:00,  1.99batch/s, loss=2.95]
Epoch: 8: 100%|██████████| 213/213 [01:25<00:00,  2.48batch/s, loss=2.96]
Epoch: 9: 100%|██████████| 213/213 [01:04<00:00,  3.30batch/s, loss=2.97]


torch.Size([1, 299, 136])

In [290]:
reserved_tokens

{'root': 0,
 'TYPEDEF_DECL': 1,
 'TYPE_DEF': 2,
 'IDENTIFIER': 3,
 'NAMESPACE_REF': 4,
 'VAR_DECL': 5,
 'TYPE': 6,
 'DECLARATOR': 7,
 'NAME': 8,
 'INTEGER_LITERAL': 9,
 'UNARY_OPERATOR_-': 10,
 'FUNCTION_DECL': 11,
 'RETURN_TYPE': 12,
 'COMPOUND_STMT': 13,
 'CALL_EXPR': 14,
 'ARGUMENTS': 15,
 'TYPE_REF': 16,
 'DECL_REF_EXPR': 17,
 'DECL_STMT': 18,
 'BINARY_OPERATOR_+': 19,
 'FOR_STMT': 20,
 'BINARY_OPERATOR_=': 21,
 'BINARY_OPERATOR_<': 22,
 'UNARY_OPERATOR_POST_++': 23,
 'ARRAY_SUBSCRIPT_EXPR': 24,
 'BINARY_OPERATOR_!=': 25,
 'CHARACTER_LITERAL': 26,
 'NULL_STMT': 27,
 'STRING_LITERAL': 28,
 'WHILE_STMT': 29,
 'UNARY_OPERATOR_!': 30,
 'MEMBER_REF_EXPR': 31,
 'BINARY_OPERATOR_%': 32,
 'PAREN_EXPR': 33,
 'BINARY_OPERATOR_*': 34,
 'IF_STMT': 35,
 'BINARY_OPERATOR_&&': 36,
 'BINARY_OPERATOR_<=': 37,
 'CSTYLE_CAST_EXPR': 38,
 'CXX_FOR_RANGE_STMT': 39,
 'PARM_DECL': 40,
 'RETURN_STMT': 41,
 'BINARY_OPERATOR_==': 42,
 'UNARY_OPERATOR_&': 43,
 'FLOATING_LITERAL': 44,
 'STRUCT_DECL': 45,
 'FIE

In [12]:
TreePlotter.plot_tree(root, 'tree.png')

In [13]:
binary_tree = TreeConverter.to_binary(root)

In [14]:
TreePlotter.plot_tree(binary_tree, 'binary_tree.png', binary=True)

In [18]:
tree_repr = torch.tensor([node.token for node in PreOrderIter(root) if node.res])

In [18]:
root.children[0].children[1].children[0].children[0]

AnyNode(res=False, token='')