#**Attention-based Tree-to-Sequence Code Summarization Model**

##Prepare the dataset


In [None]:
!ls -hal

In [None]:
!rm -rf {*,.*}
!git clone https://github.com/sh1doy/summarization_tf .
!pip install -r requirements.txt

###1. Download raw dataset from https://github.com/xing-hu/DeepCom


In [None]:
!git clone https://github.com/xing-hu/DeepCom dataset
!cd dataset && 7z -y -bsp0 -bso0 x data.7z
!cd dataset && mkdir valid && mkdir train && mkdir test


### 2. Parse them with parser.jar

In [None]:
!cd dataset && java -jar ../parser/parser.jar -f data/test.json -d test

In [None]:
!cd dataset && java -jar ../parser/parser.jar -f data/train.json -d train

In [None]:
!cd dataset && java -jar ../parser/parser.jar -f data/valid.json -d valid

##Run the Model

In [None]:

%matplotlib inline
import sys
sys.path.append("./")
import pickle
import numpy as np
from tqdm import tqdm_notebook
from prefetch_generator import BackgroundGenerator
from matplotlib import pylab as plt
from IPython.display import clear_output
import os
from joblib import Parallel, delayed
from tqdm import tqdm
import nltk
from glob import glob
from joblib import Parallel, delayed
from collections import Counter
from layers import *
from utils import *
from models import *
import dataset
import json
import tensorflow as tf

!python3 dataset.py dataset/

In [None]:
checkpoint_dir = "models/checkpoints"

In [None]:
trn_data = read_pickle("dataset/nl/train.pkl")
vld_data = read_pickle("dataset/nl/valid.pkl")
tst_data = read_pickle("dataset/nl/test.pkl")
code_i2w = read_pickle("dataset/code_i2w.pkl")
code_w2i = read_pickle("dataset/code_w2i.pkl")
nl_i2w = read_pickle("dataset/nl_i2w.pkl")
nl_w2i = read_pickle("dataset/nl_w2i.pkl")

In [None]:
trn_x, trn_y_raw = zip(*trn_data.items())
vld_x, vld_y_raw = zip(*vld_data.items())
tst_x, tst_y_raw = zip(*tst_data.items())

In [None]:
trn_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i["<UNK>"] for t in l] for l in trn_y_raw]
vld_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i["<UNK>"] for t in l] for l in vld_y_raw]
tst_y = [[nl_w2i[t] if t in nl_w2i.keys() else nl_w2i["<UNK>"] for t in l] for l in tst_y_raw]

In [None]:
# model defining
class Model(BaseModel):
    def __init__(self, dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer=1, dropout=0.5, lr=1e-4):
        super(Model, self).__init__(dim_E, dim_F, dim_rep, in_vocab, out_vocab, layer, dropout, lr)
        self.E = TreeEmbeddingLayer(dim_E, in_vocab)
        self.encoder = ChildSumLSTMLayer(dim_E, dim_rep)
    
    def encode(self, trees):
        trees = self.E(trees)
        trees = self.encoder(trees)
        
        hx = tf.stack([tree.h for tree in trees])
        cx = tf.stack([tree.c for tree in trees])
        ys = [tf.stack([node.h for node in traverse(tree)]) for tree in trees]
        
        return ys, [hx, cx]

###Define model settings

In [None]:

model = Model(512, 512, 512, len(code_w2i), len(nl_w2i), dropout=0.5, lr=1e-4)
epochs = 15
batch_size = 64
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
root = tf.train.Checkpoint(model=model)
history = {"loss":[], "loss_val":[]}

###Define data generator settings

In [None]:

trn_gen = Datagen_tree(trn_x, trn_y, batch_size, code_w2i, nl_i2w, train=True)
vld_gen = Datagen_tree(vld_x, vld_y, batch_size, code_w2i, nl_i2w, train=False)
tst_gen = Datagen_tree(tst_x, tst_y, batch_size, code_w2i, nl_i2w, train=False)

###Start training

In [None]:
# training
for epoch in range(epochs):
    
    # train
    loss_tmp = []
    t = tqdm(trn_gen(epoch))
    for x, y, _, _ in t:
        loss_tmp.append(model.train_on_batch(x, y))
        t.set_description("epoch:{:03d}, loss = {}".format(epoch + 1, np.mean(loss_tmp)))
    history["loss"].append(np.sum(loss_tmp) / len(t))
    
    loss_tmp = []
    t = tqdm(vld_gen(epoch))
    for x, y, _, _ in t:
        loss_tmp.append(model.evaluate_on_batch(x, y))
        t.set_description("epoch:{:03d}, loss_val = {}".format(epoch + 1, np.mean(loss_tmp)))
    history["loss_val"].append(np.sum(loss_tmp) / len(t))
    
    # checkpoint
    if history["loss_val"][-1] == min(history["loss_val"]):
        checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
        root.save(file_prefix=checkpoint_prefix)
    
    # print
    clear_output()
    for key, val in history.items():
        if "loss" in key:
            plt.plot(val, label=key)
    plt.legend()
    plt.show()

In [None]:
root.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
preds = []
trues = []
for x, y, _, y_raw in tqdm(tst_gen(0)):
    res = model.translate(x, nl_i2w, nl_w2i)
    preds += res
    trues += [s[1:-1] for s in y_raw]

In [None]:
bleus = Parallel(n_jobs=-1)(delayed(bleu4)(t, p) for t, p in tqdm(list(zip(trues, preds))))

In [None]:
history["bleus"] = bleus
history["preds"] = preds
history["trues"] = trues
history["numbers"] = [int(x.split("/")[-1]) for x in tst_x]

In [None]:
with open(os.path.join(checkpoint_dir, "history.json"), "w") as f:
    json.dump(history, f)