In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [None]:
import os
from tqdm import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR

from zerogercrnn.lib.data.general import DataGenerator
from zerogercrnn.lib.train.run import TrainEpochRunner

from zerogercrnn.experiments.js.ast_level.data import ASTDataGenerator, SourceFile, DataReader, MockDataReader
from zerogercrnn.experiments.js.ast_level.network_base_lstm import JSBaseModel
from zerogercrnn.experiments.js.ast_level.train import ASTRoutine
from zerogercrnn.experiments.js.ast_level.raw_data import ENCODING

In [4]:
DIR_DATASET = '/Users/zerogerc/Documents/datasets/js_dataset.tar/processed'

FILE_TRAINING = os.path.join(DIR_DATASET, 'programs_training_one_hot.json')
FILE_EVAL = os.path.join(DIR_DATASET, 'programs_eval_one_hot.json')

SEQ_LEN = 50
BATCH_SIZE = 80
LEARNING_RATE = 0.05
EPOCHS = 50
DECAY_AFTER_EPOCH = 0

EMBEDDING_SIZE = 100
HIDDEN_SIZE = 500
NUM_LAYERS = 1
DROPOUT = 0.01

NON_TERMINALS_SIZE = 96 + 1 # 96 + EOF
TERMINALS_SIZE = 50000 + 2 # 50000 + EMPTY + UNKNOWN

In [5]:
# reader = DataReader(
#     file_training=FILE_TRAINING,
#     file_eval=FILE_EVAL,
#     encoding=ENCODING
# )
reader = MockDataReader()

In [6]:
data_generator = ASTDataGenerator(
    data_reader=reader,
    seq_len=SEQ_LEN,
    batch_size=BATCH_SIZE
)

100%|██████████| 100/100 [00:00<00:00, 17652.05it/s]
100%|██████████| 100/100 [00:00<00:00, 19309.01it/s]


In [7]:
# performance measure (time needed to iterate through whole epoch of train data)
import time

start = time.time()
for data in data_generator.get_train_generator():
    N, T = data[0], data[1]
end = time.time()
print('{}s'.format(end - start))

0.029071807861328125s


In [8]:
network = JSBaseModel(
    non_terminal_vocab_size = NON_TERMINALS_SIZE,
    terminal_vocab_size = TERMINALS_SIZE,
    embedding_size = EMBEDDING_SIZE,
    hidden_size = HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT
)

In [9]:
optimizer = optim.Adam(params=network.parameters(), lr=LEARNING_RATE)
scheduler = MultiStepLR(
    optimizer=optimizer,
    milestones=list(range(DECAY_AFTER_EPOCH, EPOCHS + 1)),
    gamma=0.95
)
base_criterion = nn.NLLLoss()

In [10]:
def criterion(n_output, n_target):
    """Expect n_output and n_target to be pair of (N, T).
        Return loss as a sum of NLL losses for non-terminal(N) and terminal(T). 
    """
    sz_non_terminal = n_output[0].size()[-1]
    # flatten tensors to compute NLLLoss
    loss_non_terminal = base_criterion(
        n_output[0].view(-1, sz_non_terminal),
        n_target[0].view(-1)
    )
        
    sz_terminal = n_output[1].size()[-1]
    # flatten tensors to compute NLLLoss
    loss_terminal = base_criterion(
        n_output[1].view(-1, sz_terminal),
        n_target[1].view(-1)
    )
    
    return loss_non_terminal + loss_terminal

In [11]:
train_routine = ASTRoutine(
    network=network,
    criterion=criterion,
    optimizer=optimizer
)

validation_routine = ASTRoutine(
    network=network,
    criterion=criterion,
    optimizer=optimizer
)

In [12]:
runner = TrainEpochRunner(
    network=network,
    train_routine=train_routine,
    validation_routine=validation_routine,
    data_generator=data_generator,
    scheduler=scheduler,
    plotter='visdom'
    # save_dir=os.path.join(os.getcwd(), 'saved_models')
)

In [13]:
runner.run(number_of_epochs=EPOCHS)

start
squeeze
embedded
permuted
lstmed
non_terminal_converted
terminal_converted
finish


  0%|          | 0/50 [00:00<?, ?it/s]

New point to validation: 0 15.357542991638184
Epoch: -1, Average loss: 15.357542991638184
start
squeeze
embedded
permuted
lstmed
non_terminal_converted
terminal_converted
finish
Loss: 4.177548885345459
New point to train: 0 4.177548885345459
start
squeeze
embedded
permuted
lstmed
non_terminal_converted
terminal_converted
finish
Loss: 24.06134605407715
New point to train: 1 24.06134605407715
start
squeeze
embedded
permuted
lstmed
non_terminal_converted
-----------------------------------------------------------------------------------------
Exiting from training early
