Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add top-p and top-k sampling to GenerationUtils #2137

Merged

Conversation

yohann-benchetrit
Copy link
Contributor

@yohann-benchetrit yohann-benchetrit commented Apr 3, 2023

  1. Add the following token-decoding helpers to GenerationUtils.generate:
  • top-p
  • top-k
  • remove_invalid_values
  • temperature (see Issue 2138)
  1. Add corresponding unit tests and docstrings.

@yohann-benchetrit yohann-benchetrit force-pushed the ybenchetrit/add-top-pk-sampling branch 3 times, most recently from ac3439b to 37627f7 Compare April 3, 2023 10:11
@github-advanced-security
Copy link

You have successfully added a new CodeQL configuration .github/workflows/codeql.yml:build. As part of the setup process, we have scanned this repository and found 3 existing alerts. Please check the repository Security tab to see all alerts.

@yohann-benchetrit yohann-benchetrit changed the title Add top-p and top-k sampling Add top-p and top-k sampling to GenerationUtils Apr 3, 2023
@yohann-benchetrit yohann-benchetrit force-pushed the ybenchetrit/add-top-pk-sampling branch 3 times, most recently from 2b904f2 to 8331b24 Compare April 3, 2023 10:48
@yohann-benchetrit yohann-benchetrit marked this pull request as ready for review April 3, 2023 11:43
@joecummings
Copy link
Contributor

High level, I notice you implemented the sampling as part of the greedy_search function. Technically, with sampling this name is incorrect. I see a couple of ways of going about this.

  1. Rename the function to a more generic sample method and remove the do_sample parameter to avoid confusion. Then, if top_k or top_p was specified, a true sampling would be implemented. Otherwise, it would essentially be a greedy search. (in this case, as well, top_k should probably default to 1 instead of 0.) This would be the best way to reuse code IMO.

  2. Follow HF and have separate methods for sampling and greedy. From an adoption perspective, this might be beneficial as users of torchtext tend to also have experience with HF; however, there would certainly be some duplicated code.

Thoughts? @yohann-benchetrit

log_probs = F.log_softmax(decoder_output[:, -1], dim=-1)

if do_sample:
probs = log_probs.softmax(dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking double softmax here? Probs are already softmax'd on L91

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spotting this ! It should be a .exp to retrieve the original probabilities indeed.

@joecummings
Copy link
Contributor

Will want to add temperature, but can do that in a follow-up PR. Tracking issue: #2138

@Nayef211
Copy link
Contributor

Nayef211 commented Apr 3, 2023

High level, I notice you implemented the sampling as part of the greedy_search function. Technically, with sampling this name is incorrect. I see a couple of ways of going about this.

  1. Rename the function to a more generic sample method and remove the do_sample parameter to avoid confusion. Then, if top_k or top_p was specified, a true sampling would be implemented. Otherwise, it would essentially be a greedy search. (in this case, as well, top_k should probably default to 1 instead of 0.) This would be the best way to reuse code IMO.
  2. Follow HF and have separate methods for sampling and greedy. From an adoption perspective, this might be beneficial as users of torchtext tend to also have experience with HF; however, there would certainly be some duplicated code.

Thoughts? @yohann-benchetrit

I agree with @joecummings here in that by allowing users to provide top_k or top_p this method is no longer implementing a simple greedy search. In terms of which option to go with, I think for reusability, going for option 1 makes sense. However, I also think that it's more intuitive for users to be able to call a greedy_search method without having to worry about modifying the top_k or top_p params or spend time understanding what sample means? Most NLP users will have heard of beam_search or greedy_search but sample might not be as intuitive.

An alternative proposal here could be to keep the greedy_search method as is and then add the top_k and top_p params only to the beam_search method. This way if users want to do greedy search with those 2 params set, they can just set the beam_size to 1. For this approach, the greedy_search method will be kept simple and intuitive and all the configurability would come from the beam_search method.

@joecummings
Copy link
Contributor

I agree with @joecummings here in that by allowing users to provide top_k or top_p this method is no longer implementing a simple greedy search. In terms of which option to go with, I think for reusability, going for option 1 makes sense. However, I also think that it's more intuitive for users to be able to call a greedy_search method without having to worry about modifying the top_k or top_p params or spend time understanding what sample means? Most NLP users will have heard of beam_search or greedy_search but sample might not be as intuitive.

An alternative proposal here could be to keep the greedy_search method as is and then add the top_k and top_p params only to the beam_search method. This way if users want to do greedy search with those 2 params set, they can just set the beam_size to 1. For this approach, the greedy_search method will be kept simple and intuitive and all the configurability would come from the beam_search method.

@Nayef211 We abstract away the internal methods, so high-level users should never have to call greedy_search or beam_search directly. I see your point though; however, I don't think we should leave complexity to one function or the other. If we really want this to be completely composable and apply to all cases of sampling (which can happen in beam search, too), we may want to combine these methods into one. This seems beyond the scope of this PR though as beam_search is not implemented on main right now.

@yohann-benchetrit
Copy link
Contributor Author

yohann-benchetrit commented Apr 3, 2023

Thoughts? @yohann-benchetrit

Agreed, thanks @joecummings and @Nayef211 for your comments !

So with this additional information my understanding is:

  • generate is the "maximum generality" decoding-interface
  • greedy_search and beam_search are (or will be) implemented on their own as public subroutines of generate but they are not necessarily intended to be directly called by the end-user.

My thoughts on this:

  1. As @Nayef211 suggested and closer to your proposition 2 (HF), I would tend to also leave greedy_search sampling-free and delegate 'sampling with a single beam' (i.e what I implemented here) to a beam_sample method that will build on beam_search once it will be in main.
    However, since beam_search is not currently in main, does this mean that this PR should solely restrict to adding _get_top_{k,p}_restriction methods (with corresponding tests) ?

  2. As a side note, although HF also recommends going solely for using generate, I like the idea of having the fundamentals such as greedy_search and beam_search on the top of the API, for "no-extra-thought" usage.

@Nayef211
Copy link
Contributor

Nayef211 commented Apr 5, 2023

  1. As @Nayef211 suggested and closer to your proposition 2 (HF), I would tend to also leave greedy_search sampling-free and delegate 'sampling with a single beam' (i.e what I implemented here) to a beam_sample method that will build on beam_search once it will be in main.
    However, since beam_search is not currently in main, does this mean that this PR should solely restrict to adding _get_top_{k,p}_restriction methods (with corresponding tests) ?

I like this idea more and unless there are customers that are asking for sampling to be implemented in greedy_search, I would hold off on adding these params to the method and instead just add the helper methods like you mentioned. However if customers are asking for this, I don't see any harm in adding the functionality to greedy_search and removing it later down the line. Ultimately, I'll leave the decision to @joecummings.

@yohann-benchetrit
Copy link
Contributor Author

@joecummings @Nayef211

Thanks again for your comments. I addressed them as follows:

  • I kept only the helper methods _get_top_{k,p}_restriction and deleted their usage in decoding (except for unit tests).
  • I added an _apply_temperature method (and a unit test), as Joe was suggesting, addressing issue #2138 .

@yohann-benchetrit yohann-benchetrit force-pushed the ybenchetrit/add-top-pk-sampling branch 2 times, most recently from 56434b7 to 8e1e907 Compare April 7, 2023 19:58
Address PR comments and add Temperature
@joecummings joecummings self-requested a review April 10, 2023 22:38
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thanks @yohann-benchetrit !

@joecummings joecummings merged commit dffe2cb into pytorch:main Apr 10, 2023
37 of 44 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants