In [12]:
from model_sampler import *

import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from tqdm import trange

from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer

In [13]:
messages = ["Hello my name is Ivan, nice to meet you", 
            "Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.", 
            "Oh, hi, Mark. What a story you just told me..."]


def wrap_message_list(m_list, insert_intro=True, wrap_type='name', check_end_punct=True):
    '''
    Parameters:
    ----------
    m_list : list
        list of messages in chatbot log 
    insert_intro : bool, optional
        whether should it insert the intro about the conversation
    wrap_type : string, optional
        type of conditioning to use ('name', 'name-in-par', 'dash', 'number') 
    check_end_punct : bool, optional
        whether should it check the last symbol of message to have the period etc
    '''
    output = ""
    types = {'name': ('Alice:', 'Bob:'),
            #'name-in-par': ('[Alice]:', '[Bob]:'),
            'dash': ('-', '-'),
            'number': ('1:', '2:')}
    valid_ending = ['.', '!', '?', '\'', '\"']
    
    assert wrap_type in types, "Unknown wrapping"
    
    if(insert_intro):
        output += "<|endoftext|>"#This is the conversation between 2 people."
        
    for i, msg in enumerate(m_list):
        output += types[wrap_type][i%2]
        output += ' '
        output += msg
        if((check_end_punct) and (msg[-1] not in valid_ending)):
            output += '.'
        output += '\n'        
            
    #output += '\n'
    output += types[wrap_type][(i+1)%2]
    
    conditioning = []
    if(types[wrap_type][0][-1] == ':'):
        conditioning.append(types[wrap_type][0][:-1])
        conditioning.append(types[wrap_type][1][:-1])
    else:
        conditioning = list(types[wrap_type])
    
    return output, conditioning

In [14]:
messages

['Hello my name is Ivan, nice to meet you',
 "Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.",
 'Oh, hi, Mark. What a story you just told me...']

In [15]:
text4test, ctokens = wrap_message_list(messages, wrap_type='number')
print(text4test+'|')

<|endoftext|>1: Hello my name is Ivan, nice to meet you.
2: Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.
1: Oh, hi, Mark. What a story you just told me...
2:|


In [16]:
def init_model(seed=0, model_name_or_path='gpt2'):
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    enc = GPT2Tokenizer.from_pretrained(model_name_or_path)
    model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
    
    model = nn.DataParallel(model)
    model.load_state_dict(torch.load("../gpt2_model_3200.pth"))
    model = model.module
    
    model.to(device)
    model.eval()
    return model, enc, device

In [17]:
def model_forward(input_text, conditioning, verbose, *model_params, length=128, top_k=10, temperature=1.0):
    model, enc, device = model_params
    if length == -1:
        length = model.config.n_ctx // 2
    elif length > model.config.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
        
    context_tokens = []
    context_tokens = enc.encode(input_text)
    context_tokens = [50256, 220] + context_tokens
    
    cond_tokens = []
    for token in conditioning:
        cond_tokens += enc.encode(token)
    
    if(verbose):
        print('Detected conditioning tokens:')
        print(cond_tokens)
        print("Input tokens:")
        print(context_tokens)
    
    out = sample_sequence(
        model=model, length=length,
        context=context_tokens, cond_tokens=cond_tokens,
        start_token=None,
        batch_size=1,
        temperature=temperature, top_k=top_k, device=device)
    
    if(verbose):
        print("Out Tokens:") 
        print(out)    
    out = out[:, len(context_tokens):].tolist()
    output_text = enc.decode(out[0])
    return output_text

In [18]:
def produce_answer(user_input, prev_msgs, verbose=False, *model_params, **wrap_params):
    prev_msgs.append(user_input)
    input_text, conditioning = wrap_message_list(prev_msgs, **wrap_params)
    if(verbose):
        print("Model input:\n")
        print(input_text)
    sampled_answer = model_forward(input_text, conditioning, verbose, *model_params)
    if(verbose):
        print("All sampled:\n")
        print(sampled_answer) 
        print("\n\n")
    answer = sampled_answer.split('\n')[0] ### If <end of text> -> send ...
    answer = answer.replace(u'\xa0', '') ### FIX THIS
    
    if(answer[0] == ' '):
        answer = answer[1:]
    
    prev_msgs.append(answer)
    return answer

In [19]:
model, enc, device = init_model(seed=42)

In [30]:
messages = []

In [31]:
produce_answer("Hi! Do you have any hobbies?", messages, False, model, enc, device, insert_intro=False, wrap_type='name')

'No'

In [32]:
produce_answer("Me too. Do you like play computer games?", messages, False, model, enc, device, insert_intro=False, wrap_type='name')

"I'm really just bored."

In [35]:
messages

['Hi! Do you have any hobbies?',
 'No',
 'Me too. Do you like play computer games?',
 "I'm really just bored."]