In [1]:
import os
import logging

import torch
import torch.nn.functional as F

from modules import VAE
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
                                  BertConfig, BertForLatentConnector, BertTokenizer,
                                  GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer,
                                  OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
                                  RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)


# LOADING MODELS

In [2]:
class Args():
    encoder_model_type = "bert"
    decoder_model_type = "gpt2"
    device = "cuda"
    checkpoint_dir = "/home/patryk/Studia/PracaMagisterska/optimus/Optimus/output/checkpoint-508523"
    encoder_model_name_or_path="bert-base-cased"
    decoder_model_name_or_path="gpt2"
    encoder_tokenizer_name = ""
    decoder_tokenizer_name = ""
    do_lower_case = True
    global_step = 508523
    latent_size = 768
    block_size = 100
    fb_mode = 1
    dim_target_kl = 0.5
    temperature = 1.0
    top_k = 0.0
    top_p = 0.9

args = Args()

MODEL_CLASSES = {
    'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer),
    'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    'bert': (BertConfig, BertForLatentConnector, BertTokenizer),
    'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
}



output_encoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-encoder-{}'.format(args.global_step))
output_decoder_dir = os.path.join(args.checkpoint_dir, 'checkpoint-decoder-{}'.format(args.global_step)) 
output_full_dir    = os.path.join(args.checkpoint_dir, 'checkpoint-full-{}'.format(args.global_step)) 

checkpoints = [ [output_encoder_dir, output_decoder_dir] ]
logging.info("Evaluate the following checkpoints: %s", checkpoints)

# Load a trained Encoder model and vocabulary
encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES[args.encoder_model_type]
model_encoder = encoder_model_class.from_pretrained(output_encoder_dir, latent_size=args.latent_size)
tokenizer_encoder = encoder_tokenizer_class.from_pretrained(args.encoder_tokenizer_name if args.encoder_tokenizer_name else args.encoder_model_name_or_path, do_lower_case=args.do_lower_case)

model_encoder.to(args.device)
if args.block_size <= 0:
    args.block_size = tokenizer_encoder.max_len_single_sentence  # Our input block size will be the max possible for the model
args.block_size = min(args.block_size, tokenizer_encoder.max_len_single_sentence)

# Load a trained Decoder model and vocabulary
decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES[args.decoder_model_type]
model_decoder = decoder_model_class.from_pretrained(output_decoder_dir, latent_size=args.latent_size)
tokenizer_decoder = decoder_tokenizer_class.from_pretrained(args.decoder_tokenizer_name if args.decoder_tokenizer_name else args.decoder_model_name_or_path, do_lower_case=args.do_lower_case)
model_decoder.to(args.device)
if args.block_size <= 0:
    args.block_size = tokenizer_decoder.max_len_single_sentence  # Our input block size will be the max possible for the model
args.block_size = min(args.block_size, tokenizer_decoder.max_len_single_sentence)

# Load full model
checkpoint = torch.load(os.path.join(output_full_dir, 'training.bin'))

model_vae = VAE(
    model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, args
)

model_vae.load_state_dict(checkpoint["model_state_dict"])


ERROR:pytorch_transformers.tokenization_utils:Using pad_token, but it is not set yet.


<All keys matched successfully>

In [3]:
model_vae.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [37]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)

        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    return logits

def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None):
    
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    generated = context
    with torch.no_grad():
        while True:
        # for _ in trange(length):
            inputs = {'input_ids': generated, 'past': past}
            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
            next_token_logits = outputs[0][0, -1, :] / temperature
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            print(F.softmax(filtered_logits, dim=-1))
            next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)

            # pdb.set_trace()
            if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]:
                break

    return generated

def latent_code_from_text(text, tokenizer_encoder, model_vae, args):
    tokenized1 = tokenizer_encoder.encode(text)
    tokenized1 = [101] + tokenized1 + [102]
    coded1 = torch.Tensor([tokenized1])
    coded1 =torch.Tensor.long(coded1)
    with torch.no_grad():
        x0 = coded1
        x0 = x0.to(args.device)
        pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1]
        mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1)
        # latent_z = model_vae.reparameterize(mean, logvar, 1).squeeze(1)
        latent_z = mean.squeeze(1)  

        coded_length = len(tokenized1)
        return latent_z, coded_length

def text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder):
    past = latent_z
    context_tokens = tokenizer_decoder.encode('<BOS>')

    length = 128 # maximum length, but not used 
    out = sample_sequence_conditional(
        model=model_vae.decoder,
        context=context_tokens,
        past=past,
        length= length, # Chunyuan: Fix length; or use <EOS> to complete a sentence
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        device=args.device,
        decoder_tokenizer = tokenizer_decoder
    )
    text_x1 = tokenizer_decoder.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True)
    text_x1 = text_x1.split()[1:-1]
    text_x1 = ' '.join(text_x1)
    return text_x1


In [57]:
input_text = "place is a place where you can."

In [58]:
latent_z = latent_code_from_text(input_text, tokenizer_encoder, model_vae, args)[0]

In [59]:
text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder)

tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')


''