In [1]:
from model_sampler import *

import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from tqdm import trange

from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
messages = ["Hello my name is Ivan, nice to meet you", 
            "Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.", 
            "Oh, hi, Mark. What a story you just told me..."]


def wrap_message_list(m_list, insert_intro=True, wrap_type='name', check_end_punct=True):
    '''
    Parameters:
    ----------
    m_list : list
        list of messages in chatbot log 
    insert_intro : bool, optional
        whether should it insert the intro about the conversation
    wrap_type : string, optional
        type of conditioning to use ('name', 'name-in-par', 'dash', 'number') 
    check_end_punct : bool, optional
        whether should it check the last symbol of message to have the period etc
    '''
    output = ""
    types = {'name': ('Alice: ', 'Bob: '),
            'name-in-par': ('[Alice]: ', '[Bob]: '),
            'dash': ('-', '-'),
            'number': ('1: ', '2: ')}
    valid_ending = ['.', '!', '?', '\'']
    
    assert wrap_type in types, "Unknown wrapping"
    
    if(insert_intro):
        output += "<|endoftext|>"#This is the conversation between 2 people."
        
    for i, msg in enumerate(m_list):
        output += types[wrap_type][i%2]
        output += msg
        if((check_end_punct) and (msg[-1] not in valid_ending)):
            output += '.'
        output += '\n'        
            
    #output += '\n'
    output += types[wrap_type][(i+1)%2]
    return output

In [3]:
messages

['Hello my name is Ivan, nice to meet you',
 "Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.",
 'Oh, hi, Mark. What a story you just told me...']

In [4]:
text4test = wrap_message_list(messages, wrap_type='name')
print(text4test)

<|endoftext|>Alice: Hello my name is Ivan, nice to meet you.
Bob: Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.
Alice: Oh, hi, Mark. What a story you just told me...
Bob: 


In [5]:
def init_model(seed=0, model_name_or_path='gpt2'):
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    enc = GPT2Tokenizer.from_pretrained(model_name_or_path)
    model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
    
    model = nn.DataParallel(model)
    model.load_state_dict(torch.load("../gpt2_model_3200.pth"))
    model = model.module
    
    model.to(device)
    model.eval()
    return model, enc, device

In [11]:
def model_forward(input_text, *model_params, length=128, top_k=10, temperature=1.0):
    model, enc, device = model_params
    if length == -1:
        length = model.config.n_ctx // 2
    elif length > model.config.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
        
    context_tokens = []
    context_tokens = enc.encode(input_text)
    context_tokens = [50256, 220] + context_tokens
    print("Input tokens")
    print(context_tokens)
    
    out = sample_sequence(
        model=model, length=length,
        context=context_tokens,
        start_token=None,
        batch_size=1,
        temperature=temperature, top_k=top_k, device=device)

    print("Out Tokens") 
    print(out)    
    out = out[:, len(context_tokens):].tolist()
    output_text = enc.decode(out[0])
    return output_text

In [7]:
def produce_answer(user_input, prev_msgs, *model_params, **wrap_params):
    prev_msgs.append(user_input)
    input_text = wrap_message_list(prev_msgs, **wrap_params)
    print("Model input:\n")
    print(input_text)
    sampled_answer = model_forward(input_text, *model_params)
    print("All sampled:\n")
    print(sampled_answer) 
    print("\n\n")
    answer = sampled_answer.split('\n')[0] ### If <end of text. -> send ...
    answer = answer.replace(u'\xa0', u'') ### FIX THIS
    prev_msgs.append(answer)
    return answer

In [8]:
model, enc, device = init_model(seed=42)

In [12]:
messages = []

In [13]:
produce_answer("Hi! Do you have any hobbies", messages, model, enc, device, insert_intro=False, wrap_type='name')

Model input:

Alice: Hi! Do you have any hobbies.
Bob: 
Input tokens
[50256, 220, 44484, 25, 15902, 0, 2141, 345, 423, 597, 45578, 13, 198, 18861, 25, 220]


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:07<00:00, 17.30it/s]


Out Tokens
tensor([[50256,   220, 44484,    25, 15902,     0,  2141,   345,   423,   597,
         45578,    13,   198, 18861,    25,   220,  9805,   198, 44484,    25,
         14373,   198, 44484,    25,  4162,   389,   345,  4737,   703, 12248,
          6530,    30,   198, 18861,    25,  3966,  1392,   340, 12248,   198,
         44484,    25,  9425,   198, 18861,    25,  1867,   338,   534,  1438,
          5633,   220, 13300,   530,   640,   198, 44484,    25, 15929,   198,
         18861,    25,   679,   258,   314,  1101,  4422, 14373,   198,    27,
            91,   437,  1659,  5239,    91,    29, 44484,    25, 23105,   198,
         44484,    25,   285,   393,   277,   198, 18861,    25, 14690,   198,
         18861,    25,   355,    30,   198,    27,    91,   437,  1659,  5239,
            91,    29, 44484,    25, 17207,   198, 44484,    25,   355,    75,
           198, 18861,    25,   355,    75,   198, 18861,    25,  2956,    30,
            77,   198, 18861,    25,  159

'????'

In [26]:
messages

['Hi! Do you have any hobbies', '???????????????']

In [14]:
enc.decode([9805])

'????'