Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full-context alignment is broken in transformer_align model #2673

Closed
senarvi opened this issue Sep 29, 2020 · 1 comment
Closed

Full-context alignment is broken in transformer_align model #2673

senarvi opened this issue Sep 29, 2020 · 1 comment
Labels

Comments

@senarvi
Copy link
Contributor

senarvi commented Sep 29, 2020

🐛 Bug

transformer_align model, which implements the "Jointly Learning to Align and Translate" paper, supports full-context alignment, meaning that the auto-regressive mask is not applied to decoder self-attention. This feature is broken in the current master. When the --full-context-alignment flag is given to the model, it produces the error message

TypeError: forward() got an unexpected keyword argument 'full_context_alignment'

in the forward_decoder() method of TransformerAlignModel. As far as I can see, this happens because the forward() method of TransformerDecoder doesn't have the **extra_args argument anymore, that was passed to the extract_features() method by the version of the code at the time the transformer_align model was merged.

Ideally the test_alignment() unit test would also be updated to test this feature.

To Reproduce

Follow the instructions, but additionally give the --full-context-alignment flag to fairseq-train. It produces the following error message and stack trace:

Traceback (most recent call last):
  File ".../bin/fairseq-train", line 11, in <module>
    load_entry_point('fairseq', 'console_scripts', 'fairseq-train')()
  File ".../fairseq_cli/train.py", line 351, in cli_main
    distributed_utils.call_main(args, main)
  File ".../fairseq/distributed_utils.py", line 254, in call_main
    main(args, **kwargs)
  File ".../fairseq_cli/train.py", line 125, in main
    valid_losses, should_stop = train(args, trainer, task, epoch_itr)
  File "/usr/lib64/python3.6/contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File ".../fairseq_cli/train.py", line 207, in train
    log_output = trainer.train_step(samples)
  File "/usr/lib64/python3.6/contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File ".../fairseq/trainer.py", line 479, in train_step
    ignore_grad=is_dummy_batch,
  File ".../fairseq/tasks/fairseq_task.py", line 408, in train_step
    loss, sample_size, logging_output = criterion(model, sample)
  File ".../lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File ".../fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py", line 36, in forward
    net_output = model(**sample['net_input'])
  File ".../lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File ".../fairseq/models/transformer_align.py", line 51, in forward
    return self.forward_decoder(prev_output_tokens, encoder_out)
  File ".../fairseq/models/transformer_align.py", line 75, in forward_decoder
    **extra_args,
  File ".../lib64/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'full_context_alignment'

Expected behavior

I expect fairseq-train to start training the model, like it does when I don't give the --full-context-alignment flag, but alignment supervised conditioned on the full target context.

Environment

  • fairseq Version (e.g., 1.0 or master): master
  • PyTorch Version (e.g., 1.0): 1.6.0
  • OS (e.g., Linux): Linux
  • How you installed fairseq (pip, source): pip
  • Python version: 3.6.8
@senarvi
Copy link
Contributor Author

senarvi commented Sep 29, 2020

I assume this is how to fix it: #2675

jinyiyang-jhu pushed a commit to jinyiyang-jhu/fairseq-jyang that referenced this issue Feb 26, 2021
Summary:
Fixes facebookresearch/fairseq#2673.

# Before submitting

- [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
- [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)?
- [ ] Did you make sure to update the docs?
- [ ] Did you write any new necessary tests?

## What does this PR do?
Fixes facebookresearch/fairseq#2673 (issue).

## PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

## Did you have fun?
Make sure you had fun coding �

Pull Request resolved: facebookresearch/fairseq#2675

Reviewed By: ngoyal2707

Differential Revision: D24001793

Pulled By: myleott

fbshipit-source-id: 6b4e9270e5f5a31ba1b65ae2ae717019108af913
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants