In [1]:
import sys
sys.path.append('..')

In [2]:
import os
import time
import math
import scipy.io as sio
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from multiprocessing import cpu_count

In [3]:
from ptb import PTB
from model import TCN



In [4]:
# device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print('GPU')
else:
    print('CPU')

GPU


In [5]:
# Penn TreeBank (PTB) dataset
data_path = '../data'
max_len = 96
splits = ['train', 'valid', 'test']
datasets = {split: PTB(root=data_path, split=split) for split in splits}

In [6]:
# data loader
batch_size = 20 #32
dataloaders = {split: DataLoader(datasets[split],
                                 batch_size=batch_size,
                                 shuffle=split=='train',
                                 num_workers=cpu_count(),
                                 pin_memory=torch.cuda.is_available())
                                 for split in splits}
symbols = datasets['train'].symbols

In [7]:
# TCN model
embedding_size = 300 # dimension of character embeddings
dropout_rate = 0.1
emb_dropout_rate = 0.1
levels = 3    # # of levels
nhid = 450    # number of hidden units per layer
num_chans = [nhid] * (levels - 1) + [embedding_size]
model = TCN(vocab_size=datasets['train'].vocab_size,
            embed_size=embedding_size,
            num_channels=num_chans,
            bos_idx=symbols['<bos>'],
            eos_idx=symbols['<eos>'],
            pad_idx=symbols['<pad>'],
            dropout=dropout_rate,
            emb_dropout = emb_dropout_rate)
model = model.to(device)
print(model)

TCN(
  (encoder): Embedding(10002, 300, padding_idx=0)
  (tcn): TemporalConvNet(
    (network): Sequential(
      (0): TemporalBlock(
        (conv1): Conv1d(300, 450, kernel_size=(2,), stride=(1,), padding=(1,))
        (chomp1): Chomp1d()
        (relu1): ReLU()
        (dropout1): Dropout(p=0.1)
        (conv2): Conv1d(450, 450, kernel_size=(2,), stride=(1,), padding=(1,))
        (chomp2): Chomp1d()
        (relu2): ReLU()
        (dropout2): Dropout(p=0.1)
        (net): Sequential(
          (0): Conv1d(300, 450, kernel_size=(2,), stride=(1,), padding=(1,))
          (1): Chomp1d()
          (2): ReLU()
          (3): Dropout(p=0.1)
          (4): Conv1d(450, 450, kernel_size=(2,), stride=(1,), padding=(1,))
          (5): Chomp1d()
          (6): ReLU()
          (7): Dropout(p=0.1)
        )
        (downsample): Conv1d(300, 450, kernel_size=(1,), stride=(1,))
        (relu): ReLU()
      )
      (1): TemporalBlock(
        (conv1): Conv1d(450, 450, kernel_size=(2,), stride=(1,

In [8]:
# folder to save model
save_path = 'model'
if not os.path.exists(save_path):
    os.makedirs(save_path)

In [9]:
# objective function
learning_rate = 4
criterion = nn.CrossEntropyLoss(size_average=False, ignore_index=symbols['<pad>'])
optimizer = optim.SGD(model.parameters(), lr=learning_rate) #Adam

# negative log likelihood
def NLL(logp, target, length):
    target = target[:, :torch.max(length).item()].contiguous().view(-1)
    logp = logp[:, :torch.max(length).item(),:].contiguous().view(-1, logp.size(-1)) # logp = logp.view(-1, logp.size(-1))
    return criterion(logp, target)



In [10]:
# training setting
epoch = 20
print_every = 50

In [11]:
# training interface
step = 0
tracker = {'NLL': []}
start_time = time.time()
for ep in range(epoch):
    # learning rate decay
    if (ep % 2 == 0) and (learning_rate>0.1):
        learning_rate = learning_rate * 1 #0.5
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate

    for split in splits:
        dataloader = dataloaders[split]
        model.train() if split == 'train' else model.eval()
        totals = {'NLL': 0., 'words': 0}

        for itr, (_, dec_inputs, targets, lengths) in enumerate(dataloader):
            bsize = dec_inputs.size(0)
            dec_inputs = dec_inputs.to(device)
            targets = targets.to(device)
            lengths = lengths.to(device)

            # forward
            logp, NLL_loss = model(dec_inputs, lengths, targets) #, lengths

            # calculate loss
            #NLL_loss = NLL(logp, targets, lengths + 1)
            loss = NLL_loss / bsize

            # cumulate
            totals['NLL'] += NLL_loss.item()
            totals['words'] += torch.sum(lengths).item()

            # backward and optimize
            if split == 'train':
                step += 1
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 0.25) #5
                optimizer.step()

                # track
                tracker['NLL'].append(loss.item())

                # print statistics
                if itr % print_every == 0 or itr + 1 == len(dataloader):
                    print("%s Batch %04d/%04d, NLL-Loss %.4f, "
                          % (split.upper(), itr, len(dataloader),
                             tracker['NLL'][-1]))

        samples = len(datasets[split])
        print("%s Epoch %02d/%02d, NLL %.4f, PPL %.4f"
              % (split.upper(), ep, epoch, totals['NLL'] / samples,
                 math.exp(totals['NLL'] / totals['words'])))

    # save checkpoint
    checkpoint_path = os.path.join(save_path, "E%02d.pkl" % ep)
    torch.save(model.state_dict(), checkpoint_path)
    print("Model saved at %s\n" % checkpoint_path)
end_time = time.time()
print('Total cost time',
      time.strftime("%H hr %M min %S sec", time.gmtime(end_time - start_time)))

TRAIN Batch 0000/2104, NLL-Loss 225.1025, 
TRAIN Batch 0050/2104, NLL-Loss 145.5520, 
TRAIN Batch 0100/2104, NLL-Loss 125.8556, 
TRAIN Batch 0150/2104, NLL-Loss 162.6817, 
TRAIN Batch 0200/2104, NLL-Loss 142.8749, 
TRAIN Batch 0250/2104, NLL-Loss 145.2492, 
TRAIN Batch 0300/2104, NLL-Loss 177.6509, 
TRAIN Batch 0350/2104, NLL-Loss 154.1612, 
TRAIN Batch 0400/2104, NLL-Loss 137.9166, 
TRAIN Batch 0450/2104, NLL-Loss 131.1361, 
TRAIN Batch 0500/2104, NLL-Loss 155.6833, 
TRAIN Batch 0550/2104, NLL-Loss 132.4079, 
TRAIN Batch 0600/2104, NLL-Loss 131.7718, 
TRAIN Batch 0650/2104, NLL-Loss 128.9045, 
TRAIN Batch 0700/2104, NLL-Loss 148.2869, 
TRAIN Batch 0750/2104, NLL-Loss 160.0674, 
TRAIN Batch 0800/2104, NLL-Loss 146.0038, 
TRAIN Batch 0850/2104, NLL-Loss 159.9478, 
TRAIN Batch 0900/2104, NLL-Loss 128.5474, 
TRAIN Batch 0950/2104, NLL-Loss 143.5488, 
TRAIN Batch 1000/2104, NLL-Loss 140.3116, 
TRAIN Batch 1050/2104, NLL-Loss 139.0043, 
TRAIN Batch 1100/2104, NLL-Loss 126.3233, 
TRAIN Batch

TRAIN Batch 0000/2104, NLL-Loss 141.1685, 
TRAIN Batch 0050/2104, NLL-Loss 101.1680, 
TRAIN Batch 0100/2104, NLL-Loss 105.3169, 
TRAIN Batch 0150/2104, NLL-Loss 104.3819, 
TRAIN Batch 0200/2104, NLL-Loss 83.6900, 
TRAIN Batch 0250/2104, NLL-Loss 113.3966, 
TRAIN Batch 0300/2104, NLL-Loss 110.9214, 
TRAIN Batch 0350/2104, NLL-Loss 127.3607, 
TRAIN Batch 0400/2104, NLL-Loss 102.2555, 
TRAIN Batch 0450/2104, NLL-Loss 124.4204, 
TRAIN Batch 0500/2104, NLL-Loss 116.2626, 
TRAIN Batch 0550/2104, NLL-Loss 132.6036, 
TRAIN Batch 0600/2104, NLL-Loss 130.3017, 
TRAIN Batch 0650/2104, NLL-Loss 104.3255, 
TRAIN Batch 0700/2104, NLL-Loss 106.1581, 
TRAIN Batch 0750/2104, NLL-Loss 99.7063, 
TRAIN Batch 0800/2104, NLL-Loss 107.3756, 
TRAIN Batch 0850/2104, NLL-Loss 107.6711, 
TRAIN Batch 0900/2104, NLL-Loss 114.3303, 
TRAIN Batch 0950/2104, NLL-Loss 114.1710, 
TRAIN Batch 1000/2104, NLL-Loss 120.5488, 
TRAIN Batch 1050/2104, NLL-Loss 135.2506, 
TRAIN Batch 1100/2104, NLL-Loss 110.8913, 
TRAIN Batch 1

TRAIN Batch 0000/2104, NLL-Loss 93.5490, 
TRAIN Batch 0050/2104, NLL-Loss 121.8998, 
TRAIN Batch 0100/2104, NLL-Loss 102.9167, 
TRAIN Batch 0150/2104, NLL-Loss 136.4403, 
TRAIN Batch 0200/2104, NLL-Loss 104.3290, 
TRAIN Batch 0250/2104, NLL-Loss 114.3716, 
TRAIN Batch 0300/2104, NLL-Loss 117.4881, 
TRAIN Batch 0350/2104, NLL-Loss 121.2460, 
TRAIN Batch 0400/2104, NLL-Loss 122.7627, 
TRAIN Batch 0450/2104, NLL-Loss 126.3535, 
TRAIN Batch 0500/2104, NLL-Loss 93.2873, 
TRAIN Batch 0550/2104, NLL-Loss 104.9869, 
TRAIN Batch 0600/2104, NLL-Loss 126.8158, 
TRAIN Batch 0650/2104, NLL-Loss 105.4971, 
TRAIN Batch 0700/2104, NLL-Loss 109.1281, 
TRAIN Batch 0750/2104, NLL-Loss 122.5322, 
TRAIN Batch 0800/2104, NLL-Loss 135.8782, 
TRAIN Batch 0850/2104, NLL-Loss 99.9621, 
TRAIN Batch 0900/2104, NLL-Loss 93.3773, 
TRAIN Batch 0950/2104, NLL-Loss 128.5780, 
TRAIN Batch 1000/2104, NLL-Loss 113.0051, 
TRAIN Batch 1050/2104, NLL-Loss 112.6229, 
TRAIN Batch 1100/2104, NLL-Loss 121.7864, 
TRAIN Batch 115

TRAIN Batch 0050/2104, NLL-Loss 109.2526, 
TRAIN Batch 0100/2104, NLL-Loss 111.0066, 
TRAIN Batch 0150/2104, NLL-Loss 105.9933, 
TRAIN Batch 0200/2104, NLL-Loss 94.1797, 
TRAIN Batch 0250/2104, NLL-Loss 118.8873, 
TRAIN Batch 0300/2104, NLL-Loss 128.0262, 
TRAIN Batch 0350/2104, NLL-Loss 97.7539, 
TRAIN Batch 0400/2104, NLL-Loss 112.0490, 
TRAIN Batch 0450/2104, NLL-Loss 100.4920, 
TRAIN Batch 0500/2104, NLL-Loss 107.0889, 
TRAIN Batch 0550/2104, NLL-Loss 105.0523, 
TRAIN Batch 0600/2104, NLL-Loss 115.5734, 
TRAIN Batch 0650/2104, NLL-Loss 92.6262, 
TRAIN Batch 0700/2104, NLL-Loss 100.9333, 
TRAIN Batch 0750/2104, NLL-Loss 108.5869, 
TRAIN Batch 0800/2104, NLL-Loss 85.3847, 
TRAIN Batch 0850/2104, NLL-Loss 114.2403, 
TRAIN Batch 0900/2104, NLL-Loss 97.5944, 
TRAIN Batch 0950/2104, NLL-Loss 85.3093, 
TRAIN Batch 1000/2104, NLL-Loss 92.8106, 
TRAIN Batch 1050/2104, NLL-Loss 99.7781, 
TRAIN Batch 1100/2104, NLL-Loss 109.9166, 
TRAIN Batch 1150/2104, NLL-Loss 88.0113, 
TRAIN Batch 1200/210

TRAIN Batch 0100/2104, NLL-Loss 97.3695, 
TRAIN Batch 0150/2104, NLL-Loss 82.9301, 
TRAIN Batch 0200/2104, NLL-Loss 89.3774, 
TRAIN Batch 0250/2104, NLL-Loss 111.2700, 
TRAIN Batch 0300/2104, NLL-Loss 95.8578, 
TRAIN Batch 0350/2104, NLL-Loss 101.8761, 
TRAIN Batch 0400/2104, NLL-Loss 100.5721, 
TRAIN Batch 0450/2104, NLL-Loss 106.8291, 
TRAIN Batch 0500/2104, NLL-Loss 110.6896, 
TRAIN Batch 0550/2104, NLL-Loss 102.1321, 
TRAIN Batch 0600/2104, NLL-Loss 86.3093, 
TRAIN Batch 0650/2104, NLL-Loss 115.7534, 
TRAIN Batch 0700/2104, NLL-Loss 93.5727, 
TRAIN Batch 0750/2104, NLL-Loss 102.9182, 
TRAIN Batch 0800/2104, NLL-Loss 107.0337, 
TRAIN Batch 0850/2104, NLL-Loss 120.5082, 
TRAIN Batch 0900/2104, NLL-Loss 87.2144, 
TRAIN Batch 0950/2104, NLL-Loss 76.7290, 
TRAIN Batch 1000/2104, NLL-Loss 95.4930, 
TRAIN Batch 1050/2104, NLL-Loss 96.4280, 
TRAIN Batch 1100/2104, NLL-Loss 86.7851, 
TRAIN Batch 1150/2104, NLL-Loss 85.0187, 
TRAIN Batch 1200/2104, NLL-Loss 96.0527, 
TRAIN Batch 1250/2104, N

In [12]:
print('# of parameters:', sum(param.numel() for param in model.parameters()))

# of parameters: 5221152


In [13]:
# save learning results
sio.savemat("results.mat", tracker)