In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import json
import json_lines
import preprocess_temp as P
import model.parsers as M

In [None]:
# pip install json-lines if you don't have json_lines

### Load Data

In [4]:
directory = './conala-corpus/'
train_file = directory + 'train.json'
test_file = directory + 'test.json'

with open(train_file) as f:
    train_data = json.load(f)
    
with open(test_file) as f:
    test_data = json.load(f)

In [12]:
mine_file = directory + 'mined.jsonl'
mine_data = []
with open(mine_file) as f:
    mine_data = [line for line in json_lines.reader(f)]

### Let's preprocess the data. Everything is in Preprocess.py
### Adding mined data

In [23]:
# intent processing includes lowercase, remove punctuation'?'
train_intent, train_codes = P.process_data(train_data)
test_intent, test_codes = P.process_data(test_data)

In [24]:
mine_intent, mine_codes = P.process_data(mine_data, mine=True)

In [25]:
# this class is used for code2actions and actions2code
ast_action = P.Ast_Action()

In [26]:
train_actions = []

for code in train_codes:
    train_actions.append(ast_action.code2actions(code))

In [27]:
word_lst = P.vocab_list(train_intent, cut_freq=5)
act_lst, token_lst = P.action_list(train_actions, cut_freq=5)

In [42]:
act_lst

[ApplyRule[cmpop -> Is()],
 ApplyRule[cmpop -> NotEq()],
 ApplyRule[expr -> SetComp(expr elt, comprehension* generators)],
 ApplyRule[operator -> Add()],
 ApplyRule[operator -> FloorDiv()],
 ApplyRule[keyword -> keyword(identifier? arg, expr value)],
 ApplyRule[expr -> ListComp(expr elt, comprehension* generators)],
 ApplyRule[stmt -> Return(expr? value)],
 ApplyRule[expr -> BinOp(expr left, operator op, expr right)],
 ApplyRule[cmpop -> Lt()],
 ApplyRule[expr -> Str(string s)],
 ApplyRule[unaryop -> USub()],
 ApplyRule[expr -> Name(identifier id)],
 ApplyRule[expr -> Dict(expr* keys, expr* values)],
 ApplyRule[slice -> Index(expr value)],
 ApplyRule[expr -> NameConstant(singleton value)],
 ApplyRule[operator -> BitOr()],
 GenToken[token],
 ApplyRule[cmpop -> IsNot()],
 ApplyRule[boolop -> And()],
 ApplyRule[stmt -> Try(stmt* body, excepthandler* handlers, stmt* orelse, stmt* finalbody)],
 ApplyRule[stmt -> If(expr test, stmt* body, stmt* orelse)],
 ApplyRule[expr -> Lambda(arguments a

In [28]:
word2num = dict(zip(word_lst, range(0,len(word_lst))))
act2num = dict(zip(act_lst, range(0,len(act_lst))))
token2num = dict(zip(token_lst, range(0,len(token_lst))))

In [29]:
train_loader = P.get_train_loader(train_intent, train_actions, word2num, act2num, token2num)

In [30]:
test_loader = P.get_test_loader(test_intent, word2num)

In [31]:
action_index_copy = act2num[P.GenTokenAction('copy')]
action_index_gen = act2num[P.GenTokenAction('token')]

### Model

In [None]:
import torch
import time

In [36]:
from collections import namedtuple
hyperParamMap = {
    #### General configuration ####
    'cuda': True,      # Use gpu
    'mode': 'train',   # train or test

    #### Embedding sizes ####
    'embed_size': 128,         # Size of word embeddings
    'action_embed_size': 128,  # Size of ApplyRule/GenToken action embeddings
    'field_embed_size': 64,    # Embedding size of ASDL fields
    'type_embed_size': 64,     # Embeddings ASDL types

    #### Decoding sizes ####
    'hidden_size': 256,        # Size of LSTM hidden states

    #### training schedule details ####
    'valid_metric': 'acc',                # Metric used for validation
    'valid_every_epoch': 1,               # Perform validation every x epoch
    'log_every': 30,                      # Log training statistics every n iterations
    'save_to': 'model',                   # Save trained model to
    'clip_grad': 5.,                      # Clip gradients
    'max_epoch': 10,                      # Maximum number of training epoches
    'optimizer': 'Adam',                  # optimizer
    'lr': 0.001,                          # Learning rate
    'lr_decay': 0.5,                      # decay learning rate if the validation performance drops
    'verbose': False,                     # Verbose mode

    #### decoding/validation/testing ####
    'load_model': None,                   # Load a pre-trained model
    'beam_size': 5,                       # Beam size for beam search
    'decode_max_time_step': 100,          # Maximum number of time steps used in decoding and sampling
    'sample_size': 5,                     # Sample size
    'test_file': '',                      # Path to the test file
    'save_decode_to': None,               # Save decoding results to file
}

HyperParams = namedtuple('HyperParams', list(hyperParamMap.keys()), verbose=False)
hyperParams = HyperParams(**hyperParamMap)

In [37]:
model = M.Model(hyperParams, action_size=len(act_lst), token_size=len(token_lst), word_size=len(word_lst), 
                      action_index_copy=action_index_copy, action_index_gen=action_index_gen)

In [40]:
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
lossFunc = torch.nn.CrossEntropyLoss()

In [41]:
epoch_begin = time.time()
for e in range(20):
    for batch_ind, x in enumerate(train_loader):
        optimizer.zero_grad()

        (action_logits, action_labels), (copy_logits, copy_labels), (token_logits, token_labels) = model(x)

        loss1 = lossFunc(action_logits, action_labels)
        loss2 = torch.DoubleTensor([0.0])
        if len(copy_logits) > 0:
            loss2 = lossFunc(copy_logits, copy_labels)
        loss3 = torch.DoubleTensor([0.0])
        if len(token_logits) > 0:
            loss3 = lossFunc(token_logits, token_labels)

        total_loss = loss1 + loss2.double() + loss3.double()
        total_loss.backward()

        # clip gradient
        if hyperParams.clip_grad > 0.:
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hyperParams.clip_grad)

        optimizer.step()

        if batch_ind % hyperParams.log_every == hyperParams.log_every - 1:
            print("Action loss: {}".format(loss1.data))
            print("Copy loss: {}".format(loss2.data))
            print("Token loss: {}".format(loss3.data))
            print('-------------------------------------------------------')
            report_loss = report_examples = 0.

    print('epoch elapsed %ds' % (time.time() - epoch_begin))

Action loss: 0.7167027586575205
Copy loss: 2.3950161933898926
Token loss: 2.809378147125244
-------------------------------------------------------
Action loss: 0.42719784815837963
Copy loss: 2.468154191970825
Token loss: 2.8486595153808594
-------------------------------------------------------
Action loss: 0.6902822944227888
Copy loss: 2.298081159591675
Token loss: 2.528050184249878
-------------------------------------------------------
Action loss: 0.7904933715868996
Copy loss: 2.854611396789551
Token loss: 2.4207215309143066
-------------------------------------------------------
epoch elapsed 85s
Action loss: 0.6646689647056626
Copy loss: 2.5164413452148438
Token loss: 1.9572882652282715
-------------------------------------------------------
Action loss: 0.6740780842317303
Copy loss: 2.474066972732544
Token loss: 2.2832720279693604
-------------------------------------------------------
Action loss: 0.515992467706028
Copy loss: 2.490504264831543
Token loss: 2.3165106773376465
--

Action loss: 0.41577124559641504
Copy loss: 1.6518689393997192
Token loss: 1.20344078540802
-------------------------------------------------------
Action loss: 0.2549308473095643
Copy loss: 1.8392798900604248
Token loss: 1.0195257663726807
-------------------------------------------------------
epoch elapsed 1220s
Action loss: 0.3016550911427151
Copy loss: 1.5269074440002441
Token loss: 0.9525985717773438
-------------------------------------------------------
Action loss: 0.3523757330724296
Copy loss: 1.7236050367355347
Token loss: 1.0393800735473633
-------------------------------------------------------
Action loss: 0.36739460570292143
Copy loss: 1.6697598695755005
Token loss: 1.332204818725586
-------------------------------------------------------
Action loss: 0.2840887575819544
Copy loss: 1.5513217449188232
Token loss: 1.0433471202850342
-------------------------------------------------------
epoch elapsed 1317s
Action loss: 0.3265489389519612
Copy loss: 1.7287156581878662
Token

In [43]:
torch.save((model).state_dict(), 'Parameters/frist.t7')

In [44]:
model.load_state_dict(torch.load('Parameters/frist.t7'))