In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import torch.optim as optim

from torch import nn
from corpus import SSTCorpus
from tqdm import tqdm_notebook
from vanila_model import RnnVae
from sklearn.datasets.lfw import Bunch
from torch.nn.utils import clip_grad_norm_

In [3]:
args = Bunch(
    model=Bunch(
        d_h=64,
        d_z=64,
        d_c=2,
        n_len=15,
        n_vocab=10000,
        d_emb=50,
        p_word_dropout=0.3,
        freeze_embeddings=False,
    ),
    train=Bunch(
        n_batch=32,
        lr=1e-3,
        lr_decay=1000,
        n_iter=100000,
        log_interval=3000,
        grad_clipping=5,
        joint_loss=Bunch(
            start_inc=3000,
            weight=0.01,
            w_max=0.15
        )
    ),
    device_code=3
)

In [4]:
device = torch.device(
    f'cuda:{args.device_code}' 
    if args.device_code >= 0 and torch.cuda.is_available()
    else 'cpu'
)
device

device(type='cuda', index=3)

In [5]:
corpus = SSTCorpus(**args.model, n_batch=args.train.n_batch, device=device)

In [6]:
model = RnnVae(**args.model, x_vocab=corpus.vocab('x')).to(device)
model

RnnVae(
  (x_emb): Embedding(4847, 50, padding_idx=1)
  (encoder_rnn): GRU(50, 64)
  (q_mu): Linear(in_features=64, out_features=64, bias=True)
  (q_logvar): Linear(in_features=64, out_features=64, bias=True)
  (decoder_rnn): GRU(116, 66)
  (decoder_fc): Linear(in_features=66, out_features=4847, bias=True)
  (encoder): ModuleList(
    (0): GRU(50, 64)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): Linear(in_features=64, out_features=64, bias=True)
  )
  (decoder): ModuleList(
    (0): GRU(116, 66)
    (1): Linear(in_features=66, out_features=4847, bias=True)
  )
  (vae): ModuleList(
    (0): Embedding(4847, 50, padding_idx=1)
    (1): ModuleList(
      (0): GRU(50, 64)
      (1): Linear(in_features=64, out_features=64, bias=True)
      (2): Linear(in_features=64, out_features=64, bias=True)
    )
    (2): ModuleList(
      (0): GRU(116, 66)
      (1): Linear(in_features=66, out_features=4847, bias=True)
    )
  )
)

In [7]:
class JointLossWithKLAnnealer:
    def __init__(self, **kwargs):
        self.start_inc = kwargs['start_inc']
        self.w_max = kwargs['w_max']
        
        self.weight = kwargs['weight']
        self.inc = (self.w_max - self.weight) / (kwargs['n_iter'] - self.start_inc)
    
    def __call__(self, i, kl_loss, recon_loss):
        if i >= self.start_inc and self.weight < self.w_max:
            self.weight += self.inc
        
        return self.weight * kl_loss + recon_loss

In [8]:
get_params = lambda: model.vae.parameters()
trainer = optim.Adam(get_params(), lr=args.train.lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(trainer, factor=0.01)
# lr_lambda = lambda e: args.train.lr * (0.5 ** (e // args.train.lr_decay))
# lr_schelduer = optim.lr_scheduler.LambdaLR(trainer, lr_lambda)
joint_loss = JointLossWithKLAnnealer(**args.train.joint_loss, n_iter=args.train.n_iter)

batcher = corpus.batcher('unlabeled', 'train', n_iter=args.train.n_iter)
t = tqdm_notebook(enumerate(batcher))
losses = []
log = []
epoch = 0
for i, x in t:
    # Forward
    kl_loss, recon_loss = model(x)
    loss = joint_loss(i, kl_loss, recon_loss)
    
    # Backward
    loss.backward()
    clip_grad_norm_(get_params(), args.train.grad_clipping)
    trainer.step()
    trainer.zero_grad()
    
    # Calc metrics and update t
    losses.append(loss.item())
    cur_loss = np.mean(losses[-args.train.log_interval:])
    kl_weight = joint_loss.weight
    lr = trainer.param_groups[0]['lr']
    t.set_postfix_str(f'loss={cur_loss:.5f} klw={kl_weight:.3f} lr={lr:.7f}')
    t.refresh()
    
    # Log
    if (i > 0 and i % args.train.log_interval == 0) or (i == args.train.n_iter - 1):
        epoch += 1
        lr_scheduler.step(cur_loss, epoch=epoch)
        
        sent = corpus.reverse(model.sample_sentence(device=device))
        print(sent)
        
        log.append({
            'iter': i,
            'loss': cur_loss,
            'sent': corpus.reverse(model.sample_sentence(device=device)),
            'kl_weight': kl_weight,
            'lr': lr
        })

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

more marshall comes is really match looks big and lives .
and maid for poetry that was only to be justice .
heavy , soggy spirited and funny and meandering .
fluffy and obnoxious and .
hopkins , rent from robert malkovich .
a cockeyed shot all the action characters .
terminally brain dead production .
a bad mannered , thoroughly long natured .
... one of this movie , you can be work as
the verdict : two bodies and hardly a laugh up .
the film 's hardly a director direction .
just another disjointed , fairly predictable and if concept .
too good , but a mess .
` blue crush ' swims away with better to be a heart .
brisk hack job .
nevertheless , i take .
makes 98 minutes better to trip to feel .
... an affecting power , dark and success .
the whole thing about the star trek movie in a way time .
cherish would 've life for first time when it 's navel
it 's fun , you felt and to vulgarity , it thinks it
contrived pastiche of caper clichés and pacing are .
brings of this made the film 's ne