In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
from preprocessing.processor import Code_Intent_Pairs, sub_slotmap
from seq2seq2.model import Seq2Seq
from seq2seq2.data import get_train_loader, get_test_loader

### Define Hyperparameters

In [25]:
hyperP = {
    ## training parameters
    'batch_size' : 32,
    'lr' : 1e-4,
    'teacher_force_rate' : 0.85,
    'max_epochs' : 20,
    'lr_keep_rate' : 0.95,  # set to 1.0 to not decrease lr overtime
    
    ## encoder architecture
    'encoder_layers' : 2,
    'encoder_embed_size' : 128,
    'encoder_hidden_size' : 384,
    'encoder_dropout_rate' : 0.3,
    
    ## decoder architecture
    'decoder_layers' : 2,
    'decoder_embed_size' : 128,
    'decoder_hidden_size' : 384,
    'decoder_dropout_rate' : 0.3,
    
    ## attn architecture
    'attn_hidden_size' : 384,
    
    ## visualization
    'print_every': 5,
}

### Load Data

In [8]:
code_intent_pair = Code_Intent_Pairs()

In [9]:
path = 'vocab/'
code_intent_pair.load_dict(path)
special_symbols = code_intent_pair.get_special_symbols()
word_size = code_intent_pair.get_word_size()
code_size = code_intent_pair.get_code_size()

In [10]:
train_path = 'processed_corpus/train.json'
train_entries = code_intent_pair.load_entries(train_path)
code_intent_pair.pad()

In [11]:
trainloader = get_train_loader(train_entries, special_symbols, hyperP)

In [12]:
test_path = 'processed_corpus/test.json'
test_entries = code_intent_pair.load_entries(test_path)

In [13]:
testloader = get_test_loader(test_entries)

### Define Model

In [15]:
model = Seq2Seq(word_size, code_size, hyperP)

### Training

In [23]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
optimizer = optim.Adam(model.parameters(), lr=hyperP['lr'])
loss_f = torch.nn.CrossEntropyLoss()

In [26]:
lr_keep_rate = hyperP['lr_keep_rate']
if lr_keep_rate != 1.0:
    lr_reduce_f = lambda epoch: lr_keep_rate ** epoch
    scheduler = LambdaLR(optimizer, lr_lambda=lr_reduce_f)

In [17]:
def train(model, trainloader, optimizer, loss_f, hyperP):
    model.train()
    loss_sum = 0
    total_correct = 0
    size = 0
    print_every = hyperP['print_every']
    
    for i, (inp_seq, original_out_seq, padded_out_seq, out_lens) in enumerate(trainloader):
        logits = model(inp_seq, padded_out_seq, out_lens)
        loss = loss_f(logits, original_out_seq)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # show stats
        loss_sum += loss.item()
        _, predictions = torch.max(logits, dim=1)
        total_correct += (predictions == original_out_seq).sum()
        size += len(original_out_seq)

        if (i+1) % print_every == 0:
            print('Train: loss:{}\tacc:{}'.format(loss_sum/print_every, float(total_correct)/size), end='\r')
            loss_sum = 0
            total_correct = 0
            size = 0

In [18]:
def valid(model, validloader, loss_f, hyperP):
    model.eval()
    loss_sum = 0
    total_correct = 0
    size = 0
    print_every = hyperP['print_every']
    
    with torch.no_grad():
        for i, (inp_seq, original_out_seq, padded_out_seq, out_lens) in enumerate(validloader):
            logits = model(inp_seq, padded_out_seq, out_lens)
            loss = loss_f(logits, original_out_seq)

            # show stats
            loss_sum += loss.item()
            _, predictions = torch.max(logits, dim=1)
            total_correct += (predictions == original_out_seq).sum()
            size += len(original_out_seq)

    print('Valid: loss:{}\tacc:{}'.format(loss_sum/len(validloader), float(total_correct)/size), end='\r')
    return float(total_correct)/size

In [29]:
best_acc = 0.0
for e in range(20):
    train(model, trainloader, optimizer, loss_f, hyperP)
    acc = valid(model, trainloader, loss_f, hyperP)
    if acc > best_acc:
        best_acc = acc
        model.save()
        print()
        print('model saved')
    if lr_keep_rate != 1.0:
        scheduler.step()

Valid: loss:2.6533851718902586	acc:0.39269836596496394
model saved
Valid: loss:2.5418078819910686	acc:0.40965699985278964
model saved
Valid: loss:2.4568390878041586	acc:0.42841160017665247
model saved
Valid: loss:2.3800559059778847	acc:0.44010010304725455
model saved
Valid: loss:2.2797077385584514	acc:0.45641101133519855
model saved
Valid: loss:2.2322295268376666	acc:0.45938466068011186
model saved
Valid: loss:2.1593027369181317	acc:0.47228028853231277
model saved
Valid: loss:2.097739086151123	acc:0.483586044457529856
model saved
Valid: loss:2.076807084083557	acc:0.484704843220962746
model saved
Valid: loss:2.022046955426534	acc:0.495627852200794974
model saved
Valid: loss:1.9812824646631877	acc:0.50557927278080373
model saved
Valid: loss:1.9510851414998371	acc:0.50755189165317246
model saved
Valid: loss:1.9111504173278808	acc:0.51391137936110786
model saved
Valid: loss:1.9007323853174845	acc:0.51591344030619763
model saved
Valid: loss:1.840481309890747	acc:0.527984690122184625
model s

KeyboardInterrupt: 

In [33]:
model.load()

### Decoding

In [54]:
sos = special_symbols['code_sos']
eos = special_symbols['code_eos']
for i, (src_seq, slot_map, code, intent) in enumerate(testloader):
    model.eval()
    seq = model.greedy_decode(src_seq, sos, eos)
    gen_code_tokens = code_intent_pair.idx2code(seq)
    gen_code = sub_slotmap(gen_code_tokens, slot_map)
    print('intent:\t'+intent)
    print('predicted:\t'+gen_code+'\nground_truth:\t'+code)
    print()
    
    if i == 50:
        break

intent:	send a signal `signal.SIGUSR1` to the current process
predicted:	os . <unk> ( 'signal.SIGUSR1' )
ground_truth:	os.kill(os.getpid(), signal.SIGUSR1)

intent:	decode a hex string '4a4b4c' to UTF-8.
predicted:	"""str_0""" . split ( )
ground_truth:	bytes.fromhex('4a4b4c').decode('utf-8')

intent:	check if all elements in list `myList` are identical
predicted:	[ ( i , myList ) for i in myList ]
ground_truth:	all(x == myList[0] for x in myList)

intent:	format number of spaces between strings `Python`, `:` and `Very Good` to be `20`
predicted:	Python . <unk> ( ':' , 'Very Good' )
ground_truth:	print('%*s : %*s' % (20, 'Python', 20, 'Very Good'))

intent:	How to convert a string from CP-1251 to UTF-8?
predicted:	<unk> . <unk> ( <unk> , <unk> )
ground_truth:	d.decode('cp1251').encode('utf8')

intent:	get rid of None values in dictionary `kwargs`
predicted:	kwargs . <unk> ( )
ground_truth:	res = {k: v for k, v in list(kwargs.items()) if v is not None}

intent:	get rid of None values in 