In [22]:
from model_sampler import *

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import trange

from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer

In [17]:
messages = ["Hello my name is Ivan, nice to meet you", 
            "Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.", 
            "Oh, hi, Mark. What a story you just told me..."]


def wrap_message_list(m_list, insert_intro=True, wrap_type='name', check_end_punct=True):
    '''
    Parameters:
    ----------
    m_list : list
        list of messages in chatbot log 
    insert_intro : bool, optional
        whether should it insert the intro about the conversation
    wrap_type : string, optional
        type of conditioning to use ('name', 'name-in-par', 'dash', 'number') 
    check_end_punct : bool, optional
        whether should it check the last symbol of message to have the period etc
    '''
    output = ""
    types = {'name': ('Alice: ', 'Bob: '),
            'name-in-par': ('[Alice]: ', '[Bob]: '),
            'dash': ('-', '-'),
            'number': ('1: ', '2: ')}
    valid_ending = ['.', '!', '?', '\'']
    
    assert wrap_type in types, "Unknown wrapping"
    
    if(insert_intro):
        output += "This is the conversation between 2 people.\n"
        
    for i, msg in enumerate(m_list):
        output += '\n'        
        output += types[wrap_type][i%2]
        output += msg
        if((check_end_punct) and (msg[-1] not in valid_ending)):
            output += '.'
            
    output += '\n'
    output += types[wrap_type][(i+1)%2]
    return output

In [18]:
messages

['Hello my name is Ivan, nice to meet you',
 "Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.",
 'Oh, hi, Mark. What a story you just told me...']

In [41]:
text4test = wrap_message_list(messages, wrap_type='name')
print(text4test)

This is the conversation between 2 people.

Alice: Hello my name is Ivan, nice to meet you.
Bob: Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.
Alice: Oh, hi, Mark. What a story you just told me...
Bob: 


In [35]:
def init_model(seed=0, model_name_or_path='gpt2'):
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    enc = GPT2Tokenizer.from_pretrained(model_name_or_path)
    model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
    # model.load_state_dict(torch.load( ... )) uncommet it to make it work....
    model.to(device)
    model.eval()
    return model, enc, device

In [88]:
def model_forward(input_text, *model_params, length=-1, top_k=0, temperature=1.0):
    model, enc, device = model_params
    if length == -1:
        length = model.config.n_ctx // 2
    elif length > model.config.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
        
    context_tokens = []
    context_tokens = enc.encode(input_text)

    out = sample_sequence(
        model=model, length=length,
        context=context_tokens,
        start_token=None,
        batch_size=1,
        temperature=temperature, top_k=top_k, device=device)

    out = out[:, len(context_tokens):].tolist()
    output_text = enc.decode(out[0])
    return output_text

In [97]:
def produce_answer(user_input, prev_msgs, *model_params, **wrap_params):
    prev_msgs.append(user_input)
    input_text = wrap_message_list(prev_msgs, **wrap_params)
    
    sampled_answer = model_forward(input_text, *model_params)
    print("All sampled:\n", sampled_answer, "\n\n")
    answer = sampled_answer.split('\n')[0]
    prev_msgs.append(answer)
    return answer

In [98]:
model, enc, device = init_model(seed=42)

In [99]:
messages = []

In [102]:
produce_answer("Hello, my name is Ivan, nice to meet you!", messages, model, enc, device, wrap_type='name')

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 512/512 [00:26<00:00, 23.12it/s]


All sampled:
 ????? Right now? ?????

Alice: ????????

Webcam Kodak will come out on Feb. 27. Counter's cameras will be shown, but it seems Congress is going to run the numbers.<|endoftext|>University of Texas President Dawn Canady: Woman kissed NHL star

University of Texas President Dawn Canady: Bureau of Stadium Services no longer includes cross references to nude photos in bus advertisements, commentator says

Election Consultancy group is now customers of one Miss San Antonio. The news burns right in the wrong wing of the computer somewhere. Please fix this. Democrat Naheed Nenshi, who owned an alpacas dealership in Houston Thursday night, fears a possible newspaper article listing street fighting sets off a news storm this week over "distracted drivers" using cross references in advertising.

Yehuda said physiotherapist David Serutter of Saint Mary's Athletic Club paid $5 a day for Lat chefs Evangeline Bellingham and John Sacrange. Serutter said Thursday that flights "remain no m

'????? Right now? ?????'

In [103]:
messages

['Hello, my name is Ivan, nice to meet you!',
 '????? Right now? ?????',
 'Hello, my name is Ivan, nice to meet you!',
 '????? Right now? ?????']