# Train

In [3]:
import os
import json
import time
import torch
import argparse
import numpy as np
from multiprocessing import cpu_count
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from collections import OrderedDict, defaultdict

from ptb import PTB
from utils import to_var, idx2word, expierment_name, AttributeDict
from model import SentenceVAE

In [9]:
# data_dir = 'data/simple-examples'
data_dir = 'data/eccos'

args = {
    'data_dir': data_dir,
    'create_data': False,
    'max_sequence_length': 50,
    'min_occ': 1,
    'test': False,

    'epochs': 10,
    'batch_size': 32,
    'learning_rate': 0.001,

    'embedding_size': 300,
    'rnn_type': 'gru',
    'hidden_size': 256,
    'num_layers': 1,
    'bidirectional': False,
    'latent_size': 16,
    'word_dropout': 0,
    'embedding_dropout': 0.5,

    'anneal_function': 'logistic',
    'k': 0.0025,
    'x0': 2500,

    'print_every': 50,
    'tensorboard_logging': False,
    'logdir': '/root/user/work/logs/',
    'save_model_path': 'bin',
}

args = AttributeDict(args)

args.rnn_type = args.rnn_type.lower()
args.anneal_function = args.anneal_function.lower()

assert args.rnn_type in ['rnn', 'lstm', 'gru']
assert args.anneal_function in ['logistic', 'linear']
assert 0 <= args.word_dropout <= 1
args

<AttrDict{'data_dir': 'data/eccos', 'create_data': False, 'max_sequence_length': 50, 'min_occ': 1, 'test': False, 'epochs': 10, 'batch_size': 32, 'learning_rate': 0.001, 'embedding_size': 300, 'rnn_type': 'gru', 'hidden_size': 256, 'num_layers': 1, 'bidirectional': False, 'latent_size': 16, 'word_dropout': 0, 'embedding_dropout': 0.5, 'anneal_function': 'logistic', 'k': 0.0025, 'x0': 2500, 'print_every': 50, 'tensorboard_logging': False, 'logdir': '/root/user/work/logs/', 'save_model_path': 'bin'}>

## load data

In [10]:
%%time
splits = ['train', 'valid'] + (['test'] if args.test else [])
datasets = OrderedDict()
print(f'loading {args.data_dir}')
for split in splits:
    datasets[split] = PTB(
        data_dir=args.data_dir,
        split=split,
        create_data=args.create_data,
        max_sequence_length=args.max_sequence_length,
        min_occ=args.min_occ
    )
    print(f'vocab: {datasets[split].vocab_size}, records: {len(datasets[split].data)}')

loading data/eccos
vocab: 12037, records: 30726
vocab: 12037, records: 7682
CPU times: user 860 ms, sys: 65.2 ms, total: 925 ms
Wall time: 917 ms


In [11]:
# 実際のデータ確認
def ids2text(id_list, ptb):
    return ' '.join([ptb.i2w[f'{i}'] for i in id_list])

_ptb = datasets['train']
_sample = _ptb.data['100']
print(f'■ Input\n{ids2text(_sample["input"], _ptb)}')
print(f'■ Target\n{ids2text(_sample["target"], _ptb)}')

■ Input
<sos> 美しい 姿勢 も 手 に 入る ! 「 コルセット ダイエット 」 で 目指す 美ボディ <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
■ Target
美しい 姿勢 も 手 に 入る ! 「 コルセット ダイエット 」 で 目指す 美ボディ <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


## build model

In [12]:
model = SentenceVAE(
    vocab_size=datasets['train'].vocab_size,
    sos_idx=datasets['train'].sos_idx,
    eos_idx=datasets['train'].eos_idx,
    pad_idx=datasets['train'].pad_idx,
    unk_idx=datasets['train'].unk_idx,
    max_sequence_length=args.max_sequence_length,
    embedding_size=args.embedding_size,
    rnn_type=args.rnn_type,
    hidden_size=args.hidden_size,
    word_dropout=args.word_dropout,
    embedding_dropout=args.embedding_dropout,
    latent_size=args.latent_size,
    num_layers=args.num_layers,
    bidirectional=args.bidirectional
    )

if torch.cuda.is_available():
    model = model.cuda()

In [13]:
model

SentenceVAE(
  (embedding): Embedding(12037, 300)
  (embedding_dropout): Dropout(p=0.5, inplace=False)
  (encoder_rnn): GRU(300, 256, batch_first=True)
  (decoder_rnn): GRU(300, 256, batch_first=True)
  (hidden2mean): Linear(in_features=256, out_features=16, bias=True)
  (hidden2logv): Linear(in_features=256, out_features=16, bias=True)
  (latent2hidden): Linear(in_features=16, out_features=256, bias=True)
  (outputs2vocab): Linear(in_features=256, out_features=12037, bias=True)
)

## log

In [14]:
print(f'tensorboard logging: {args.tensorboard_logging}')
ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())
if args.tensorboard_logging:
    writer = SummaryWriter(os.path.join(args.logdir, expierment_name(args,ts)))
    writer.add_text("model", str(model))
    writer.add_text("args", str(args))
    writer.add_text("ts", ts)
    
save_model_path = os.path.join(args.save_model_path, ts)
os.makedirs(save_model_path)

tensorboard logging: False


## loss, optimizer

In [15]:
def kl_anneal_function(anneal_function, step, k, x0):
    if anneal_function == 'logistic':
        return float(1/(1+np.exp(-k*(step-x0))))
    elif anneal_function == 'linear':
        return min(1, step/x0)

NLL = torch.nn.NLLLoss(reduction='sum', ignore_index=datasets['train'].pad_idx)
def loss_fn(logp, target, length, mean, logv, anneal_function, step, k, x0):

    # cut-off unnecessary padding from target, and flatten
    target = target[:, :torch.max(length).data].contiguous().view(-1)
    logp = logp.view(-1, logp.size(2))

    # Negative Log Likelihood
    NLL_loss = NLL(logp, target)

    # KL Divergence
    KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
    KL_weight = kl_anneal_function(anneal_function, step, k, x0)

    return NLL_loss, KL_loss, KL_weight

optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
step = 0

In [16]:
for epoch in range(args.epochs):

    for split in splits:

        data_loader = DataLoader(
            dataset=datasets[split],
            batch_size=args.batch_size,
            shuffle=split=='train',
            num_workers=cpu_count(),
            pin_memory=torch.cuda.is_available()
        )

        tracker = defaultdict(tensor)

        # Enable/Disable Dropout
        if split == 'train':
            model.train()
        else:
            model.eval()

        for iteration, batch in enumerate(data_loader):
            
            batch_size = batch['input'].size(0)
            
            for k, v in batch.items():
                if torch.is_tensor(v):
                    batch[k] = to_var(v)
                    
            # Forward pass
            logp, mean, logv, z = model(batch['input'], batch['length'])

            # loss calculation
            NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
                batch['length'], mean, logv, args.anneal_function, step, args.k, args.x0)

            loss = (NLL_loss + KL_weight * KL_loss)/batch_size

            # backward + optimization
            if split == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                step += 1

            # bookkeepeing
            tracker['ELBO'] = torch.cat((tracker['ELBO'], loss.data.view(1)))

            if args.tensorboard_logging:
                writer.add_scalar("%s/ELBO"%split.upper(), loss.data.item(), epoch*len(data_loader) + iteration)
                writer.add_scalar("%s/NLL Loss"%split.upper(), NLL_loss.data.item()/batch_size, epoch*len(data_loader) + iteration)
                writer.add_scalar("%s/KL Loss"%split.upper(), KL_loss.data.item()/batch_size, epoch*len(data_loader) + iteration)
                writer.add_scalar("%s/KL Weight"%split.upper(), KL_weight, epoch*len(data_loader) + iteration)

            if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
                print("%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                    %(split.upper(), iteration, len(data_loader)-1, loss.data.item(), NLL_loss.data.item()/batch_size, KL_loss.data.item()/batch_size, KL_weight))

            if split == 'valid':
                if 'target_sents' not in tracker:
                    tracker['target_sents'] = list()
                tracker['target_sents'] += idx2word(batch['target'].data, i2w=datasets['train'].get_i2w(), pad_idx=datasets['train'].pad_idx)
                tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

        print("%s Epoch %02d/%i, Mean ELBO %9.4f"%(split.upper(), epoch, args.epochs, torch.mean(tracker['ELBO'])))

        if args.tensorboard_logging:
            writer.add_scalar("%s-Epoch/ELBO"%split.upper(), torch.mean(tracker['ELBO']), epoch)

        # save a dump of all sentences and the encoded latent space
        if split == 'valid':
            dump = {'target_sents':tracker['target_sents'], 'z':tracker['z'].tolist()}
            if not os.path.exists(os.path.join('dumps', ts)):
                os.makedirs('dumps/'+ts)
            with open(os.path.join('dumps/'+ts+'/valid_E%i.json'%epoch), 'w') as dump_file:
                json.dump(dump,dump_file)

        # save checkpoint
        if split == 'train':
            checkpoint_path = os.path.join(save_model_path, "E%i.pytorch"%(epoch))
            torch.save(model.state_dict(), checkpoint_path)
            print("Model saved at %s"%checkpoint_path)

TRAIN Batch 0000/960, Loss  188.5276, NLL-Loss  188.5269, KL-Loss    0.3596, KL-Weight  0.002


KeyboardInterrupt: 