In [1]:
import sys
sys.path.append('..')

In [2]:
import os
import time
import math
import scipy.io as sio
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from multiprocessing import cpu_count

In [3]:
from ptb import PTB
from model import RNN



In [4]:
# device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    print('GPU')
else:
    print('CPU')

GPU


In [5]:
# Penn TreeBank (PTB) dataset
data_path = '../data'
max_len = 96
splits = ['train', 'valid', 'test']
datasets = {split: PTB(root=data_path, split=split) for split in splits}

In [6]:
# data loader
batch_size = 20 #32
dataloaders = {split: DataLoader(datasets[split],
                                 batch_size=batch_size,
                                 shuffle=split=='train',
                                 num_workers=cpu_count(),
                                 pin_memory=torch.cuda.is_available())
                                 for split in splits}
symbols = datasets['train'].symbols

In [7]:
# RNN model
embedding_size = 300
hidden_size = 450 #256
dropout_rate = 0.5

model = RNN(vocab_size=datasets['train'].vocab_size,
            embed_size=embedding_size,
            time_step=max_len,
            hidden_size=hidden_size,
            dropout_rate=dropout_rate,
            bos_idx=symbols['<bos>'],
            eos_idx=symbols['<eos>'],
            pad_idx=symbols['<pad>'])

model = model.to(device)
print(model)

RNN(
  (embedding): Embedding(10002, 300, padding_idx=0)
  (rnn): LSTM(300, 450, batch_first=True)
  (output): Linear(in_features=450, out_features=10002, bias=True)
)


In [8]:
# initialization
for p in model.parameters():
    p.data.uniform_(-0.1, 0.1)

In [9]:
# folder to save model
save_path = 'model'
if not os.path.exists(save_path):
    os.makedirs(save_path)

In [10]:
# objective function
learning_rate = 40#0.001
criterion = nn.NLLLoss(size_average=False, ignore_index=symbols['<pad>'])
optimizer = optim.SGD(model.parameters(), lr=learning_rate) #Adam

# negative log likelihood
def NLL(logp, target, length):
    target = target[:, :torch.max(length).item()].contiguous().view(-1)
    logp = logp.view(-1, logp.size(-1))
    return criterion(logp, target)



In [11]:
# training setting
epoch = 20
print_every = 50

In [12]:
# training interface
step = 0
tracker = {'NLL': []}
start_time = time.time()
for ep in range(epoch):
    # learning rate decay
    if (ep % 2 == 0) and (learning_rate>0.1):
        learning_rate = learning_rate * 0.5
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate

    for split in splits:
        dataloader = dataloaders[split]
        model.train() if split == 'train' else model.eval()
        totals = {'NLL': 0., 'words': 0}

        for itr, (_, dec_inputs, targets, lengths) in enumerate(dataloader):
            bsize = dec_inputs.size(0)
            dec_inputs = dec_inputs.to(device)
            targets = targets.to(device)
            lengths = lengths.to(device)

            # forward
            logp = model(dec_inputs, lengths)

            # calculate loss
            NLL_loss = NLL(logp, targets, lengths + 1)
            loss = NLL_loss / bsize

            # cumulate
            totals['NLL'] += NLL_loss.item()
            totals['words'] += torch.sum(lengths).item()

            # backward and optimize
            if split == 'train':
                step += 1
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 0.25) #5
                optimizer.step()

                # track
                tracker['NLL'].append(loss.item())

                # print statistics
                if itr % print_every == 0 or itr + 1 == len(dataloader):
                    print("%s Batch %04d/%04d, NLL-Loss %.4f, "
                          % (split.upper(), itr, len(dataloader),
                             tracker['NLL'][-1]))

        samples = len(datasets[split])
        print("%s Epoch %02d/%02d, NLL %.4f, PPL %.4f"
              % (split.upper(), ep, epoch, totals['NLL'] / samples,
                 math.exp(totals['NLL'] / totals['words'])))

    # save checkpoint
    checkpoint_path = os.path.join(save_path, "E%02d.pkl" % ep)
    torch.save(model.state_dict(), checkpoint_path)
    print("Model saved at %s\n" % checkpoint_path)
end_time = time.time()
print('Total cost time',
      time.strftime("%H hr %M min %S sec", time.gmtime(end_time - start_time)))

TRAIN Batch 0000/2104, NLL-Loss 205.1002, 
TRAIN Batch 0050/2104, NLL-Loss 163.0776, 
TRAIN Batch 0100/2104, NLL-Loss 164.3962, 
TRAIN Batch 0150/2104, NLL-Loss 153.9433, 
TRAIN Batch 0200/2104, NLL-Loss 156.9258, 
TRAIN Batch 0250/2104, NLL-Loss 123.7360, 
TRAIN Batch 0300/2104, NLL-Loss 127.2849, 
TRAIN Batch 0350/2104, NLL-Loss 152.4349, 
TRAIN Batch 0400/2104, NLL-Loss 114.6360, 
TRAIN Batch 0450/2104, NLL-Loss 107.0449, 
TRAIN Batch 0500/2104, NLL-Loss 123.7056, 
TRAIN Batch 0550/2104, NLL-Loss 132.6186, 
TRAIN Batch 0600/2104, NLL-Loss 121.1500, 
TRAIN Batch 0650/2104, NLL-Loss 143.3746, 
TRAIN Batch 0700/2104, NLL-Loss 117.0036, 
TRAIN Batch 0750/2104, NLL-Loss 110.7385, 
TRAIN Batch 0800/2104, NLL-Loss 138.4691, 
TRAIN Batch 0850/2104, NLL-Loss 124.3361, 
TRAIN Batch 0900/2104, NLL-Loss 112.4763, 
TRAIN Batch 0950/2104, NLL-Loss 123.8357, 
TRAIN Batch 1000/2104, NLL-Loss 123.9144, 
TRAIN Batch 1050/2104, NLL-Loss 111.3361, 
TRAIN Batch 1100/2104, NLL-Loss 123.7838, 
TRAIN Batch

TRAIN Batch 0000/2104, NLL-Loss 101.1075, 
TRAIN Batch 0050/2104, NLL-Loss 111.0943, 
TRAIN Batch 0100/2104, NLL-Loss 81.1327, 
TRAIN Batch 0150/2104, NLL-Loss 110.8152, 
TRAIN Batch 0200/2104, NLL-Loss 111.5830, 
TRAIN Batch 0250/2104, NLL-Loss 124.3795, 
TRAIN Batch 0300/2104, NLL-Loss 116.5560, 
TRAIN Batch 0350/2104, NLL-Loss 96.0956, 
TRAIN Batch 0400/2104, NLL-Loss 94.4773, 
TRAIN Batch 0450/2104, NLL-Loss 88.3573, 
TRAIN Batch 0500/2104, NLL-Loss 99.0879, 
TRAIN Batch 0550/2104, NLL-Loss 108.8426, 
TRAIN Batch 0600/2104, NLL-Loss 109.3814, 
TRAIN Batch 0650/2104, NLL-Loss 100.9349, 
TRAIN Batch 0700/2104, NLL-Loss 80.0676, 
TRAIN Batch 0750/2104, NLL-Loss 105.0107, 
TRAIN Batch 0800/2104, NLL-Loss 121.5751, 
TRAIN Batch 0850/2104, NLL-Loss 110.7120, 
TRAIN Batch 0900/2104, NLL-Loss 83.6839, 
TRAIN Batch 0950/2104, NLL-Loss 97.8394, 
TRAIN Batch 1000/2104, NLL-Loss 115.1070, 
TRAIN Batch 1050/2104, NLL-Loss 82.5005, 
TRAIN Batch 1100/2104, NLL-Loss 83.8736, 
TRAIN Batch 1150/2104

TRAIN Batch 0050/2104, NLL-Loss 108.0003, 
TRAIN Batch 0100/2104, NLL-Loss 112.9936, 
TRAIN Batch 0150/2104, NLL-Loss 90.6010, 
TRAIN Batch 0200/2104, NLL-Loss 82.8587, 
TRAIN Batch 0250/2104, NLL-Loss 91.3092, 
TRAIN Batch 0300/2104, NLL-Loss 99.5268, 
TRAIN Batch 0350/2104, NLL-Loss 97.0355, 
TRAIN Batch 0400/2104, NLL-Loss 108.3455, 
TRAIN Batch 0450/2104, NLL-Loss 95.5621, 
TRAIN Batch 0500/2104, NLL-Loss 94.6084, 
TRAIN Batch 0550/2104, NLL-Loss 106.0120, 
TRAIN Batch 0600/2104, NLL-Loss 107.0856, 
TRAIN Batch 0650/2104, NLL-Loss 99.3507, 
TRAIN Batch 0700/2104, NLL-Loss 92.7055, 
TRAIN Batch 0750/2104, NLL-Loss 97.0157, 
TRAIN Batch 0800/2104, NLL-Loss 113.1460, 
TRAIN Batch 0850/2104, NLL-Loss 97.7768, 
TRAIN Batch 0900/2104, NLL-Loss 91.2451, 
TRAIN Batch 0950/2104, NLL-Loss 75.6502, 
TRAIN Batch 1000/2104, NLL-Loss 102.5477, 
TRAIN Batch 1050/2104, NLL-Loss 88.5876, 
TRAIN Batch 1100/2104, NLL-Loss 110.0151, 
TRAIN Batch 1150/2104, NLL-Loss 112.9438, 
TRAIN Batch 1200/2104, NL

TRAIN Batch 0150/2104, NLL-Loss 92.9893, 
TRAIN Batch 0200/2104, NLL-Loss 85.0956, 
TRAIN Batch 0250/2104, NLL-Loss 97.4712, 
TRAIN Batch 0300/2104, NLL-Loss 111.3325, 
TRAIN Batch 0350/2104, NLL-Loss 93.2810, 
TRAIN Batch 0400/2104, NLL-Loss 106.0800, 
TRAIN Batch 0450/2104, NLL-Loss 84.1591, 
TRAIN Batch 0500/2104, NLL-Loss 96.4337, 
TRAIN Batch 0550/2104, NLL-Loss 84.3714, 
TRAIN Batch 0600/2104, NLL-Loss 87.5873, 
TRAIN Batch 0650/2104, NLL-Loss 86.7672, 
TRAIN Batch 0700/2104, NLL-Loss 112.2869, 
TRAIN Batch 0750/2104, NLL-Loss 106.6448, 
TRAIN Batch 0800/2104, NLL-Loss 98.3761, 
TRAIN Batch 0850/2104, NLL-Loss 111.4753, 
TRAIN Batch 0900/2104, NLL-Loss 88.5718, 
TRAIN Batch 0950/2104, NLL-Loss 98.3462, 
TRAIN Batch 1000/2104, NLL-Loss 88.7632, 
TRAIN Batch 1050/2104, NLL-Loss 106.2341, 
TRAIN Batch 1100/2104, NLL-Loss 104.3982, 
TRAIN Batch 1150/2104, NLL-Loss 106.6657, 
TRAIN Batch 1200/2104, NLL-Loss 83.2334, 
TRAIN Batch 1250/2104, NLL-Loss 90.9297, 
TRAIN Batch 1300/2104, NLL

TRAIN Batch 0250/2104, NLL-Loss 94.4444, 
TRAIN Batch 0300/2104, NLL-Loss 97.8744, 
TRAIN Batch 0350/2104, NLL-Loss 99.6775, 
TRAIN Batch 0400/2104, NLL-Loss 87.3081, 
TRAIN Batch 0450/2104, NLL-Loss 88.5922, 
TRAIN Batch 0500/2104, NLL-Loss 97.7767, 
TRAIN Batch 0550/2104, NLL-Loss 101.6478, 
TRAIN Batch 0600/2104, NLL-Loss 107.7234, 
TRAIN Batch 0650/2104, NLL-Loss 105.3501, 
TRAIN Batch 0700/2104, NLL-Loss 96.5990, 
TRAIN Batch 0750/2104, NLL-Loss 110.4589, 
TRAIN Batch 0800/2104, NLL-Loss 118.1229, 
TRAIN Batch 0850/2104, NLL-Loss 89.0330, 
TRAIN Batch 0900/2104, NLL-Loss 96.0839, 
TRAIN Batch 0950/2104, NLL-Loss 134.1119, 
TRAIN Batch 1000/2104, NLL-Loss 92.5848, 
TRAIN Batch 1050/2104, NLL-Loss 113.3227, 
TRAIN Batch 1100/2104, NLL-Loss 106.9671, 
TRAIN Batch 1150/2104, NLL-Loss 105.2266, 
TRAIN Batch 1200/2104, NLL-Loss 110.5178, 
TRAIN Batch 1250/2104, NLL-Loss 100.5084, 
TRAIN Batch 1300/2104, NLL-Loss 92.6189, 
TRAIN Batch 1350/2104, NLL-Loss 103.4856, 
TRAIN Batch 1400/2104,

In [13]:
# save learning results
sio.savemat("results.mat", tracker)

In [13]:
print('# of parameters:', sum(param.numel() for param in model.parameters()))

# of parameters: 8865102
