In [2]:
import treelstm
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import pandas as pd
from anytree.importer import JsonImporter
from anytree.exporter import JsonExporter
from anytree import PreOrderIter
from treelstm import TreeLSTM, calculate_evaluation_orders
from tqdm import tqdm
import csv
import time
from torch.utils.data import DataLoader
from itertools import cycle
import sys
from torch import optim
sys.path.append("utils/")
from datasets.AstDataset import AstDataset

from utils.TreeConverter import TreeConverter
from utils.TreePlotter import TreePlotter
from utils.TreePredictionNode import Node




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
csv.field_size_limit(sys.maxsize)

131072

In [102]:
reserved_tokens_path = '../data/ast_trees_200k/reserved_tokens.json'

# Load the reserved tokens dictionary
with open(reserved_tokens_path, 'r') as json_f:
    json_data = json_f.read()
    
# To JSON format (dictionary)
reserved_tokens = json.loads(json_data)

ast_dataset = AstDataset('../data/ast_trees_200k/asts.csv', vocab_size=len(reserved_tokens), max_tree_size=-1)
# ast_dataset = AstDataset('test_dataset.csv', vocab_size=len(reserved_tokens))

num_workers = 8
loader = DataLoader(ast_dataset, batch_size=32, collate_fn=treelstm.batch_tree_input, num_workers=num_workers)

In [103]:
class TreeLstmEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, latent_size):
        super().__init__() 
        
        self.hidden_size = hidden_size
        self.embedding_dim = embedding_dim
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.tree_lstm = TreeLSTM(embedding_dim, hidden_size)
        self.z_mean = nn.Linear(hidden_size, latent_size)
        self.z_log_var = nn.Linear(hidden_size, latent_size)
        
    def forward(self, inp):
        batch_size = len(inp['tree_sizes'])
        
        features = self.embedding(inp['features'].type(torch.LongTensor).to(device)).view(-1, self.embedding_dim)
        hidden, cell = self.tree_lstm(features,
                                      inp['node_order'],
                                      inp['adjacency_list'],
                                      inp['edge_order'])
        
        
        # Take hidden states of roots of trees only -> tree lstm produces hidden states for all nodes in all trees as list
        # hidden roots: (batch_size, hidden_size)
        hidden_roots = torch.zeros(batch_size, self.hidden_size, device=device)
        
        # Offset to check in hidden state, start at zero, increase by tree size each time
        # Example: hidden  [1, 3, 5, 1, 5, 2] and tree_sizes = [4, 2] we want hidden[0] and hidden[4] -> 1, 5
        offset = 0
        for i in range(len(inp['tree_sizes'])):
            hidden_roots[i] = hidden[offset]
            offset += inp['tree_sizes'][i]

        # Get z_mean and z_log_var from hidden (parent roots only)
        z_mean = self.z_mean(hidden_roots)
        z_log_var = self.z_log_var(hidden_roots)
        
        # Parameterization trick
        z = self.reparameterize(z_mean, z_log_var)
        
        # Return latent vector, and mean and variance
        return z, z_mean, z_log_var
        
        
    def reparameterize(self, mu, log_var):
        if self.training:
            std = torch.exp(0.5 * log_var)
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
        
        
class TreeLstmDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, latent_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.latent_size = latent_size
        
        self.lstm_parent = nn.LSTMCell(vocab_size, latent_size)
        self.U_parent = nn.Linear(latent_size, latent_size)
        self.depth_pred = nn.Linear(latent_size, 1)
        
        self.lstm_sibling = nn.LSTMCell(vocab_size, latent_size)
        self.U_sibling = nn.Linear(latent_size, latent_size)
        self.width_pred = nn.Linear(latent_size, 1)
        
        self.label_prediction = nn.Linear(latent_size, vocab_size)
        
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.LogSoftmax(dim=1)
        
        self.offset_parent = nn.Linear(1, 1)
        self.offset_sibling = nn.Linear(1, 1)
        
        
    def forward(self, inp_batch, target=None):
        # Initalize output
        output={'predicted_labels': [], 'labels': [], 'predicted_has_siblings': [],
                'has_siblings': [], 'predicted_is_parent': [], 'is_parent': []}
        
        trees = []
        
        # Build tree by tree in batch
        for index, z in enumerate(inp_batch):
            # Initialize hidden_parent values
            hidden_parent = (z.unsqueeze(0), torch.zeros(self.latent_size).unsqueeze(0).to(device))
          
            if self.training:
                start, end = target['tree_sizes'][index - 1] if index -1 > 0 else 0, target['tree_sizes'][index]  
                trees.append(self.build_tree(hidden_parent, None, output, None, target['features'], target['adjacency_list'][start: end]))
            else:
                trees.append(self.build_tree(hidden_parent, None, output))
            
        return trees, output
            
        
    def build_tree(self, hidden_parent, hidden_sibling, output, parent_node=None,
                   features=None, adjacency_list=None, siblings=[], index=0):
        
        
        # Split hidden parent, into state and cell
        hidden_state_parent, hidden_cell_parent = hidden_parent      
        # Run hidden parent state through U_parent
        U_parent = self.U_parent(hidden_state_parent)
        
        # If there was a previous sibling, calculate U_sibling, otherwise use 0 to not include this value
        if hidden_sibling is not None:
            hidden_state_sibling, hidden_cell_sibling = hidden_sibling
            U_sibling = self.U_sibling(hidden_state_sibling)
        else:
            U_sibling = 0
        
        # tanh(U_parent + U_sibling)
        h_pred = torch.tanh(U_parent + U_sibling)
        label_pred = self.label_prediction(h_pred)
        
        # Probability of the node having children
        p_parent = self.sigmoid(self.depth_pred(h_pred))
        # Probability of the node having successor children
        p_sibling = self.sigmoid(self.width_pred(h_pred))
        
        # Teacher forcing on is_parent, has_sibling
        if self.training:
            label, child_indices = self.get_truth_values(features, adjacency_list, index)
            is_parent = torch.tensor([1], device=device, dtype=torch.float32) if len(child_indices) > 0 else torch.tensor([0], device=device, dtype=torch.float32)
            has_sibling = torch.tensor([1], device=device, dtype=torch.float32) if len(siblings) > 1 else torch.tensor([0], device=device, dtype=torch.float32)
        else:
            # Sample is_parent and has_sibling from predicted probability of parent/sibling
            is_parent = torch.distributions.bernoulli.Bernoulli(p_parent).sample()
            has_sibling = torch.distributions.bernoulli.Bernoulli(p_sibling).sample()
            
            # Could also simply use > 0.5 instead OR TODO BEAM SEARCH
            # is_parent = torch.tensor(1) if p_parent > 0.5 else torch.tensor(0)
            # has_sibling = torch.tensor(1) if p_sibling > 0.5 else torch.tensor(0)
            
        
        # Node label prediction
        predicted_label = self.softmax(label_pred + self.offset_parent(is_parent) + self.offset_sibling(has_sibling))
        
        # Build tree: Add node to tree
        if parent_node is None:
            node = Node(predicted_label, parent=None)
        else:
            node = Node(predicted_label, parent=parent_node)
            
        # For computing loss, save output (predictions and true values)
        output['predicted_labels'].append(predicted_label.tolist())
        output['labels'].append(label.tolist())
        output['predicted_has_siblings'].append(p_sibling.tolist())
        output['has_siblings'].append(has_sibling.tolist())
        output['predicted_is_parent'].append(p_parent.tolist())
        output['is_parent'].append(is_parent.tolist())
            
        
        # Teacher forcing on label
        if self.training:
            label = F.one_hot(label.long(), self.vocab_size).float()
        else:
            label = predicted_label
            
        
        if has_sibling:
            if hidden_sibling is not None:
                hidden_sibling = self.lstm_sibling(label, hidden_sibling)
            else:
                hidden_sibling = self.lstm_sibling(label)
            
            if self.training:
                siblings.pop(0)
                self.build_tree(hidden_parent, hidden_sibling, output, parent_node, features, adjacency_list, siblings, siblings[0])
                
            else:
                self.build_tree(hidden_parent, hidden_sibling, output, parent_node)
            
        if is_parent:
            hidden_parent = self.lstm_parent(label, hidden_parent)
            parent_node = node
            
            if self.training:
                siblings = list(child_indices)
                self.build_tree(hidden_parent, None, output, parent_node, features, adjacency_list, siblings, siblings[0])
                
            else:
                self.build_tree(hidden_parent, None, output, parent_node)
                
        return parent_node
        
        
    def get_truth_values(self, features, adjacency_list, index):
        adjacency_list_current = adjacency_list[adjacency_list[:, 0] == index]

        x = features[index, :]
        child_indices = adjacency_list_current[:, 1]

        return x, child_indices
            
        
        
class VaeLoss(nn.Module):
    
    def __init__(self):
        super().__init__()

    def forward(self, output, mu, log_var, vocab_size):  
        predicted_labels = torch.tensor(output['predicted_labels']).view(-1, vocab_size).to(device)
        labels = torch.tensor(output['labels']).view(-1).long().to(device)
#         reconstructed_tree_labels = self.flatten_tree(reconstructed_tree).view(-1, vocab_size).to(device)
#         tree_labels = tree_labels.view(-1).long()

        # Negative log likelihood loss (categorical cross entropy)
        label_prediction_loss = F.nll_loss(predicted_labels, labels)

        # Binary cross entropy loss for parent and sibling predictions (topology)
        parent_loss = nn.BCELoss()(torch.tensor(output['predicted_is_parent']).view(-1), torch.tensor(output['is_parent']).view(-1))
        sibling_loss = nn.BCELoss()(torch.tensor(output['predicted_has_siblings']).view(-1), torch.tensor(output['has_siblings']).view(-1))
        
        reconstruction_loss =  label_prediction_loss + parent_loss + sibling_loss
        
        kl_loss = 0.5 * torch.sum(log_var.exp() - log_var - 1 + mu.pow(2))
                
        return reconstruction_loss + kl_loss
    
    def flatten_tree(self, tree):
        nodes = []
        for child in PreOrderIter(reconstructed_tree):
            nodes.append(child.pred.tolist())
        
        return torch.tensor(nodes)
        

In [104]:
encoder = TreeLstmEncoder(len(reserved_tokens), 30, 256, 128).to(device)
decoder = TreeLstmDecoder(len(reserved_tokens), 256, 128).to(device)

In [None]:
encoder.train()
decoder.train()

encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.001)
vae_loss = VaeLoss()

pbar = tqdm(unit='batch')
for batch in loader:
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    for key in batch.keys():
        if key != 'tree_sizes':
            batch[key] = batch[key].to(device)
            
    z, z_mean, z_log_var = encoder(batch)
    reconstructed_tree, output = decoder(z, batch)
    
    loss = vae_loss(output, z_mean, z_log_var, len(reserved_tokens))
    loss.backward()    
    encoder_optimizer.step()
    decoder_optimizer.step()
    pbar.set_postfix(loss=round(loss.item(), 3))
    pbar.update()
#     tree = retrieve_tree(batch['features'], batch['adjacency_list'], 0)

58batch [01:17,  1.34s/batch, loss=6.77]
58batch [01:45,  1.81s/batch, loss=6.77]
58batch [01:46,  1.84s/batch, loss=6.77]
58batch [01:50,  1.90s/batch, loss=6.77]
58batch [01:53,  1.96s/batch, loss=6.77]
58batch [01:55,  1.99s/batch, loss=6.77]
58batch [01:56,  2.01s/batch, loss=6.77]
58batch [02:20,  2.43s/batch, loss=6.77]
58batch [02:22,  2.46s/batch, loss=6.77]
271batch [04:49,  1.43s/batch, loss=6.39]

In [36]:
for batch in loader:
    for key in batch.keys():
        if key != 'tree_sizes':
            batch[key] = batch[key].to(device)
            
    z, z_mean, z_log_var = encoder(batch)
    reconstructed_tree = decoder(z, batch)
    break

In [94]:
for index, batch in enumerate(loader):
    if index == 2:
        print(batch['adjacency_list'])
        print(batch['features'][602:])
        print(batch['tree_sizes'])
        break

tensor([[  0,   1],
        [  1,   2],
        [  1,   3],
        ...,
        [860, 861],
        [845, 862],
        [862, 863]])
tensor([[ 0.],
        [ 7.],
        [ 8.],
        [ 9.],
        [ 7.],
        [ 8.],
        [ 9.],
        [ 7.],
        [ 8.],
        [ 9.],
        [15.],
        [ 5.],
        [16.],
        [26.],
        [ 3.],
        [ 4.],
        [ 5.],
        [26.],
        [ 3.],
        [ 4.],
        [ 5.],
        [11.],
        [27.],
        [49.],
        [17.],
        [ 6.],
        [11.],
        [23.],
        [ 6.],
        [19.],
        [ 2.],
        [ 3.],
        [ 4.],
        [ 5.],
        [25.],
        [ 5.],
        [30.],
        [17.],
        [59.],
        [17.],
        [ 6.],
        [27.],
        [60.],
        [17.],
        [ 6.],
        [11.],
        [23.],
        [56.],
        [56.],
        [17.],
        [17.],
        [17.],
        [11.],
        [23.],
        [56.],
        [17.],
        [17.],
        [15

In [774]:
reconstructed_tree.pred

tensor([[-5.0833, -5.1485, -4.9925, -4.8512, -4.8439, -5.1410, -4.8637, -5.3190,
         -5.0486, -4.9787, -4.2735, -4.6612, -5.0867, -4.9258, -5.1109, -5.1695,
         -4.6145, -5.0211, -5.2556, -5.0707, -4.8873, -4.6784, -4.7922, -5.0320,
         -5.0267, -5.0366, -5.1499, -4.8833, -5.2060, -4.8906, -4.8833, -4.7884,
         -5.0989, -5.1590, -5.6312, -5.0200, -5.0279, -5.2376, -4.4423, -5.0881,
         -5.2620, -5.0520, -5.3334, -5.0406, -5.0080, -4.9668, -5.3711, -5.0183,
         -5.1444, -5.1024, -5.2567, -4.6066, -5.6135, -4.9749, -4.5158, -5.7996,
         -5.0591, -4.9957, -4.4560, -4.8711, -4.8767, -4.9055, -5.2220, -4.6838,
         -5.0249, -4.6379, -4.5016, -4.7282, -4.9555, -5.1179, -4.6750, -4.7107,
         -5.0069, -4.9236, -4.8821, -4.9246, -4.5893, -5.2496, -4.8997, -4.5019,
         -5.0527, -5.1799, -4.7794, -4.6719, -5.1771, -5.1550, -5.3076, -4.9797,
         -5.0670, -4.6437, -5.0867, -5.4347, -5.2504, -5.3512, -4.5376, -4.8877,
         -4.9932, -5.0787, -

In [636]:
def retrieve_tree(features, adjacency_list, index, parent_node=None):
    adjacency_list_current = adjacency_list[adjacency_list[:, 0] == index]

    
    parent_node = Node(features[index, :], parent=parent_node)
    child_indices = adjacency_list_current[:, 1]

    for child_index in child_indices:
        retrieve_tree(features, adjacency_list, child_index, parent_node)
        
    return parent_node

In [475]:
reader = csv.reader(open('../data/ast_trees_200k/asts.csv'))

next(reader)

TreePlotter.plot_tree(JsonImporter().import_(next(reader)[1]), 'first_tree.png')


(<utils.TreePredictionNode.Node at 0x7fb669acf2e0>,
 <utils.TreePredictionNode.Node at 0x7fb669acf490>,
 <utils.TreePredictionNode.Node at 0x7fb669acf730>)