In [1]:
import sys
sys.path.append('..')

In [2]:
import os
import time
import math
import scipy.io as sio
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from multiprocessing import cpu_count

In [3]:
from ptb import PTB
from model import STCN



In [4]:
# device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print('GPU')
else:
    print('CPU')

GPU


In [5]:
# Penn TreeBank (PTB) dataset
data_path = '../data'
max_len = 96
splits = ['train', 'valid', 'test']
datasets = {split: PTB(root=data_path, split=split) for split in splits}

In [6]:
# data loader
batch_size = 20 #32
dataloaders = {split: DataLoader(datasets[split],
                                 batch_size=batch_size,
                                 shuffle=split=='train',
                                 num_workers=cpu_count(),
                                 pin_memory=torch.cuda.is_available())
                                 for split in splits}
symbols = datasets['train'].symbols

In [7]:
# STCN model
embedding_size = 300 # dimension of character embeddings
dropout_rate = 0.4
emb_dropout_rate = 0.1
levels = 3    # # of levels
nhid = 450    # number of hidden units per layer
num_chans = [nhid] * (levels)#[nhid] * (levels - 1) + [embedding_size]
model = STCN(vocab_size=datasets['train'].vocab_size,
            embed_size=embedding_size,
            num_channels=num_chans,
            bos_idx=symbols['<bos>'],
            eos_idx=symbols['<eos>'],
            pad_idx=symbols['<pad>'],
            dropout=dropout_rate,
            emb_dropout = emb_dropout_rate)
model = model.to(device)
print(model)

STCN(
  (encoder): Embedding(10002, 300, padding_idx=0)
  (tcn): TemporalConvNet(
    (network): Sequential(
      (0): TemporalBlock(
        (conv1): Conv1d(300, 450, kernel_size=(2,), stride=(1,), padding=(1,))
        (chomp1): Chomp1d()
        (relu1): ReLU()
        (dropout1): Dropout(p=0.4, inplace=False)
        (conv2): Conv1d(450, 450, kernel_size=(2,), stride=(1,), padding=(1,))
        (chomp2): Chomp1d()
        (relu2): ReLU()
        (dropout2): Dropout(p=0.4, inplace=False)
        (net): Sequential(
          (0): Conv1d(300, 450, kernel_size=(2,), stride=(1,), padding=(1,))
          (1): Chomp1d()
          (2): ReLU()
          (3): Dropout(p=0.4, inplace=False)
          (4): Conv1d(450, 450, kernel_size=(2,), stride=(1,), padding=(1,))
          (5): Chomp1d()
          (6): ReLU()
          (7): Dropout(p=0.4, inplace=False)
        )
        (downsample): Conv1d(300, 450, kernel_size=(1,), stride=(1,))
        (relu): ReLU()
      )
      (1): TemporalBlock(
 

In [8]:
# folder to save model
save_path = 'model'
if not os.path.exists(save_path):
    os.makedirs(save_path)

In [9]:
# objective function
learning_rate = 4
criterion = nn.CrossEntropyLoss(size_average=False, ignore_index=symbols['<pad>'])
optimizer = optim.SGD(model.parameters(), lr=learning_rate) #Adam
#optimizer = optim.Adam(model.parameters(), lr=learning_rate) #Adam


# negative log likelihood
def NLL(logp, target, length):
    target = target[:, :torch.max(length).item()].contiguous().view(-1)
    logp = logp[:, :torch.max(length).item(),:].contiguous().view(-1, logp.size(-1)) # logp = logp.view(-1, logp.size(-1))
    return criterion(logp, target)



In [10]:
# training setting
epoch = 20
print_every = 50

In [11]:
# training interface
step = 0
NLL_tracker = {'NLL': []}
KL_tracker = {'KL': []}
start_time = time.time()
beta = 1 #1.0/epoch
beta_increase = 0
for ep in range(epoch):
    # learning rate decay
    if (ep % 2 == 0) and (learning_rate>0.1) and (ep>=10):
        learning_rate = learning_rate * 0.5
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate

    for split in splits:
        dataloader = dataloaders[split]
        model.train() if split == 'train' else model.eval()
        totals = {'NLL': 0.,'KL': 0., 'words': 0}

        for itr, (_, dec_inputs, targets, lengths) in enumerate(dataloader):
            bsize = dec_inputs.size(0)
            dec_inputs = dec_inputs.to(device)
            targets = targets.to(device)
            lengths = lengths.to(device)

            # forward
            logp, NLL_loss, KL_loss = model(dec_inputs, lengths, targets) #, lengths

            # calculate loss
            #NLL_loss = NLL(logp, targets, lengths + 1)
            KL_loss = KL_loss
            loss = NLL_loss + beta*KL_loss

            # cumulate
            totals['NLL'] += NLL_loss.item()
            totals['KL'] += KL_loss.item()
            totals['words'] += torch.sum(lengths).item()
            NLL_loss = NLL_loss / bsize

            # backward and optimize
            if split == 'train':
                step += 1

                # track
                NLL_tracker['NLL'].append(NLL_loss.item())
                KL_tracker['KL'].append(KL_loss.item())

                # print statistics
                if itr % print_every == 0 or itr + 1 == len(dataloader):
                    print("%s Batch %04d/%04d, NLL-Loss %.4f, KL-Loss %.4f, "
                          % (split.upper(), itr, len(dataloader),
                             NLL_tracker['NLL'][-1], KL_tracker['KL'][-1]))
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 0.25) #5
                optimizer.step()

        samples = len(datasets[split])
        print("%s Epoch %02d/%02d, NLL %.4f, PPL %.4f"
              % (split.upper(), ep, epoch, totals['NLL'] / samples,
                 math.exp(totals['NLL'] / totals['words'])))
    # KL annealing
    #if beta<=1:
    #    beta+=beta_increase

    # save checkpoint
    checkpoint_path = os.path.join(save_path, "E%02d.pkl" % ep)
    torch.save(model.state_dict(), checkpoint_path)
    print("Model saved at %s\n" % checkpoint_path)
end_time = time.time()
print('Total cost time',
      time.strftime("%H hr %M min %S sec", time.gmtime(end_time - start_time)))

TRAIN Batch 0000/2104, NLL-Loss 213.6874, KL-Loss 0.0023, 
TRAIN Batch 0050/2104, NLL-Loss 177.9694, KL-Loss 1.2142, 
TRAIN Batch 0100/2104, NLL-Loss 209.5447, KL-Loss 4.1308, 
TRAIN Batch 0150/2104, NLL-Loss 153.7311, KL-Loss 6.5192, 
TRAIN Batch 0200/2104, NLL-Loss 158.4470, KL-Loss 9.0045, 
TRAIN Batch 0250/2104, NLL-Loss 192.6690, KL-Loss 9.6218, 
TRAIN Batch 0300/2104, NLL-Loss 148.7655, KL-Loss 11.4688, 
TRAIN Batch 0350/2104, NLL-Loss 191.5117, KL-Loss 14.3553, 
TRAIN Batch 0400/2104, NLL-Loss 158.5504, KL-Loss 13.1247, 
TRAIN Batch 0450/2104, NLL-Loss 162.8992, KL-Loss 13.5499, 
TRAIN Batch 0500/2104, NLL-Loss 166.5098, KL-Loss 13.2145, 
TRAIN Batch 0550/2104, NLL-Loss 158.5920, KL-Loss 13.1861, 
TRAIN Batch 0600/2104, NLL-Loss 134.4210, KL-Loss 16.0349, 
TRAIN Batch 0650/2104, NLL-Loss 178.1622, KL-Loss 16.1133, 
TRAIN Batch 0700/2104, NLL-Loss 162.0838, KL-Loss 16.2961, 
TRAIN Batch 0750/2104, NLL-Loss 137.1474, KL-Loss 14.9918, 
TRAIN Batch 0800/2104, NLL-Loss 141.2554, KL-L

VALID Epoch 02/20, NLL 118.0831, PPL 285.2490
TEST Epoch 02/20, NLL 117.5684, PPL 276.0828
Model saved at model/E02.pkl

TRAIN Batch 0000/2104, NLL-Loss 100.4766, KL-Loss 10.9670, 
TRAIN Batch 0050/2104, NLL-Loss 132.9491, KL-Loss 14.1902, 
TRAIN Batch 0100/2104, NLL-Loss 105.1902, KL-Loss 11.6031, 
TRAIN Batch 0150/2104, NLL-Loss 123.3560, KL-Loss 12.2202, 
TRAIN Batch 0200/2104, NLL-Loss 139.1994, KL-Loss 13.2129, 
TRAIN Batch 0250/2104, NLL-Loss 134.4068, KL-Loss 11.1589, 
TRAIN Batch 0300/2104, NLL-Loss 132.3535, KL-Loss 12.2258, 
TRAIN Batch 0350/2104, NLL-Loss 126.4219, KL-Loss 12.0494, 
TRAIN Batch 0400/2104, NLL-Loss 118.0922, KL-Loss 12.4528, 
TRAIN Batch 0450/2104, NLL-Loss 124.1716, KL-Loss 12.4612, 
TRAIN Batch 0500/2104, NLL-Loss 140.1402, KL-Loss 12.9144, 
TRAIN Batch 0550/2104, NLL-Loss 135.3824, KL-Loss 12.1504, 
TRAIN Batch 0600/2104, NLL-Loss 135.9822, KL-Loss 12.8087, 
TRAIN Batch 0650/2104, NLL-Loss 107.5965, KL-Loss 11.9204, 
TRAIN Batch 0700/2104, NLL-Loss 111.857

TRAIN Batch 2100/2104, NLL-Loss 124.1607, KL-Loss 13.9838, 
TRAIN Batch 2103/2104, NLL-Loss 151.5496, KL-Loss 5.1151, 
TRAIN Epoch 05/20, NLL 118.7888, PPL 278.8084
VALID Epoch 05/20, NLL 111.0065, PPL 203.2751
TEST Epoch 05/20, NLL 110.4232, PPL 196.1931
Model saved at model/E05.pkl

TRAIN Batch 0000/2104, NLL-Loss 101.8682, KL-Loss 14.0994, 
TRAIN Batch 0050/2104, NLL-Loss 107.6651, KL-Loss 13.6556, 
TRAIN Batch 0100/2104, NLL-Loss 97.7072, KL-Loss 15.1031, 
TRAIN Batch 0150/2104, NLL-Loss 117.1303, KL-Loss 15.1650, 
TRAIN Batch 0200/2104, NLL-Loss 110.2420, KL-Loss 14.2601, 
TRAIN Batch 0250/2104, NLL-Loss 119.8236, KL-Loss 15.1619, 
TRAIN Batch 0300/2104, NLL-Loss 121.3206, KL-Loss 14.2007, 
TRAIN Batch 0350/2104, NLL-Loss 105.9515, KL-Loss 14.1247, 
TRAIN Batch 0400/2104, NLL-Loss 97.9397, KL-Loss 15.5557, 
TRAIN Batch 0450/2104, NLL-Loss 89.2013, KL-Loss 14.4863, 
TRAIN Batch 0500/2104, NLL-Loss 129.0923, KL-Loss 15.4910, 
TRAIN Batch 0550/2104, NLL-Loss 130.0247, KL-Loss 15.5536

TRAIN Batch 1950/2104, NLL-Loss 100.6757, KL-Loss 16.0575, 
TRAIN Batch 2000/2104, NLL-Loss 110.5975, KL-Loss 16.0856, 
TRAIN Batch 2050/2104, NLL-Loss 112.3228, KL-Loss 14.8676, 
TRAIN Batch 2100/2104, NLL-Loss 109.3163, KL-Loss 19.2919, 
TRAIN Batch 2103/2104, NLL-Loss 125.7295, KL-Loss 6.9511, 
TRAIN Epoch 08/20, NLL 113.1405, PPL 213.3205
VALID Epoch 08/20, NLL 108.7406, PPL 182.3775
TEST Epoch 08/20, NLL 108.2353, PPL 176.7091
Model saved at model/E08.pkl

TRAIN Batch 0000/2104, NLL-Loss 118.0530, KL-Loss 16.7398, 
TRAIN Batch 0050/2104, NLL-Loss 114.7296, KL-Loss 18.2685, 
TRAIN Batch 0100/2104, NLL-Loss 112.4033, KL-Loss 16.3792, 
TRAIN Batch 0150/2104, NLL-Loss 105.4946, KL-Loss 17.8274, 
TRAIN Batch 0200/2104, NLL-Loss 119.4603, KL-Loss 17.4756, 
TRAIN Batch 0250/2104, NLL-Loss 95.0120, KL-Loss 18.9807, 
TRAIN Batch 0300/2104, NLL-Loss 120.4565, KL-Loss 17.4881, 
TRAIN Batch 0350/2104, NLL-Loss 103.1412, KL-Loss 17.6078, 
TRAIN Batch 0400/2104, NLL-Loss 89.5349, KL-Loss 14.715

TRAIN Batch 1800/2104, NLL-Loss 106.5377, KL-Loss 21.5133, 
TRAIN Batch 1850/2104, NLL-Loss 121.7836, KL-Loss 21.3907, 
TRAIN Batch 1900/2104, NLL-Loss 105.2835, KL-Loss 21.7708, 
TRAIN Batch 1950/2104, NLL-Loss 109.2510, KL-Loss 20.2747, 
TRAIN Batch 2000/2104, NLL-Loss 106.3537, KL-Loss 21.0532, 
TRAIN Batch 2050/2104, NLL-Loss 114.1665, KL-Loss 23.5616, 
TRAIN Batch 2100/2104, NLL-Loss 92.1216, KL-Loss 21.9809, 
TRAIN Batch 2103/2104, NLL-Loss 84.1608, KL-Loss 8.7483, 
TRAIN Epoch 11/20, NLL 106.7919, PPL 157.8861
VALID Epoch 11/20, NLL 105.5202, PPL 156.3187
TEST Epoch 11/20, NLL 104.8050, PPL 149.9810
Model saved at model/E11.pkl

TRAIN Batch 0000/2104, NLL-Loss 95.6496, KL-Loss 22.3237, 
TRAIN Batch 0050/2104, NLL-Loss 86.7402, KL-Loss 21.2024, 
TRAIN Batch 0100/2104, NLL-Loss 110.7623, KL-Loss 23.4717, 
TRAIN Batch 0150/2104, NLL-Loss 114.0110, KL-Loss 25.1889, 
TRAIN Batch 0200/2104, NLL-Loss 109.5431, KL-Loss 24.1470, 
TRAIN Batch 0250/2104, NLL-Loss 118.6796, KL-Loss 22.6646,

TRAIN Batch 1700/2104, NLL-Loss 113.9020, KL-Loss 26.5436, 
TRAIN Batch 1750/2104, NLL-Loss 85.0693, KL-Loss 24.7703, 
TRAIN Batch 1800/2104, NLL-Loss 110.9129, KL-Loss 27.1475, 
TRAIN Batch 1850/2104, NLL-Loss 125.4638, KL-Loss 27.1542, 
TRAIN Batch 1900/2104, NLL-Loss 107.5491, KL-Loss 26.3185, 
TRAIN Batch 1950/2104, NLL-Loss 95.0682, KL-Loss 25.2271, 
TRAIN Batch 2000/2104, NLL-Loss 138.5541, KL-Loss 26.1550, 
TRAIN Batch 2050/2104, NLL-Loss 101.4465, KL-Loss 27.1641, 
TRAIN Batch 2100/2104, NLL-Loss 114.5428, KL-Loss 25.0872, 
TRAIN Batch 2103/2104, NLL-Loss 109.6765, KL-Loss 10.5585, 
TRAIN Epoch 14/20, NLL 102.8154, PPL 130.7639
VALID Epoch 14/20, NLL 104.3530, PPL 147.8233
TEST Epoch 14/20, NLL 103.6138, PPL 141.6781
Model saved at model/E14.pkl

TRAIN Batch 0000/2104, NLL-Loss 84.9481, KL-Loss 23.3886, 
TRAIN Batch 0050/2104, NLL-Loss 98.7855, KL-Loss 25.3492, 
TRAIN Batch 0100/2104, NLL-Loss 103.3992, KL-Loss 27.9790, 
TRAIN Batch 0150/2104, NLL-Loss 132.2521, KL-Loss 26.9834

TRAIN Batch 1600/2104, NLL-Loss 71.0324, KL-Loss 24.3764, 
TRAIN Batch 1650/2104, NLL-Loss 90.0773, KL-Loss 26.1654, 
TRAIN Batch 1700/2104, NLL-Loss 95.5534, KL-Loss 27.0190, 
TRAIN Batch 1750/2104, NLL-Loss 116.6957, KL-Loss 27.7648, 
TRAIN Batch 1800/2104, NLL-Loss 100.8306, KL-Loss 27.8169, 
TRAIN Batch 1850/2104, NLL-Loss 120.3904, KL-Loss 30.1184, 
TRAIN Batch 1900/2104, NLL-Loss 108.0135, KL-Loss 26.1729, 
TRAIN Batch 1950/2104, NLL-Loss 112.2885, KL-Loss 27.6686, 
TRAIN Batch 2000/2104, NLL-Loss 95.5196, KL-Loss 26.0156, 
TRAIN Batch 2050/2104, NLL-Loss 75.3605, KL-Loss 24.3454, 
TRAIN Batch 2100/2104, NLL-Loss 101.8739, KL-Loss 26.6925, 
TRAIN Batch 2103/2104, NLL-Loss 84.4714, KL-Loss 10.7157, 
TRAIN Epoch 17/20, NLL 101.4257, PPL 122.4279
VALID Epoch 17/20, NLL 104.2535, PPL 147.1209
TEST Epoch 17/20, NLL 103.4711, PPL 140.7147
Model saved at model/E17.pkl

TRAIN Batch 0000/2104, NLL-Loss 116.7305, KL-Loss 27.7590, 
TRAIN Batch 0050/2104, NLL-Loss 139.0717, KL-Loss 28.0671, 

In [12]:
print('# of parameters:', sum(param.numel() for param in model.parameters()))

# of parameters: 23414052


In [12]:
# save learning results
sio.savemat("results.mat", NLL_tracker)