Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the beam search generator #38

Closed
geroale opened this issue Mar 1, 2018 · 2 comments
Closed

Remove the beam search generator #38

geroale opened this issue Mar 1, 2018 · 2 comments

Comments

@geroale
Copy link

geroale commented Mar 1, 2018

Hi @pender , first of all, thank you so much for your repo. I found it really helpful for learning about RNNs, the code is enough clear, the model based on reddit well trained and everything is cool.

I have only one curiosity: is it possible to generate the answer instantly from the model, without the char after char generation?
I think that I have understand that this generation effect is in the beam search generator but I can't fully get how it works.
It would be great if the model could write the answer in one instant, or add a parameter by which the user can decide what type of generation use.

Thanks for your work and everything.
Alessandro.

@pender
Copy link
Owner

pender commented Mar 2, 2018

Hi Alessandro, generation for this type of model has to occur one character at a time, because each character is chosen based on all of the previous characters. You can disable beam search by typing "--beam_width 1" during chat, which will make generation faster (and much worse), but it will still choose one character at a time.

@pender pender closed this as completed Mar 2, 2018
@shubhank008
Copy link

@geroale just wanted to share something similar I wanted and how I did it, although it does not return the output instantly, it does return it as a whole (sentence) at once. Usefull for creating APIs or integrating the output in your own code

def chatbot(net, sess, chars, vocab, max_length, beam_width, relevance, temperature, topn, input_text):
    states = initial_state_with_relevance_masking(net, sess, relevance)
    while True:
        #user_input = input('\n> ')
        user_input = input_text
        start_time = time.time()
        user_command_entered, reset, states, relevance, temperature, topn, beam_width = process_user_command(
            user_input, states, relevance, temperature, topn, beam_width)
        if reset: states = initial_state_with_relevance_masking(net, sess, relevance)
        if not user_command_entered:
            beam_width = 1
            states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "> " + user_input + "\n>"))
            computer_response_generator = beam_search_generator(sess=sess, net=net,
                initial_state=copy.deepcopy(states), initial_sample=vocab[' '],
                early_term_token=vocab['\n'], beam_width=beam_width, forward_model_fn=forward_with_mask,
                forward_args={'relevance':relevance, 'mask_reset_token':vocab['\n'], 'forbidden_token':vocab['>'],
                                'temperature':temperature, 'topn':topn})
            #print(chars)
            out_chars = []
            out = []
            for i, char_token in enumerate(computer_response_generator):
                out_chars.append(chars[char_token])
                out.append(possibly_escaped_char(out_chars))
                #print(possibly_escaped_char(out_chars), end='', flush=True)
                states = forward_text(net, sess, states, relevance, vocab, chars[char_token])
                if i >= max_length: break
            #print("".join(out))
            #print("--- %s seconds ---" % (time.time() - start_time))
            return "".join(out)
            states = forward_text(net, sess, states, relevance, vocab, sanitize_text(vocab, "\n> "))

You just store the characters in var[] and once the loop ends, you join them in a single statement.
PS: My code modification is to return the output, if you want to just print it, comment return line and uncomment print.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants