Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing prefix_allowed_tokens_fn #3276

Conversation

nicola-decao
Copy link
Contributor

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes the use of prefix_allowed_tokens_fn in generation. It was working for fairseq==0.9.0 (see https://github.com/facebookresearch/GENRE) but with the current version is broken.

PR review

Anyone in the community is free to review the PR once the tests have passed.

Did you have fun?

Make sure you had fun coding 🙃

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@myleott has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@erip
Copy link
Contributor

erip commented Mar 1, 2021

I'm interested in seeing this land, but I was curious if it would be possible to also include a bit of documentation about this prefix_allowed_tokens_fn callable. I can't seem to find anything which explains what it's supposed to do or what shape it's supposed to be (f: (Tensor, list[int]) -> int?)

@nicola-decao
Copy link
Contributor Author

I'm interested in seeing this land, but I was curious if it would be possible to also include a bit of documentation about this prefix_allowed_tokens_fn callable. I can't seem to find anything which explains what it's supposed to do or what shape it's supposed to be (f: (Tensor, list[int]) -> int?)

@erip you are right. Where do you think is the best place to write the signature of the function? here? https://github.com/pytorch/fairseq/blob/e5e8b3fee1e57a7abf35ad1a3ff223a2b7190c65/fairseq/search.py#L148

@erip
Copy link
Contributor

erip commented Mar 1, 2021

@nicola-decao I think that makes good sense. There doesn't seem to be any documentation on the other search strategies, but this one is somewhat less straightforward since it's got the callback. Unless @myleott has other thoughts, I think throwing a docstring beneath the ctor would be great.

@myleott
Copy link
Contributor

myleott commented Mar 1, 2021

@nicola-decao if you can share a docstring here, I can update the imported version before merging

@nicola-decao
Copy link
Contributor Author

@nicola-decao if you can share a docstring here, I can update the imported version before merging

@myleott Here you go:

prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]]: If provided, this function constrains the beam search to allowed tokens only at each step. If not provided no constraint is applied. This function takes 2 arguments: the batch ID batch_id: int and a unidimensional tensor of token ids inputs_ids: torch.Tensor. It has to return a List[int] with the allowed tokens for the next generation step conditioned on the previously generated tokens inputs_ids and the batch ID batch_id. This argument is useful for constrained generation conditioned on the prefix, as described in Autoregressive Entity Retrieval https://arxiv.org/abs/2010.00904 and https://github.com/facebookresearch/GENRE.

@nicola-decao
Copy link
Contributor Author

@myleott Any news on this? Is there something I should do?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@myleott has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@nicola-decao
Copy link
Contributor Author

nicola-decao commented Apr 5, 2021

@myleott Any news on this? I have a facebook AI project https://github.com/facebookresearch/GENRE that depends on this bug fix (for now I link people to my fork with the fix that is not ideal).

@nicola-decao
Copy link
Contributor Author

@myleott @sshleifer can we please proceed on the merge here? It is really a minor change.

There is this Facebook AI project https://github.com/facebookresearch/GENRE that depends on this bug fix (for now I link people to my fork with the fix that is not ideal and may have trouble installing it).

@erip
Copy link
Contributor

erip commented May 26, 2021

also cc @alexeib

facebook-github-bot pushed a commit that referenced this pull request May 27, 2021
Summary:
# Before submitting

- [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [x] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?

## What does this PR do?
Fixes the use of `prefix_allowed_tokens_fn` in generation. It was working for `fairseq==0.9.0` (see https://github.com/facebookresearch/GENRE) but with the current version is broken.

## PR review
Anyone in the community is free to review the PR once the tests have passed.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: #3276

Reviewed By: alexeib

Differential Revision: D26725494

Pulled By: myleott

fbshipit-source-id: ce3da725f36352687e5cb5d62a59b4c89ce0b0bc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants