# Random Masking (MADE style)

In [1]:
import os
import pdb
import numpy as np
import torch
from torchtext import data, datasets
from collections import OrderedDict as OD
import matplotlib.pyplot as plt

from transformer import * 
from utils       import * 
from custom_ds   import CustomDataset

import seaborn as sns
sns.set(font_scale=1.5)  
sns.set_style("whitegrid")

VERBOSE = True


  return f(*args, **kwds)
  return f(*args, **kwds)


### Build Dataset & Model

In [2]:
def load_data(train_file='train.txt', valid_file='valid.txt', test_file='test.txt', path=None, **kwargs):
    if path is not None:
        train_file, valid_file, test_file = [os.path.join(path, ext) for ext in [train_file, valid_file, test_file]]

    # create required field for language modeling
    input_field = data.Field(lower=True, batch_first=True)
    fields = [("text", input_field)]

    train_set, valid_set, test_set  = CustomDataset.splits(fields, train_file, valid_file, test_file)
    input_field.build_vocab(train_set)

    train_loader, val_loader, test_loader = data.Iterator.splits(
      (train_set, valid_set, test_set),
      sort_key=lambda x : len(x.text),
      batch_sizes=(256, 512, 512),
      **kwargs)

    return input_field, train_loader, val_loader, test_loader

input_field, train_iter, val_iter, test_iter = load_data(path='data/news', device=0, repeat=False)
iterators = {'train': train_iter, 'valid': val_iter, 'test': test_iter}

The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.
The `device` argument should be set by using `torch.device` or passing a string as an argument. This behavior will be deprecated soon and currently defaults to cpu.


In [3]:
# create model and ship to GPU
gen  = make_model(len(input_field.vocab.itos), N=2, h=4).cuda()
# print(gen)
print('number of params', sum([np.prod([int(y) for y in x.shape]) for x in gen.parameters()]))

# build optimizer
optimizer_gen = torch.optim.Adam(gen.parameters())

number of params 11706517


### Loop over the whole dataset 

In [4]:
def full_epoch(epoch_no, split, mask_type='left to right'):
    loader = iterators[split]

    # create logging containers
    logs = OD()
    for name in ['nll', 'ppl']:
        logs[name] = []

    gen.train() if split == 'train' else gen.eval()

    # Training loop
    for i, minibatch in enumerate(loader):
        
        text = minibatch.text.cuda()
        input, target = text[:, :-1], text[:, 1:]
        
        bs, seq_len = input.size()
        
        if mask_type == 'left to right':
            masks = make_std_mask(target, 0)
        elif mask_type == 'random':
            masks = torch.from_numpy(build_ar_masks([seq_len] * bs)).long().cuda()
        else:
            raise ValueError('%s is an invalid mask type' % mask_type)

        logits = gen(input, masks)
        recon_loss = F.cross_entropy(logits.view(bs * seq_len, -1), target.flatten())

        if gen.training:
             optimizer_gen.zero_grad()
             recon_loss.backward()
             params = optimizer_gen.param_groups[0]['params']
             torch.nn.utils.clip_grad_norm_(params, 10, norm_type=2)
             optimizer_gen.step()
         
        logs['nll']  += [recon_loss.data]
        logs['ppl']  += [recon_loss.exp().data]

    return logs

### Plot perplexity graphs 

In [5]:
def plot(train_ppl, valid_ppl, title):
    train, valid = [], []
    for tt, vv in zip(train_ppl, valid_ppl):
        train += [torch.stack(tt).mean().item()]
        valid += [torch.stack(vv).mean().item()]
        
    plt.scatter(np.arange(len(train)), train, label='train ppl')
    plt.scatter(np.arange(len(valid)), valid, label='valid ppl')
    plt.legend()
    plt.hlines(min(valid), 0, len(valid), linestyles='dashed')
    plt.ylim(bottom=0)
    plt.xlim(-0.1, len(valid))
    plt.yticks([min(valid)] + [x for x in np.linspace(0, max(train + valid), 5)][1:])
    plt.title(title)
    plt.show()   

### Exp1 : Baseline Model (left to right masking for training and eval)

In [6]:
train_ppl, valid_ppl, writes = [], [], 0

for epoch in range(0):
    train_log  = full_epoch(epoch, 'train', mask_type = 'left to right')
    train_ppl += [train_log['ppl']]

    if VERBOSE: 
        for key, value in train_log.items():
            print_scalar('train/%s' % key, value, writes)
        print('')
    
    with torch.no_grad():
        valid_log  = full_epoch(epoch, 'valid')
        valid_ppl += [valid_log['ppl']]

        if VERBOSE: 
            for key, value in valid_log.items():
                print_scalar('valid/%s' % key, value, writes)
            print('')
        
    writes += 1

In [7]:
# plot(train_ppl, valid_ppl, 'BASELINE (left-to-right ordering)')

### Sample from model

In [8]:
def greedy_decode(model, max_len, start_symbol=2, take_max=True):
    ys = torch.ones(1, 1).fill_(start_symbol).long().cuda()
    for i in range(max_len-1):
        out = model.decode(
                           ys,
                           (subsequent_mask(ys.size(1)).type_as(ys)))
        prob = model.generator(out[:, -1])
        if take_max:
            _, next_word = torch.max(prob, dim = 1)
        else:
            dist = torch.distributions.Categorical(prob.exp())
            next_word = dist.sample()
        next_word = next_word.data[0]
        ys = torch.cat([ys,
                        torch.ones(1, 1).long().fill_(next_word).cuda()], dim=1)
        ys = ys.cuda()
    return ys

def low_entropy_decoding(model, max_len, sos_token, pad_token):
    import pdb; pdb.set_trace()
    ys = torch.ones(1, max_len).fill_(pad_token).long().cuda()
    ys[0, 0] = sos_token

    mask = torch.zeros(1, max_len, max_len).byte().cuda()
    
    # all tokens can look at themselves
    # mask += torch.eye(max_len).unsqueeze(0).byte().cuda()  
    
    # TODO remove this
    # mask = make_std_mask(torch.zeros_like(ys), pad_token)
    # mask = mask - torch.eye(mask.size(-1)).unsqueeze(0).byte().cuda()
    # test = subsequent_mask(ys.size(1))
    
    # all tokens can look at the sos token
    mask[:, :, 0] = 1
  

    for t in range(max_len):
        out  = model.decode(ys, mask)
        prob = model.generator(out)
        dist = torch.distributions.Categorical(prob.squeeze().exp())
        
        # zero-out words that have already been generated
        # TODO
        position = dist.entropy().argmin()
        
        # update the mask to take into account this new word
        # TODO
        sample = dist.sample()[position]
    
        
def to_readable(vocab, matrix):
    if isinstance(vocab, torchtext.data.field.Field):
        vocab = vocab.vocab.itos

    sentences = []
    for line in matrix:
        sentence = ''
        for token in line:
            sentence += vocab[token] + ' '
        sentence = sentence.replace('<pad>', '')
        sentences += [sentence]
    return sentences

In [9]:
out = greedy_decode(gen, 51, take_max=False)
print(out)
print(to_readable(input_field, out))
# out = low_entropy_decoding(gen, 51, 0, 1)

tensor([[   2, 1567, 4251, 4859, 2434, 2952,  950, 4538, 3310, 3730,  118,  573,
          647, 3112, 1803, 4194,  181, 1182, 5018, 4367, 3787, 1343, 3186, 4486,
         1041,  358, 2644, 4258, 1991, 3548, 4625,  943,  363,  622, 3036, 5023,
         1924, 3680, 4832, 4861, 4643, 1275, 3487, 4963,  513, 4079, 4469,  836,
          206, 1648,  274]], device='cuda:0')
['the hotel boom 20th convicted organisations isis courses bars emerging take 9 stage pursue defend racial end charge romantic heritage closure network funded satellite august ever minds landed ongoing deployed simpson increased report outside prosecutor brian buying hanging bennett adam blocks daily lay orange living passes tourist girl man bringing keep ']


### Exp2 : Proposed Model (random ordering masking for training and eval)

In [10]:
VERBOSE = True
gen  = make_model(len(input_field.vocab.itos), N=2, h=4).cuda()
optimizer_gen = torch.optim.Adam(gen.parameters())

train_ppl, valid_ppl, writes = [], [], 0

for epoch in range(30):
    train_log  = full_epoch(epoch, 'train', mask_type = 'random')
    train_ppl += [train_log['ppl']]

    if VERBOSE: 
        for key, value in train_log.items():
            print_scalar('train/%s' % key, value, writes)
        print('')
    
    with torch.no_grad():
        valid_log  = full_epoch(epoch, 'valid', mask_type='random')
        valid_ppl += [valid_log['ppl']]

        if VERBOSE: 
            for key, value in valid_log.items():
                print_scalar('valid/%s' % key, value, writes)
            print('')
        
    writes += 1

train/nll                                @ write 0 = 1.6476
train/ppl                                @ write 0 = 13.7395

valid/nll                                @ write 0 = 2.2066
valid/ppl                                @ write 0 = 9.1447

train/nll                                @ write 1 = 1.2710
train/ppl                                @ write 1 = 3.5674

valid/nll                                @ write 1 = 2.0846
valid/ppl                                @ write 1 = 8.0797

train/nll                                @ write 2 = 1.2214
train/ppl                                @ write 2 = 3.3946

valid/nll                                @ write 2 = 2.0463
valid/ppl                                @ write 2 = 7.7770

train/nll                                @ write 3 = 1.1943
train/ppl                                @ write 3 = 3.3035

valid/nll                                @ write 3 = 2.0221
valid/ppl                                @ write 3 = 7.5993

train/nll                      

KeyboardInterrupt: 

In [None]:
plot(train_ppl[2:], valid_ppl[2:], 'LM-MADE (random train/test masks)')

In [None]:
out = greedy_decode(gen, 51, take_max=False)
print(to_readable(input_field, out))
out = low_entropy_decoding(gen, 51, 0, 16)

### Comments on results 

Perplexity wise, these results are very promising. Let's see if this gain simply comes from evaluating with random orderings, or that training with random masks actually helps. To do so, we train using the regular ordering, and evaluate with random masks

In [None]:
VERBOSE = False
gen  = make_model(10000 + 1, N=2, h=4).cuda()
optimizer_gen = torch.optim.Adam(gen.parameters())

train_ppl, valid_ppl, writes = [], [], 0

for epoch in range(10):
    train_log  = full_epoch(epoch, 'train', mask_type = 'left to right')
    train_ppl += [train_log['ppl']]

    if VERBOSE: 
        for key, value in train_log.items():
            print_scalar('train/%s' % key, value, writes)
        print('')
    
    with torch.no_grad():
        valid_log  = full_epoch(epoch, 'valid', mask_type='random')
        valid_ppl += [valid_log['ppl']]

        if VERBOSE: 
            for key, value in valid_log.items():
                print_scalar('valid/%s' % key, value, writes)
            print('')
        
    writes += 1

In [None]:
plot(train_ppl, valid_ppl, 'regular training, random mask for testing')

Last thing to try : Train on random orderings but evaluate only on left-to-right orderings

In [None]:
VERBOSE = False
gen  = make_model(10000 + 1, N=2, h=4).cuda()
optimizer_gen = torch.optim.Adam(gen.parameters())

train_ppl, valid_ppl, writes = [], [], 0

for epoch in range(10):
    train_log  = full_epoch(epoch, 'train', mask_type = 'random')
    train_ppl += [train_log['ppl']]

    if VERBOSE: 
        for key, value in train_log.items():
            print_scalar('train/%s' % key, value, writes)
        print('')
    
    with torch.no_grad():
        valid_log  = full_epoch(epoch, 'valid', mask_type='left to right')
        valid_ppl += [valid_log['ppl']]

        if VERBOSE: 
            for key, value in valid_log.items():
                print_scalar('valid/%s' % key, value, writes)
            print('')
        
    writes += 1

In [None]:
plot(train_ppl, valid_ppl, 'random training, left-to-right testing')