In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('./preprocessing')
sys.path.append('./seq2seq')

In [3]:
from processor import Code_Intent_Pairs
from model import Seq2Seq
from data import get_train_loader, get_test_loader

### Define Hyperparameters

In [4]:
hyperP = {
    ## training parameters
    'batch_size' : 32,
    'lr' : 1e-3,
    'teacher_force_rate' : 0.90,
    'max_epochs' : 50,
    'lr_keep_rate' : 0.95,  # set to 1.0 to not decrease lr overtime
    'load_pretrain_code_embed': False,
    'freeze_embed': False,
    
    ## encoder architecture
    'encoder_layers' : 2,
    'encoder_embed_size' : 128,
    'encoder_hidden_size' : 384,
    'encoder_dropout_rate' : 0.3,
    
    ## decoder architecture
    'decoder_layers' : 2,
    'decoder_embed_size' : 128,
    'decoder_hidden_size' : 384,
    'decoder_dropout_rate' : 0.3,
    
    ## attn architecture
    'attn_hidden_size' : 384,
    
    ## visualization
    'print_every': 10,
}

### Load Data

In [5]:
code_intent_pair = Code_Intent_Pairs()

In [6]:
path = 'vocab/'
code_intent_pair.load_dict(path)
special_symbols = code_intent_pair.get_special_symbols()
word_size = code_intent_pair.get_word_size()
code_size = code_intent_pair.get_code_size()

In [7]:
train_path = 'processed_corpus/train.json'
train_entries = code_intent_pair.load_entries(train_path)
code_intent_pair.pad()

In [8]:
trainloader = get_train_loader(train_entries, special_symbols, hyperP)

In [9]:
valid_path = 'processed_corpus/valid.json'
valid_entries = code_intent_pair.load_entries(valid_path)
code_intent_pair.pad()

In [10]:
validloader = get_train_loader(valid_entries, special_symbols, hyperP)

In [11]:
test_path = 'processed_corpus/test.json'
test_entries = code_intent_pair.load_entries(test_path)

In [12]:
testloader = get_test_loader(test_entries)

### Define Model

In [13]:
model = Seq2Seq(word_size, code_size, hyperP)

In [14]:
import torch
if hyperP['load_pretrain_code_embed']:
    model.decoder.embed[0].load_state_dict(torch.load('./pretrain_code_lm/embedding-1556211835.t7'))
    if hyperP['freeze_embed']:
        for param in model.decoder.embed[0].parameters():
            param.requires_grad = False

### Training

In [15]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
optimizer = optim.Adam(model.parameters(), lr=hyperP['lr'], weight_decay = 1e-4)
loss_f = torch.nn.CrossEntropyLoss()

In [16]:
lr_keep_rate = hyperP['lr_keep_rate']
if lr_keep_rate != 1.0:
    lr_reduce_f = lambda epoch: lr_keep_rate ** epoch
    scheduler = LambdaLR(optimizer, lr_lambda=lr_reduce_f)

In [17]:
def train(model, trainloader, optimizer, loss_f, hyperP):
    model.train()
    total_loss = 0
    loss_sum = 0
    total_correct = 0
    size = 0
    print_every = hyperP['print_every']
    
    for i, (inp_seq, original_out_seq, padded_out_seq, out_lens) in enumerate(trainloader):
        logits = model(inp_seq, padded_out_seq, out_lens)
        loss = loss_f(logits, original_out_seq)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # show stats
        loss_sum += loss.item()
        total_loss += loss.item()
        _, predictions = torch.max(logits, dim=1)
        total_correct += (predictions == original_out_seq).sum()
        size += len(original_out_seq)

        if (i+1) % print_every == 0:
            print('Train: loss:{}\tacc:{}'.format(loss_sum/print_every, float(total_correct)/size), end='\r')
            loss_sum = 0
            total_correct = 0
            size = 0
    print()
    return total_loss / len(trainloader)

### Decoder

In [19]:
from decoder import Decoder, post_process_dummy, sub_slotmap, tokenize_for_bleu_eval
from evaluate import get_bleu_sent

In [20]:
beam_decoder = Decoder(model)
sos = special_symbols['code_sos']
eos = special_symbols['code_eos']
unk = special_symbols['code_unk']

In [21]:
def decode_with_score(train_entry, decoder, idx2code, get_score, sub_slotmap):
    intent_idx = train_entry['intent_indx']
    true_code_idx = train_entry['code_indx_nocopy'][1:-1]
    slot_map = train_entry['slot_map']
    true_code = ' '.join(sub_slotmap(train_entry['code'], slot_map))

    inp_seq = torch.LongTensor([intent_idx])
    beams = decoder.decode(inp_seq, sos, eos, unk, beam_width = 3)
    gen_code_idx = [beam.path[:-1] for beam in beams]
    gen_code = [' '.join(sub_slotmap(idx2code(idx), slot_map)) for idx in gen_code_idx]

    slot_values = slot_map.values()
    slot_token_counts = {}
    for value in slot_values:
        slot_token_counts[value] = len(tokenize_for_bleu_eval(value))

    slotmap_used_counts = []
    for code in gen_code:
        slotmap_used_count = 0
        for value in slot_values:
            if value in code:
                slotmap_used_count += slot_token_counts[value]
        slotmap_used_counts.append(slotmap_used_count)

    scores = [get_score(code, true_code) for code in gen_code]
    output_entries = []
    for code,score,count_token in zip(gen_code_idx, scores, slotmap_used_counts):
        output_entries.append((intent_idx,code,count_token,score))
    return output_entries

In [22]:
def decode_all(train_entries, beam_decoder, idx2code, get_bleu_sent, sub_slotmap):
    epoch_entries = []
    for sample in train_entries:
        sample_entries = decode_with_score(sample, beam_decoder, 
                  code_intent_pair.idx2code, get_bleu_sent, sub_slotmap)
        epoch_entries.extend(sample_entries)
    return epoch_entries

In [25]:
out_entries = []
for e in range(20):
    train(model, trainloader, optimizer, loss_f, hyperP)
    epoch_entries = decode_all(train_entries, beam_decoder, 
                  code_intent_pair.idx2code, get_bleu_sent, sub_slotmap)
    out_entries.extend(epoch_entries)
    if lr_keep_rate != 1.0:
        scheduler.step()

Train: loss:2.8608489274978637	acc:0.36051810985847926
Train: loss:2.300374436378479	acc:0.449988006716238966
Train: loss:2.0361998438835145	acc:0.49028544015351405
Train: loss:1.918964922428131	acc:0.509234828496042246
Train: loss:1.7197949051856996	acc:0.5476133365315423
Train: loss:1.7212832808494567	acc:0.5420964260014391
Train: loss:1.5254679203033448	acc:0.5797553370112737
Train: loss:1.407399046421051	acc:0.60278244183257384
Train: loss:1.265366530418396	acc:0.64260014391940519
Train: loss:1.1647071063518524	acc:0.6553130247061646
Train: loss:1.1043188750743866	acc:0.6701846965699209
Train: loss:0.9861076891422271	acc:0.6932118013912214
Train: loss:0.9801946878433228	acc:0.7076037419045335
Train: loss:0.8901587128639221	acc:0.7392660110338211
Train: loss:0.8605209529399872	acc:0.7421443991364836
Train: loss:0.7322976946830749	acc:0.7800431758215399
Train: loss:0.7723296701908111	acc:0.7608539218037899
Train: loss:0.7026496529579163	acc:0.7865195490525306
Train: loss:0.6815170437

In [26]:
import json
with open('rerank_data.json', 'w') as f:
    json.dump(out_entries, f)