In [1]:
import os
import logging
from pprint import pformat
from argparse import ArgumentParser
from collections import defaultdict
from itertools import chain
import warnings
import random

import torch.nn.functional as F
import torch
# from torch.nn.parallel import DistributedDataParallel
# from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler
# from ignite.engine import Engine, Events
# from ignite.handlers import ModelCheckpoint
# from ignite.metrics import Accuracy, Loss, MetricsLambda, RunningAverage
# from ignite.contrib.handlers import ProgressBar, PiecewiseLinear
# from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler
from transformers import (AdamW, OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
                                  GPT2DoubleHeadsModel, GPT2Tokenizer, WEIGHTS_NAME, CONFIG_NAME)
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2LMHeadModel, GPT2Tokenizer

from transfer_learning_conv_ai.utils import get_dataset, make_logdir
from tqdm import tqdm, trange

In [4]:
class args:
    dataset_path=''
    dataset_cache='./dataset_cache'
    model='openai-gpt'
    model_checkpoint='May24_07-02-15_ip-172-31-32-191_openai-gpt'
    output_dir = ''
    device="cuda" if torch.cuda.is_available() else "cpu"
    max_history=2
    no_sample=False
    max_length=20
    min_length=1
    temperature=0.7
    top_k=0
    top_p=0.9
    local_rank = -1

args=args 

In [5]:
SPECIAL_TOKENS = ["<bos>", "<eos>", "<speaker1>", "<speaker2>", "<pad>"]
ATTR_TO_SPECIAL_TOKEN = {"bos_token": "<bos>", "eos_token": "<eos>", 
                  "additional_special_tokens": ["<speaker1>", "<speaker2>"],
                  "pad_token": "<pad>"}
MODEL_INPUTS = ["input_ids", "mc_token_ids", "lm_labels", "mc_labels", "token_type_ids"]
PADDED_INPUTS = ["input_ids", "lm_labels", "token_type_ids"]

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Running process %d", args.local_rank)  # This is a logger.warning: it will be printed by all distributed processes
logger.info("Arguments: %s", pformat(args))
logger.info("Prepare tokenizer, pretrained model and optimizer.")

tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer # cant use Autotokenizer because checkpoint could be a Path
model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel

# tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
# model = model_class.from_pretrained(args.model_checkpoint)
# model.to(args.device)

eval_outputs_dirs = './runs/May24_07-02-15_ip-172-31-32-191_openai-gpt'

#reload checkpoints and evaluate on test dataset
model = model_class.from_pretrained(eval_outputs_dirs)
tokenizer = tokenizer_class.from_pretrained(eval_outputs_dirs, do_lower_case=True)
model.to(args.device)

INFO:__main__:Arguments: <class '__main__.args'>
INFO:__main__:Prepare tokenizer, pretrained model and optimizer.
INFO:transformers.configuration_utils:loading configuration file ./runs/May24_07-02-15_ip-172-31-32-191_openai-gpt/config.json
INFO:transformers.configuration_utils:Model config OpenAIGPTConfig {
  "afn": "gelu",
  "architectures": [
    "OpenAIGPTDoubleHeadsModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": null,
  "do_sample": false,
  "embd_pdrop": 0.1,
  "eos_token_ids": null,
  "finetuning_task": null,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_epsilon": 1e-05,
  "length_penalty": 1.0,
  "max_length": 20,
  "model_type": "openai-gpt",
  "n_ctx": 512,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 512,
  "n_special": 0,
  "num_beams": 1,
  "num_labels": 1,
  "num_return_sequences": 1,
  "output_attentions": false,

OpenAIGPTDoubleHeadsModel(
  (transformer): OpenAIGPTModel(
    (tokens_embed): Embedding(40483, 768)
    (positions_embed): Embedding(512, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
      (1): Block(
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_1): LayerNorm((768,), eps=1e-05, elementwis

In [6]:
def add_special_tokens_(model, tokenizer):
    """ Add special tokens to the tokenizer and the model if they have not already been added. """
    orig_num_tokens = len(tokenizer.encoder)
    num_added_tokens = tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN) # doesn't add if they are already there
    if num_added_tokens > 0:
        model.resize_token_embeddings(new_num_tokens=orig_num_tokens + num_added_tokens)

def build_input_from_segments(persona, history, reply, tokenizer, lm_labels=False, with_eos=True):
    """ Build a sequence of input from 3 segments: persona, history and last reply. """
    bos, eos, speaker1, speaker2 = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[:-1])
    sequence = [[bos] + list(chain(*persona))] + history + [reply + ([eos] if with_eos else [])]
    sequence = [sequence[0]] + [[speaker2 if (len(sequence)-i) % 2 else speaker1] + s for i, s in enumerate(sequence[1:])]
    #[[bos+persona], [history], [reply+eos]]
    instance = {}
    instance["input_ids"] = list(chain(*sequence))
    instance["token_type_ids"] = [speaker2 if i % 2 else speaker1 for i, s in enumerate(sequence) for _ in s]
    instance["mc_token_ids"] = len(instance["input_ids"]) - 1
    instance["lm_labels"] = [-100] * len(instance["input_ids"])
    if lm_labels: #if the current candidate is lm_labels, [-100]*[len(persona)+len(history)+1(speaker2)]+current candidate
        instance["lm_labels"] = ([-100] * sum(len(s) for s in sequence[:-1])) + [-100] + sequence[-1][1:]
    return instance

In [7]:
def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k: <=0: no filtering, >0: keep only top k tokens with highest probability.
            top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset
                whose total probability mass is greater than or equal to the threshold top_p.
                In practice, we select the highest probability tokens whose cumulative probability mass exceeds
                the threshold top_p.
            threshold: a minimal threshold to keep logits
    """
    assert logits.dim() == 1  # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        # Remove all tokens with a probability less than the last token in the top-k tokens
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # Compute cumulative probabilities of sorted tokens
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probabilities > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Back to unsorted indices and set them to -infinity
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    indices_to_remove = logits < threshold
    logits[indices_to_remove] = filter_value

    return logits

In [8]:
def sample_sequence(personality, history, tokenizer, model, args, current_output=None):
    special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
    if current_output is None:
        current_output = []

    for i in range(args.max_length):
        instance = build_input_from_segments(personality, history, current_output, tokenizer, with_eos=False)

        input_ids = torch.tensor(instance["input_ids"], device=args.device).unsqueeze(0)
        token_type_ids = torch.tensor(instance["token_type_ids"], device=args.device).unsqueeze(0)

        logits = model(input_ids, token_type_ids=token_type_ids)
        if isinstance(logits, tuple):  # for gpt2 and maybe others
            logits = logits[0]
        logits = logits[0, -1, :] / args.temperature
        logits = top_filtering(logits, top_k=args.top_k, top_p=args.top_p)
        probs = F.softmax(logits, dim=-1)

        prev = torch.topk(probs, 1)[1] if args.no_sample else torch.multinomial(probs, 1)
        if i < args.min_length and prev.item() in special_tokens_ids:
            while prev.item() in special_tokens_ids:
                if probs.max().item() == 1:
                    warnings.warn("Warning: model generating special token with probability 1.")
                    break  # avoid infinitely looping over special token
                prev = torch.multinomial(probs, num_samples=1)

        if prev.item() in special_tokens_ids:
            break
        current_output.append(prev.item())

    return current_output

In [14]:
logger.info("Sample a personality")

dataset_cache = args.dataset_cache
dataset_cache = dataset_cache + '_' + type(tokenizer).__name__ 

dataset = torch.load(dataset_cache)#get_dataset(tokenizer, args.dataset_path, args.dataset_cache)
personalities = [dialog["personality"] for dataset in dataset.values() for dialog in dataset]
personality = random.choice(personalities)
logger.info("Selected personality: %s", tokenizer.decode(chain(*personality)))


INFO:__main__:Sample a personality
INFO:__main__:Selected personality: i've 2 cats. i love to go to the beach. my favorite food is strawberries. i work in a veterinary office. i'm vegan.


In [13]:
# history = []
# current_output = []

# raw_text = 'Hello, how are you?'
# history.append(tokenizer.encode(raw_text))
# instance = build_input_from_segments(personality, history, current_output, tokenizer, with_eos=False)
# with torch.no_grad():
#     out_ids = sample_sequence(personality, history, tokenizer, model, args)


In [16]:
history = []
while True:
    raw_text = input(">>> ")
    while not raw_text:
        print('Prompt should not be empty!')
        raw_text = input(">>> ")
    history.append(tokenizer.encode(raw_text))
    with torch.no_grad():
        out_ids = sample_sequence(personality, history, tokenizer, model, args)
    history.append(out_ids)
    history = history[-(2*args.max_history+1):]
    out_text = tokenizer.decode(out_ids, skip_special_tokens=True)
    print(out_text)

>>> hello, how are you doing
i am doing great. just got back from the beach
>>> I was writing a blog post about you :-)
really? i love cats. what type of blog?
>>> I write mostly tutorials
i'm a vet, i have 2 dogs and i love cats
>>> what beach did you go?
i go to the beach often
>>> sounds interesting, I love movies
i like to watch movies
>>> I love movies
what is your favorite
>>> zombie series
i love to eat strawberries
>>> 
Prompt should not be empty!
>>> what movie do you like
i love the ones with the stars
>>> do you like star wars?
star wars is my favorite.
>>> Which character in star war do you like?
i like batman
>>> I like batman too!
i also like twilight.
>>> I don't like twilight
i like to go to the beach sometimes
>>> can we go together?
we can go to the beach?
>>> can we go together?
i can bring my cats.
>>> can we go to the beach together?
i love cats!
>>> do you have kids?
i don't have any kids but i 'd like some
>>> I have two daughters
how many kids do you have?
>>> I

KeyboardInterrupt: 