In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import json_lines
import preprocess_temp as P
import model.parsers as M

### Load Data

In [18]:
directory = './conala-corpus/'
train_file = directory + 'train.json'
test_file = directory + 'test.json'

with open(train_file) as f:
    train_data = json.load(f)
    
with open(test_file) as f:
    test_data = json.load(f)

In [19]:
mine_file = directory + 'mined.jsonl'
mine_data = []
with open(mine_file) as f:
    mine_data = [line for line in json_lines.reader(f) if line['prob'] > 0.5]

### Let's preprocess the data. Everything is in Preprocess.py
### Adding mined data

In [20]:
# intent processing includes lowercase, remove punctuation'?'
train_intent, train_codes = P.process_data(train_data)
test_intent, test_codes = P.process_data(test_data)
mine_intent, mine_codes = P.process_data(mine_data, mine=True)
train_intent.extend(mine_intent)
train_codes.extend(mine_codes)

In [21]:
# this class is used for code2actions and actions2code
ast_action = P.Ast_Action()

In [22]:
train_actions = []

for code in train_codes:
    train_actions.append(ast_action.code2actions(code))

In [23]:
word_lst = P.vocab_list(train_intent, cut_freq=2)
act_lst, token_lst = P.action_list(train_actions, cut_freq=0)

In [24]:
word2num = dict(zip(word_lst, range(0,len(word_lst))))
act2num = dict(zip(act_lst, range(0,len(act_lst))))
token2num = dict(zip(token_lst, range(0,len(token_lst))))

In [25]:
train_loader = P.get_train_loader(train_intent, train_actions, word2num, act2num, token2num, batch_size=32)

In [26]:
test_loader = P.get_test_loader(test_intent, word2num, batch_size=1)

In [27]:
action_index_copy = act2num[P.GenTokenAction('copy')]
action_index_gen = act2num[P.GenTokenAction('token')]

### Model

In [28]:
import torch
import time

In [29]:
from collections import namedtuple
hyperParamMap = {
    #### General configuration ####
    'cuda': True,      # Use gpu
    'mode': 'train',   # train or test

    #### Embedding sizes ####
    'embed_size': 128,         # Size of word embeddings
    'action_embed_size': 128,  # Size of ApplyRule/GenToken action embeddings
    'field_embed_size': 64,    # Embedding size of ASDL fields
    'type_embed_size': 64,     # Embeddings ASDL types

    #### Decoding sizes ####
    'hidden_size': 256,        # Size of LSTM hidden states

    #### training schedule details ####
    'valid_metric': 'acc',                # Metric used for validation
    'valid_every_epoch': 1,               # Perform validation every x epoch
    'log_every': 10,                      # Log training statistics every n iterations
    'save_to': 'model',                   # Save trained model to
    'clip_grad': 5.,                      # Clip gradients
    'max_epoch': 10,                      # Maximum number of training epoches
    'optimizer': 'Adam',                  # optimizer
    'lr': 0.001,                          # Learning rate
    'lr_decay': 0.5,                      # decay learning rate if the validation performance drops
    'verbose': False,                     # Verbose mode

    #### decoding/validation/testing ####
    'load_model': None,                   # Load a pre-trained model
    'beam_size': 1,                       # Beam size for beam search
    'decode_max_time_step': 100,          # Maximum number of time steps used in decoding and sampling
    'sample_size': 5,                     # Sample size
    'test_file': '',                      # Path to the test file
    'save_decode_to': None,               # Save decoding results to file
}

HyperParams = namedtuple('HyperParams', list(hyperParamMap.keys()), verbose=False)
hyperParams = HyperParams(**hyperParamMap)

In [30]:
model = M.Model(hyperParams, action_size=len(act_lst), token_size=len(token_lst), word_size=len(word_lst), 
                      action_index_copy=action_index_copy, action_index_gen=action_index_gen)

In [31]:
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
lossFunc = torch.nn.CrossEntropyLoss()

In [32]:
model.train()
epoch_begin = time.time()
for e in range(20):
    for batch_ind, x in enumerate(train_loader):
        optimizer.zero_grad()

        (action_logits, action_labels), (copy_logits, copy_labels), (token_logits, token_labels) = model(x)

        loss1 = lossFunc(action_logits, action_labels)
        loss2 = torch.DoubleTensor([0.0])
        if len(copy_logits) > 0:
            loss2 = lossFunc(copy_logits, copy_labels)
        loss3 = torch.DoubleTensor([0.0])
        if len(token_logits) > 0:
            loss3 = lossFunc(token_logits, token_labels)

        total_loss = loss1 + loss2.double() + loss3.double()
        total_loss.backward()

        # clip gradient
        if hyperParams.clip_grad > 0.:
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hyperParams.clip_grad)

        optimizer.step()

        if batch_ind % hyperParams.log_every == hyperParams.log_every - 1:
            print('--------------------------epoch {} batch {}-----------------------------'.format(e, batch_ind))
            print("Action loss: {}".format(loss1.data))
            print("Copy loss: {}".format(loss2.data))
            print("Token loss: {}".format(loss3.data))
            report_loss = report_examples = 0.

    print('epoch elapsed %ds' % (time.time() - epoch_begin))

--------------------------epoch 0 batch 9-----------------------------
Action loss: 1.8443992640952438
Copy loss: 2.6568801403045654
Token loss: 6.890439510345459
--------------------------epoch 0 batch 19-----------------------------
Action loss: 1.4655785435146331
Copy loss: 2.672477960586548
Token loss: 6.080912113189697
--------------------------epoch 0 batch 29-----------------------------
Action loss: 1.3251404033556944
Copy loss: 2.1825685501098633
Token loss: 6.080993175506592
--------------------------epoch 0 batch 39-----------------------------
Action loss: 1.387730495106048
Copy loss: 2.527622699737549
Token loss: 6.132447242736816
--------------------------epoch 0 batch 49-----------------------------
Action loss: 1.4603556854743915
Copy loss: 2.5043530464172363
Token loss: 6.223443984985352
--------------------------epoch 0 batch 59-----------------------------
Action loss: 1.1725416518113112
Copy loss: 2.4335925579071045
Token loss: 6.008927345275879
--------------------

--------------------------epoch 3 batch 59-----------------------------
Action loss: 0.5010012641765161
Copy loss: 1.9128273725509644
Token loss: 3.956510066986084
--------------------------epoch 3 batch 69-----------------------------
Action loss: 0.5656788169271464
Copy loss: 1.8701472282409668
Token loss: 3.7626655101776123
--------------------------epoch 3 batch 79-----------------------------
Action loss: 0.5741970905588857
Copy loss: 1.9166284799575806
Token loss: 4.339343547821045
--------------------------epoch 3 batch 89-----------------------------
Action loss: 0.5120145068049531
Copy loss: 1.7828840017318726
Token loss: 4.155537128448486
--------------------------epoch 3 batch 99-----------------------------
Action loss: 0.5491406587679667
Copy loss: 1.554297924041748
Token loss: 4.465452194213867
--------------------------epoch 3 batch 109-----------------------------
Action loss: 0.5226903305408511
Copy loss: 1.9856289625167847
Token loss: 3.973316192626953
---------------

--------------------------epoch 6 batch 109-----------------------------
Action loss: 0.42125199817983766
Copy loss: 1.5742124319076538
Token loss: 3.315479278564453
--------------------------epoch 6 batch 119-----------------------------
Action loss: 0.6294878539810345
Copy loss: 1.6641607284545898
Token loss: 2.970217704772949
--------------------------epoch 6 batch 129-----------------------------
Action loss: 0.4331703285719059
Copy loss: 1.5983737707138062
Token loss: 3.4850645065307617
--------------------------epoch 6 batch 139-----------------------------
Action loss: 0.40034234549109426
Copy loss: 1.5500848293304443
Token loss: 3.5025980472564697
--------------------------epoch 6 batch 149-----------------------------
Action loss: 0.5064771819906946
Copy loss: 1.938191294670105
Token loss: 3.549039602279663
epoch elapsed 2498s
--------------------------epoch 7 batch 9-----------------------------
Action loss: 0.3955441176971877
Copy loss: 1.6661779880523682
Token loss: 2.99090

epoch elapsed 3598s
--------------------------epoch 10 batch 9-----------------------------
Action loss: 0.416733834029793
Copy loss: 1.6871261596679688
Token loss: 2.3548285961151123
--------------------------epoch 10 batch 19-----------------------------
Action loss: 0.47623544746469143
Copy loss: 1.7102437019348145
Token loss: 2.5351016521453857
--------------------------epoch 10 batch 29-----------------------------
Action loss: 0.4244802492466977
Copy loss: 1.7706047296524048
Token loss: 2.585294723510742
--------------------------epoch 10 batch 39-----------------------------
Action loss: 0.35098312124400005
Copy loss: 1.7347958087921143
Token loss: 2.226491928100586
--------------------------epoch 10 batch 49-----------------------------
Action loss: 0.5529000873877636
Copy loss: 2.013653516769409
Token loss: 2.2179598808288574
--------------------------epoch 10 batch 59-----------------------------
Action loss: 0.2946308706341243
Copy loss: 1.6628254652023315
Token loss: 2.5219

KeyboardInterrupt: 

In [33]:
torch.save((model).state_dict(), 'Parameters/frist.t7')

In [34]:
model.load_state_dict(torch.load('Parameters/frist.t7'))

In [41]:
import codecs
model.eval()
test_loader = P.get_test_loader(test_intent, word2num, batch_size=1)
testFile = codecs.open('./test_predict.txt','w','utf-8')
for i,(sample_sent, sample_sent_txt) in enumerate(test_loader):
    sample_hypothesis = model.parse(sample_sent, sample_sent_txt, act_lst, token_lst, ast_action)
    try:
        code = ast_action.actions2code(sample_hypothesis.actions)
    except:
        for i in range(len(sample_hypothesis.actions)):
            action = sample_hypothesis.actions[i]
            if isinstance(action, P.GenTokenAction):
                if action.token in ['list', 'reverse', 'range', 'address', 'in', 'and', 'range']:
                    action.token = '.'
        try:
            code = ast_action.actions2code(sample_hypothesis.actions)
        except:
            code = ''
#     print(test_intent[i])
    testFile.write(' '.join(test_intent[i])+':'+code+'\n')
testFile.close()

KeyboardInterrupt: 

In [42]:
testFile.close()