In [1]:
%load_ext autoreload
%autoreload 2
import treelstm
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import pandas as pd
from anytree.importer import JsonImporter
from anytree.exporter import JsonExporter
from anytree import PreOrderIter
from treelstm import TreeLSTM, calculate_evaluation_orders
from tqdm import tqdm
import csv
import time
from torch.utils.data import DataLoader
from itertools import cycle
import sys
from torch import optim
sys.path.append("utils/")
from datasets.AstDataset import AstDataset

from utils.TreeConverter import TreeConverter
from utils.TreePlotter import TreePlotter
from utils.TreePredictionNode import Node




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
csv.field_size_limit(sys.maxsize)

131072

In [26]:
reserved_tokens_path = '../data/ast_trees_200k/reserved_tokens.json'

# Load the reserved tokens dictionary
with open(reserved_tokens_path, 'r') as json_f:
    json_data = json_f.read()
    
# To JSON format (dictionary)
reserved_tokens = json.loads(json_data)

ast_dataset = AstDataset('../data/ast_trees_200k/asts.csv', vocab_size=len(reserved_tokens), max_tree_size=-1)
# ast_dataset = AstDataset('test_dataset.csv', vocab_size=len(reserved_tokens))

num_workers = 0
loader = DataLoader(ast_dataset, batch_size=1, collate_fn=treelstm.batch_tree_input, num_workers=num_workers)

In [27]:
class TreeLstmEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, latent_size):
        super().__init__() 
        
        self.hidden_size = hidden_size
        self.embedding_dim = embedding_dim
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.tree_lstm = TreeLSTM(embedding_dim, hidden_size)
        self.z_mean = nn.Linear(hidden_size, latent_size)
        self.z_log_var = nn.Linear(hidden_size, latent_size)
        
    def forward(self, inp):
        batch_size = len(inp['tree_sizes'])
        
        features = self.embedding(inp['features'].type(torch.LongTensor).to(device)).view(-1, self.embedding_dim)
        hidden, cell = self.tree_lstm(features,
                                      inp['node_order'],
                                      inp['adjacency_list'],
                                      inp['edge_order'])
        
        
        # Take hidden states of roots of trees only -> tree lstm produces hidden states for all nodes in all trees as list
        # hidden roots: (batch_size, hidden_size)
        hidden_roots = torch.zeros(batch_size, self.hidden_size, device=device)
        
        # Offset to check in hidden state, start at zero, increase by tree size each time
        # Example: hidden  [1, 3, 5, 1, 5, 2] and tree_sizes = [4, 2] we want hidden[0] and hidden[4] -> 1, 5
        offset = 0
        for i in range(len(inp['tree_sizes'])):
            hidden_roots[i] = hidden[offset]
            offset += inp['tree_sizes'][i]

        # Get z_mean and z_log_var from hidden (parent roots only)
        z_mean = self.z_mean(hidden_roots)
        z_log_var = self.z_log_var(hidden_roots)
        
        # Parameterization trick
        z = self.reparameterize(z_mean, z_log_var)
        
        # Return latent vector, and mean and variance
        return z, z_mean, z_log_var
        
        
    def reparameterize(self, mu, log_var):
        if self.training:
            std = torch.exp(0.5 * log_var)
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
        
        
class TreeLstmDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, latent_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.latent_size = latent_size
        
        self.lstm_parent = nn.LSTMCell(vocab_size, latent_size)
        self.U_parent = nn.Linear(latent_size, latent_size)
        self.depth_pred = nn.Linear(latent_size, 1)
        
        self.lstm_sibling = nn.LSTMCell(vocab_size, latent_size)
        self.U_sibling = nn.Linear(latent_size, latent_size)
        self.width_pred = nn.Linear(latent_size, 1)
        
        self.label_prediction = nn.Linear(latent_size, vocab_size)
        
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.LogSoftmax(dim=1)
        
        self.offset_parent = nn.Linear(1, 1)
        self.offset_sibling = nn.Linear(1, 1)
        
        
    def forward(self, inp_batch, target=None):
        # Initalize output
        output={'predicted_labels': torch.zeros(sum(target['tree_sizes']), self.vocab_size, device=device),
                'labels': torch.zeros(sum(target['tree_sizes']), device=device),
                'predicted_has_siblings': torch.zeros(sum(target['tree_sizes']),  device=device),
                'has_siblings': torch.zeros(sum(target['tree_sizes']), device=device),
                'predicted_is_parent': torch.zeros(sum(target['tree_sizes']), device=device),
                'is_parent': torch.zeros(sum(target['tree_sizes']), device=device)}
        
        trees = []
        offset = 0
        counter = TreeNodeCounter()
                
        # Build tree by tree in batch
        for index, z in enumerate(inp_batch):
            # Initialize hidden_parent values
            hidden_parent = (z.unsqueeze(0), torch.zeros(self.latent_size).unsqueeze(0).to(device))
          
            if self.training:
                tree_size = target['tree_sizes'][index]
                
                adjacency_list_tree = batch['adjacency_list'][(batch['adjacency_list'][:,0] >= offset) & (batch['adjacency_list'][:,0] < offset + tree_size)]
                                
                tree, counter = self.build_tree(hidden_parent, None, output, None, counter, target['features'], adjacency_list_tree, node_index=offset)
                offset += tree_size
            else:
                tree, counter = self.build_tree(hidden_parent, None, output, None, counter)
                
            trees.append(tree)
            counter.increase()
                            
        return trees, output
            
        
    def build_tree(self, hidden_parent, hidden_sibling, output, parent_node=None, index=0,
                   features=None, adjacency_list=None, siblings=[], node_index=0):
                
        # Split hidden parent, into state and cell
        hidden_state_parent, hidden_cell_parent = hidden_parent      
        # Run hidden parent state through U_parent
        U_parent = self.U_parent(hidden_state_parent)
        
        # If there was a previous sibling, calculate U_sibling, otherwise use 0 to not include this value
        if hidden_sibling is not None:
            hidden_state_sibling, hidden_cell_sibling = hidden_sibling
            U_sibling = self.U_sibling(hidden_state_sibling)
        else:
            U_sibling = 0
        
        # tanh(U_parent + U_sibling)
        h_pred = torch.tanh(U_parent + U_sibling)
        label_pred = self.label_prediction(h_pred)
        
        # Probability of the node having children
        p_parent = self.sigmoid(self.depth_pred(h_pred))
        # Probability of the node having successor children
        p_sibling = self.sigmoid(self.width_pred(h_pred))
        
        # Teacher forcing on is_parent, has_sibling
        if self.training:
            label, child_indices = self.get_truth_values(features, adjacency_list, node_index)
            is_parent = torch.tensor([1], device=device, dtype=torch.float32) if len(child_indices) > 0 else torch.tensor([0], device=device, dtype=torch.float32)
            has_sibling = torch.tensor([1], device=device, dtype=torch.float32) if len(siblings) > 1 else torch.tensor([0], device=device, dtype=torch.float32)
        else:
            # Sample is_parent and has_sibling from predicted probability of parent/sibling
            is_parent = torch.distributions.bernoulli.Bernoulli(p_parent).sample()
            has_sibling = torch.distributions.bernoulli.Bernoulli(p_sibling).sample()
            
            # Could also simply use > 0.5 instead OR TODO BEAM SEARCH
            # is_parent = torch.tensor(1) if p_parent > 0.5 else torch.tensor(0)
            # has_sibling = torch.tensor(1) if p_sibling > 0.5 else torch.tensor(0)
            
        
        # Node label prediction
        predicted_label = self.softmax(label_pred + self.offset_parent(is_parent) + self.offset_sibling(has_sibling))
        
        # Build tree: Add node to tree
        if parent_node is None:
            node = Node(torch.argmax(predicted_label, dim=-1), parent=None)
        else:
            node = Node(torch.argmax(predicted_label, dim=-1), parent=parent_node)
            
        # For computing loss, save output (predictions and true values)
        output['predicted_labels'][index.get()] = predicted_label
        output['labels'][index.get()] = label
        output['predicted_has_siblings'][index.get()] = p_sibling
        output['has_siblings'][index.get()] = has_sibling
        output['predicted_is_parent'][index.get()] = p_parent
        output['is_parent'][index.get()] = is_parent
            
        
        # Teacher forcing on label
        if self.training:
            label = F.one_hot(label.long(), self.vocab_size).float()
        else:
            label = predicted_label
            
        
        if has_sibling:
            index.increase()
            if hidden_sibling is not None:
                hidden_sibling = self.lstm_sibling(label, hidden_sibling)
            else:
                hidden_sibling = self.lstm_sibling(label)
            
            if self.training:
                siblings.pop(0)
                self.build_tree(hidden_parent, hidden_sibling, output, parent_node, index, features, adjacency_list, siblings, siblings[0])
                
            else:
                self.build_tree(hidden_parent, hidden_sibling, output, parent_node, index)
            
        if is_parent:
            index.increase()
            hidden_parent = self.lstm_parent(label, hidden_parent)
            parent_node = node
            
            if self.training:
                siblings = list(child_indices)
                self.build_tree(hidden_parent, None, output, parent_node, index, features, adjacency_list, siblings, siblings[0])
                
            else:
                self.build_tree(hidden_parent, None, output, parent_node, index)
                
        return parent_node, index
        
        
    def get_truth_values(self, features, adjacency_list, index):
        adjacency_list_current = adjacency_list[adjacency_list[:, 0] == index]

        x = features[index, :]
        child_indices = adjacency_list_current[:, 1]

        return x, child_indices
    
    
class TreeNodeCounter:
    def __init__(self):
        self.counter = 0
        
    def increase(self):
        self.counter += 1
    
    def get(self):
        return self.counter
    
    def reset(self):
        self.counter = 0
            
        
        
class VaeLoss(nn.Module):
    
    def __init__(self):
        super().__init__()

    def forward(self, output, mu, log_var, vocab_size): 
        predicted_labels = output['predicted_labels']
        labels = output['labels'].view(-1).long().to(device)

        # Negative log likelihood loss (categorical cross entropy)
        label_prediction_loss = F.nll_loss(output['predicted_labels'], output['labels'].long())

        # Binary cross entropy loss for parent and sibling predictions (topology)
        parent_loss = nn.BCELoss()(output['predicted_is_parent'], output['is_parent'])
        sibling_loss = nn.BCELoss()(output['predicted_has_siblings'], output['has_siblings'])
        
        reconstruction_loss =  label_prediction_loss + parent_loss + sibling_loss
        
        kl_loss = 0.5 * torch.sum(log_var.exp() - log_var - 1 + mu.pow(2))
        
        total_loss = kl_loss + reconstruction_loss
        
        losses = {
            'total_loss': total_loss,
            'label_prediction_loss': label_prediction_loss.item(),
            'parent_loss': parent_loss.item(),
            'sibling_loss': sibling_loss.item(),
            'kl_loss': kl_loss.item()
        }
                        
        return losses
    
    def flatten_tree(self, tree):
        nodes = []
        for child in PreOrderIter(reconstructed_tree):
            nodes.append(child.pred.tolist())
        
        return torch.tensor(nodes)
        

In [28]:
encoder = TreeLstmEncoder(len(reserved_tokens), 30, 256, 128).to(device)
decoder = TreeLstmDecoder(len(reserved_tokens), 256, 128).to(device)

In [29]:
encoder.train()
decoder.train()

encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.001)
vae_loss = VaeLoss()

loss_types = ['total_loss', 'label_prediction_loss', 'parent_loss', 'sibling_loss', 'kl_loss']
losses = {}

for loss_type in loss_types:
    losses[loss_type] = []

pbar = tqdm(unit='batch', position=0)
for batch in loader:
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    for key in batch.keys():
        if key != 'tree_sizes':
            batch[key] = batch[key].to(device)
        
    start = time.time()
    z, z_mean, z_log_var = encoder(batch)
    print(f'encoding time: {time.time() - start}')
    start = time.time()
    reconstructed_tree, output = decoder(z, batch)
    print(f'decoding time: {time.time() - start}')
    
    curr_losses = vae_loss(output, z_mean, z_log_var, len(reserved_tokens))
    start = time.time()
    curr_losses['total_loss'].backward() 
    print(f'backwards time: {time.time() - start}')
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    for loss_type in loss_types:
        losses[loss_type].append(curr_losses[loss_type])
    
    pbar.set_postfix(loss=round(curr_losses['total_loss'].item(), 3))
    pbar.update()
#     tree = retrieve_tree(batch['features'], batch['adjacency_list'], 0)

0batch [00:00, ?batch/s]

encoding time: 0.012890815734863281
decoding time: 0.0882575511932373


1batch [00:00,  2.70batch/s, loss=10.2]

backwards time: 0.09966659545898438
encoding time: 0.014879941940307617
decoding time: 0.12757658958435059


2batch [00:00,  3.29batch/s, loss=8]   

backwards time: 0.09364438056945801
encoding time: 0.011805057525634766
decoding time: 0.09884285926818848


3batch [00:00,  3.82batch/s, loss=6.93]

backwards time: 0.0897071361541748
encoding time: 0.02576470375061035
decoding time: 0.24670934677124023
backwards time: 0.19557952880859375


4batch [00:01,  2.87batch/s, loss=7.51]

encoding time: 0.0453944206237793
decoding time: 0.5534889698028564


5batch [00:02,  1.61batch/s, loss=8.38]

backwards time: 0.4824962615966797
encoding time: 0.02780771255493164
decoding time: 0.2793283462524414


6batch [00:02,  1.66batch/s, loss=8.82]

backwards time: 0.23687338829040527
encoding time: 0.03934955596923828
decoding time: 0.6524362564086914


KeyboardInterrupt: 

In [120]:
for batch in tqdm(loader, position=0):
    if batch['tree_sizes'][0] < 30:
        data = batch
        break
        
for key in data.keys():
        if key != 'tree_sizes':
            data[key] = data[key].to(device)

encoder.train()
decoder.train()

encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.001)
vae_loss = VaeLoss()

pbar = tqdm(unit='batch', position=0)
for i in range(1000):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
            
    z, z_mean, z_log_var = encoder(data)
    reconstructed_tree, output = decoder(z, data)
    
    loss = vae_loss(output, z_mean, z_log_var, len(reserved_tokens))
    loss.backward()    
    encoder_optimizer.step()
    decoder_optimizer.step()
    pbar.set_postfix(loss=round(loss.item(), 3))
    pbar.update()
#     tree = retrieve_tree(batch['features'], batch['adjacency_list'], 0)

{'total_loss': [tensor(161.2562, device='cuda:0', grad_fn=<AddBackward0>),
  tensor(79.3005, device='cuda:0', grad_fn=<AddBackward0>),
  tensor(47.3698, device='cuda:0', grad_fn=<AddBackward0>),
  tensor(40.6093, device='cuda:0', grad_fn=<AddBackward0>),
  tensor(26.8318, device='cuda:0', grad_fn=<AddBackward0>),
  tensor(25.8960, device='cuda:0', grad_fn=<AddBackward0>),
  tensor(20.6566, device='cuda:0', grad_fn=<AddBackward0>),
  tensor(17.4204, device='cuda:0', grad_fn=<AddBackward0>),
  tensor(17.9077, device='cuda:0', grad_fn=<AddBackward0>),
  tensor(15.7321, device='cuda:0', grad_fn=<AddBackward0>)],
 'label_prediction_loss': [tensor(4.9460, device='cuda:0', grad_fn=<NllLossBackward>),
  tensor(4.9242, device='cuda:0', grad_fn=<NllLossBackward>),
  tensor(4.9006, device='cuda:0', grad_fn=<NllLossBackward>),
  tensor(4.8806, device='cuda:0', grad_fn=<NllLossBackward>),
  tensor(4.8529, device='cuda:0', grad_fn=<NllLossBackward>),
  tensor(4.8278, device='cuda:0', grad_fn=<NllLos

In [None]:
TreePlotter.plot_predicted_tree(reconstructed_tree[0], 'predicted_tree.png')

In [475]:
reader = csv.reader(open('../data/ast_trees_200k/asts.csv'))

next(reader)

TreePlotter.plot_tree(JsonImporter().import_(next(reader)[1]), 'first_tree.png')


(<utils.TreePredictionNode.Node at 0x7fb669acf2e0>,
 <utils.TreePredictionNode.Node at 0x7fb669acf490>,
 <utils.TreePredictionNode.Node at 0x7fb669acf730>)