In [0]:
n_samples = 20
batch_size = 2
sentence_len = 15
max_iter = 500
condition = "really happy"

file_name = "./generated_language.txt"

In [0]:
out_file = file_name
in_file = file_name

In [0]:
import numpy as np
import torch
import math
import time

# !pip install pytorch_pretrained_bert
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

In [0]:
# Load pre-trained model (weights)
model_version = 'bert-base-uncased'
model = BertForMaskedLM.from_pretrained(model_version)
model.eval()
cuda = torch.cuda.is_available()
if cuda:
    model = model.cuda(0)

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=model_version.endswith("uncased"))

def tokenize_batch(batch):
    return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]

def untokenize_batch(batch):
    return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]

def detokenize(sent):
    """ Roughly detokenizes (mainly undoes wordpiece) """
    new_sent = []
    for i, tok in enumerate(sent):
        if tok.startswith("##"):
            new_sent[len(new_sent) - 1] = new_sent[len(new_sent) - 1] + tok[2:]
        else:
            new_sent.append(tok)
    return new_sent

CLS = '[CLS]'
SEP = '[SEP]'
MASK = '[MASK]'
mask_id = tokenizer.convert_tokens_to_ids([MASK])[0]
sep_id = tokenizer.convert_tokens_to_ids([SEP])[0]
cls_id = tokenizer.convert_tokens_to_ids([CLS])[0]

# Generations

In [0]:
def generate_step(out, gen_idx, temperature=None, top_k=0, sample=False, return_list=True):
    """ Generate a word from from out[gen_idx]
    
    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate for
        - top_k (int): if >0, only sample from the top k most probable words
        - sample (Bool): if True, sample from full distribution. Overridden by top_k 
    """
    logits = out[:, gen_idx]
    if temperature is not None:
        logits = logits / temperature
    if top_k > 0:
        kth_vals, kth_idx = logits.topk(top_k, dim=-1)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1)
    elif sample:
        dist = torch.distributions.categorical.Categorical(logits=logits)
        idx = dist.sample().squeeze(-1)
    else:
        idx = torch.argmax(logits, dim=-1)
    return idx.tolist() if return_list else idx

In [0]:
# Generation modes as functions

def get_init_text(seed_text,inp_txt, max_len, batch_size = 1, rand_init=False):
    """ Get initial sentence by padding seed_text with condition and mask of lenght max_len """
    #batch = [seed_text + [MASK] * max_len + [SEP] for _ in range(batch_size)]

    batch = list()
    inp_len = len(inp_txt)
    seed_len = 1

    # randomly introduce condition on initial sequence
    rand_ind = np.random.randint(0, max_len - inp_len + 1)
    for ind in range(batch_size):
        temp = seed_text + [MASK] * rand_ind + inp_txt + [MASK] * (max_len - rand_ind - inp_len) + [SEP]
        batch.append(temp)

    return [tokenize_batch(batch),rand_ind]

In [0]:
def getLeftInd(inp_ind,inp_len,max_len):
    """ Generate indices that we can mask """
    ind_flag = [False]*max_len
    for i in range(inp_len):
        ind_flag[inp_ind + i] = True

    left_ind = list()
    for i in range(max_len):
        if not ind_flag[i]:
            left_ind.append(i)

    return left_ind
    left_len = len(left_ind)
    kk = np.random.randint(0, left_len)
    return left_ind[kk]

In [0]:
def getRandInd(left_ind):
    """ left_ind contains indices of sentence except "condition". function will randomly choose index for mask. """
    left_len = len(left_ind)
    kk = np.random.randint(0, left_len)
    return left_ind[kk]

In [0]:
def parallel_sequential_generation(seed_text, max_len=15, top_k=0, temperature=None, max_iter=300, burnin=200,
                                   cuda=False, print_every=10, verbose=True):
    """ Generate for one random position at a timestep
    
    args:
        - burnin: during burn-in period, sample from full distribution; afterwards take argmax
    """
    seed_len = len(seed_text)

    inp_txt = seed_text[1:seed_len]
    inp_len = seed_len - 1

    seed_text = seed_text[0:1]
    seed_len = 1

    max_len = max_len + inp_len

    noise_and_ind = get_init_text(seed_text,inp_txt, max_len, batch_size)
    batch = noise_and_ind[0]
    inp_ind = noise_and_ind[1]

    left_ind = getLeftInd(inp_ind,inp_len,max_len)
    for ii in range(max_iter):
        kk = getRandInd(left_ind)
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = mask_id
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        topk = top_k if (ii >= burnin) else 0
        idxs = generate_step(out, gen_idx=seed_len+kk, top_k=topk, temperature=temperature, sample=(ii < burnin))
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = idxs[jj]
            
        if verbose and np.mod(ii+1, print_every) == 0:
            for_print = tokenizer.convert_ids_to_tokens(batch[0])
            for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
            print("iter", ii+1, " ".join(for_print))
            
    return untokenize_batch(batch)

In [0]:
def generate(n_samples, seed_text="[CLS]", batch_size=10, max_len=25, 
             sample=True, top_k=100, temperature=1.0, burnin=200, max_iter=500,
             cuda=False, print_every=1):
    # main generation function to call
    sentences = []
    n_batches = math.ceil(n_samples / batch_size)
    start_time = time.time()
    for batch_n in range(n_batches):
        batch = parallel_sequential_generation(seed_text, max_len=max_len, top_k=top_k,
                                               temperature=temperature, burnin=burnin, max_iter=max_iter, 
                                               cuda=cuda, verbose=False)
        
        if (batch_n + 1) % print_every == 0:
            print("Finished batch %d in %.3fs" % (batch_n + 1, time.time() - start_time))
            start_time = time.time()
        
        sentences += batch
    return sentences

In [0]:
# Utility functions

def printer(sent, should_detokenize=True):
    if should_detokenize:
        sent = detokenize(sent)[1:-1]
    print(" ".join(sent))
    
def read_sents(in_file, should_detokenize=False):
    sents = [sent.strip().split() for sent in open(in_file).readlines()]
    if should_detokenize:
        sents = [detokenize(sent) for sent in sents]
    return sents

def write_sents(out_file, sents, should_detokenize=False):
    with open(out_file, "w") as out_fh:
        for sent in sents:
            sent = detokenize(sent[1:-1]) if should_detokenize else sent
            out_fh.write("%s\n" % " ".join(sent))

In [70]:
max_len = sentence_len

top_k = 100
temperature = 0.7

leed_out_len = 5 # max_len
burnin = 250
sample = True


# Choose the prefix context
in_text = "[CLS]"+" "+condition
seed_text = in_text.split()

for temp in [1.0]:
    bert_sents = generate(n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len,
                          sample=sample, top_k=top_k, temperature=temp, burnin=burnin, max_iter=max_iter,
                          cuda=False)
                              
    write_sents(out_file, bert_sents, should_detokenize=True)

Finished batch 1 in 133.232s
Finished batch 2 in 132.599s
Finished batch 3 in 133.154s
Finished batch 4 in 132.877s
Finished batch 5 in 133.343s
Finished batch 6 in 133.594s
Finished batch 7 in 133.783s
Finished batch 8 in 134.045s
Finished batch 9 in 133.256s
Finished batch 10 in 133.574s


# Results

In [71]:

bert_sents = read_sents(in_file, should_detokenize=False)
for i in range(min(n_samples,50)):
    printer(bert_sents[i], should_detokenize=False)

i was really happy about the change . i had never been to this festival before .
they look really happy now , too , some really happy , but mostly not happy .
" oh hello , madaug ! i am really happy just having you back .
and if there was anything to ask , i was really happy to see him tonight .
that boy really loved me . then there were two . and i was really happy .
that was when my parents hugged me . oh god , i was just really happy .
the painting presents white silhouettes to symbolize how there will be really happy marriages .
pity i cannot keep my eyes off him , but he ... looks really happy inside .
" really happy to see you again . " sean loved the sound of the bear .
" really happy , " he muttered . " so happy . so damn happy . "
not really happy either , although i hear it a lot . he checks his watch .
never really happy with herself , she was married to a man from the far east .
opening the door i exclaim , " how are you looking ? " really happy .
by the minute , you can se