In [1]:
from model_sampler import *

import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from tqdm import trange

from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
messages = ["Hello my name is Ivan, nice to meet you", 
            "Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.", 
            "Oh, hi, Mark. What a story you just told me..."]


def wrap_message_list(m_list, insert_intro=True, wrap_type='name', check_end_punct=True):
    '''
    Parameters:
    ----------
    m_list : list
        list of messages in chatbot log 
    insert_intro : bool, optional
        whether should it insert the intro about the conversation
    wrap_type : string, optional
        type of conditioning to use ('name', 'name-in-par', 'dash', 'number') 
    check_end_punct : bool, optional
        whether should it check the last symbol of message to have the period etc
    '''
    output = ""
    types = {'name': ('Alice:', 'Bob:'),
            #'name-in-par': ('[Alice]:', '[Bob]:'),
            'dash': ('-', '-'),
            'number': ('1:', '2:')}
    valid_ending = ['.', '!', '?']
    
    assert wrap_type in types, "Unknown wrapping"
    
    if(insert_intro):
        output += "<|endoftext|>"
        output += "This is the conversation between 2 people.\n"
        
    for i, msg in enumerate(m_list):
        output += types[wrap_type][i%2]
        output += ' '
        output += msg
        if((check_end_punct) and (msg[-1] not in valid_ending)):
            output += '.'
        output += '\n'        
            
    #output += '\n'
    output += types[wrap_type][(i+1)%2]
    
    if(types[wrap_type][0][-1] == ':'):
        conditioning = types[wrap_type][0][:-1]
    else:
        conditioning = types[wrap_type][0]
    
    return output, [conditioning]

In [3]:
messages

['Hello my name is Ivan, nice to meet you',
 "Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.",
 'Oh, hi, Mark. What a story you just told me...']

In [4]:
text4test, ctoken = wrap_message_list(messages, wrap_type='name')
print(text4test+'|')
print("Stop token:", ctoken)

<|endoftext|>This is the conversation between 2 people.
Alice: Hello my name is Ivan, nice to meet you.
Bob: Hello, Ivan, nive to meet you too. I'm a YetAnotherChatbot.
Alice: Oh, hi, Mark. What a story you just told me...
Bob:|
Stop token: ['Alice']


In [5]:
def init_model(seed=0, model_name_or_path='gpt2'):
    '''
    Parameters:
    ----------
    seed : int
        seed number for different ramdomizers
    model_name_or_path : string, optional
        either model name for existing model or path for trained model
    '''
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    enc = GPT2Tokenizer.from_pretrained('gpt2')
    model = GPT2LMHeadModel.from_pretrained('gpt2')
    
    model = nn.DataParallel(model)
    model.load_state_dict(torch.load(model_name_or_path))
    model = model.module
    
    model.to(device)
    model.eval()
    return model, enc, device

In [6]:
def model_forward(input_text, conditioning, verbose, *model_params, length=128, top_k=10, temperature=1.0):
    '''
    Parameters:
    ----------
    input_text : string
        input text for sampling
    *model_params : tuple
        (model, enc, device) output of 'init_model' function
    length : int, optional
        length of generated sample I guess (!!not sure!!)
    top_k : int, optional
        to generate k most probable samples (!!not sure!!)
    temperature: float, optional
        parameter of sampling algorithm
    '''
    model, enc, device = model_params
    if length == -1:
        length = model.config.n_ctx // 2
    elif length > model.config.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
        
    context_tokens = []
    context_tokens = enc.encode(input_text)
    context_tokens = [50256] + context_tokens
    
    cond_tokens = []
    for token in conditioning:
        cond_tokens += enc.encode(token)
    
    if(verbose):
        print('Detected conditioning tokens:')
        print(cond_tokens)
        print("Input tokens:")
        print(context_tokens)
    
    out = sample_sequence(
        model=model, length=length,
        context=context_tokens, cond_tokens=cond_tokens,
        start_token=None,
        batch_size=1,
        temperature=temperature, top_k=top_k, device=device)
    
    if(verbose):
        print("Out Tokens:") 
        print(out)    
    out = out[:, len(context_tokens):].tolist()
    output_text = enc.decode(out[0])
    return output_text

In [109]:
def split(string, delimeters):
    sentences = []
    prev_end = 0
    for i in range(1, len(string)):
        if((string[i-1] in delimeters) & (string[i] not in delimeters)):
            str_to_append = string[prev_end:i-1].strip() + string[i-1]
            sentences.append((str_to_append[0].upper() + str_to_append[1:]).replace('¿', ''))
            prev_end = i
    str_to_append = string[prev_end:-1].strip() + string[-1]
    #print("Strting to append:", '\''+str_to_append+'\'')
    sentences.append((str_to_append[0].upper() + str_to_append[1:]).replace('¿', ''))
    return sentences

In [110]:
def output_post_processing(input_quote, max_words):
    '''
    Parameters:
    ----------
    input_quote : string
        output of model
    max_words : integer
        maximal number of words (rounded to the end of sentence with last word) to concatenate to phrase
    '''
    valid_endings = ['.', '!', '?', '¿']
    input_quote = input_quote.replace(u'\xa0', '') # filter out "\xa0" 
    input_quote = input_quote.replace('\n', ' ') # filter out "\n" 
    
    first_endoftext = input_quote.find('<|endoftext|>') 
    if(first_endoftext != -1):
        input_quote = input_quote[:first_endoftext] # cut the string when '<|endoftext|>' found'
        
    first_Alice = input_quote.find('Alice:') 
    if(first_Alice != -1):
        input_quote = input_quote[:first_Alice] # cut the string when 'Alice:' found'

    input_quote = input_quote.replace("Bob:", '¿') # filter out "Bob: "
    print('Partially processed: ')
    print(input_quote+'|')
    
    sentences = split(input_quote.strip(), valid_endings) # sprit remaining string to sentences according to delimiters
    
    print(sentences)
    
    sentences = list(filter(None, sentences)) # filter out empty strings
    sentences = list(filter(lambda x: x != '¿', sentences))  # filter out empty strings
    
    print(sentences)

    for i, sentence in enumerate(sentences): # add periods where nessecary
#         for j in range(len(sentence)-1, 1, -1):
#             if((not sentence[j].isalnum()) & (sentence[j-1].isalnum())):
#                 left_part = sentence[:j]
#                 right_part = sentence[j:]
#                 print("\n\n!!!SPACE BETWEEN PUNCTUATION AND WORDS!!!")
#                 print("BEFORE:")
#                 print(left_part, right_part)
#                 right_part = right_part.replace(' ', '')
#                 print("AFTER:")
#                 print(left_part + right_part)
#                 sentences[i] = left_part + right_part
#                 break
        
        if(sentence[-1] not in valid_endings):
            sentences[i] += '.'
    
    word_counts = [len(s.split(' ')) for s in sentences]
    word_cum_counts = np.cumsum(np.array(word_counts)) / max_words
    sentences_to_pass = np.sum(word_cum_counts < 1.0) + 1
    
    return " ".join(sentences[:sentences_to_pass])

In [111]:
s = "Kek .. ?? Bob: Kek, definetely kek"
output_post_processing(s, 30)

Partially processed: 
Kek .. ?? ¿ Kek, definetely kek|
['Kek ..', '??', '', 'Kek, definetely kek']
['Kek ..', '??', 'Kek, definetely kek']


'Kek .. ?? Kek, definetely kek.'

In [112]:
def produce_answer(user_input, prev_msgs, max_words, verbose=False, *model_params, **wrap_params):
    '''
    Parameters:
    ----------
    user_input : string
        user's message
    prev_msgs : list
        list of previous messages in conversation
    max_words : integer
        number of words to generate (rounded to the end of last sentence)
    *model_params : tuple
        (model, enc, device) output of 'init_model' function
    **wrap_parameters : dict
        parametrs for 'wrap_message_list' function like `wrap_type`    
    '''
    prev_msgs.append(user_input)
    input_text, conditioning = wrap_message_list(prev_msgs, **wrap_params)
    if(verbose):
        print("Model input:\n")
        print(input_text)
    sampled_answer = model_forward(input_text, conditioning, verbose, *model_params)
    if(verbose):
        print("All sampled:\n")
        print(sampled_answer) 
        print("\n\n")
        
    answer = output_post_processing(sampled_answer, max_words)
    
    prev_msgs.append(answer)
    return answer

In [113]:
model, enc, device = init_model(seed=42)

In [114]:
messages = []

In [115]:
produce_answer("Hi! Do you have any hobbies?", messages, 30, True, 
               model, enc, device, 
               insert_intro=False, wrap_type='name')

Model input:

Alice: Hi! Do you have any hobbies?
Bob:
Detected conditioning tokens:
[44484]
Input tokens:
[50256, 220, 44484, 25, 15902, 0, 2141, 345, 423, 597, 45578, 30, 198, 18861, 25]
Out Tokens:
tensor([[50256,   220, 44484,    25, 15902,     0,  2141,   345,   423,   597,
         45578,    30,   198, 18861,    25,  1312, 17666,   466,   881,   290,
          1312,   588,   284,  1561,   546,   616, 45578,   198, 18861,    25,
          1312,   711, 10047,   198, 18861,    25,  1312,   423,   645, 45578,
           198]], device='cuda:0')
All sampled:

 i dont do much and i like to talk about my hobbies
Bob: i play guitar
Bob: i have no hobbies




Partially processed: 
 i dont do much and i like to talk about my hobbies ¿ i play guitar ¿ i have no hobbies |
['I dont do much and i like to talk about my hobbies', 'I play guitar', 'I have no hobbies']
['I dont do much and i like to talk about my hobbies', 'I play guitar', 'I have no hobbies']


'I dont do much and i like to talk about my hobbies. I play guitar. I have no hobbies.'

In [116]:
produce_answer("Me too. Do you like play computer games?", messages, 30, True, 
               model, enc, device, 
               insert_intro=False, wrap_type='name')

Model input:

Alice: Hi! Do you have any hobbies?
Bob: I dont do much and i like to talk about my hobbies. I play guitar. I have no hobbies.
Alice: Me too. Do you like play computer games?
Bob:
Detected conditioning tokens:
[44484]
Input tokens:
[50256, 220, 44484, 25, 15902, 0, 2141, 345, 423, 597, 45578, 30, 198, 18861, 25, 314, 17666, 466, 881, 290, 1312, 588, 284, 1561, 546, 616, 45578, 13, 314, 711, 10047, 13, 314, 423, 645, 45578, 13, 198, 44484, 25, 2185, 1165, 13, 2141, 345, 588, 711, 3644, 1830, 30, 198, 18861, 25]
Out Tokens:
tensor([[50256,   220, 44484,    25, 15902,     0,  2141,   345,   423,   597,
         45578,    30,   198, 18861,    25,   314, 17666,   466,   881,   290,
          1312,   588,   284,  1561,   546,   616, 45578,    13,   314,   711,
         10047,    13,   314,   423,   645, 45578,    13,   198, 44484,    25,
          2185,  1165,    13,  2141,   345,   588,   711,  3644,  1830,    30,
           198, 18861,    25,  9425,     0,   198]], device='cu

'Yeah!'

In [89]:
messages

['Hi! Do you have any hobbies?',
 'I dont do much and i like to talk about my hobbies. I play guitar. I have no hobbies.',
 'Me too. Do you like play computer games?',
 'Yeah!']

In [16]:
enc.encode(":")

[25]