Skip to content

Commit

Permalink
Merge TracingCompliantTransformer and regular Transformer, fix NAT tests
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: fairinternal/fairseq-py#899

Differential Revision: D18373060

Pulled By: myleott

fbshipit-source-id: bb5510ec15799a0a10a7c0669e76d8200e1ba479
  • Loading branch information
myleott authored and facebook-github-bot committed Nov 13, 2019
1 parent 2a9b4ec commit 27568a7
Show file tree
Hide file tree
Showing 14 changed files with 551 additions and 1,191 deletions.
2 changes: 1 addition & 1 deletion fairseq/criterions/nat_loss.py
Expand Up @@ -48,7 +48,7 @@ def mean_ds(x: Tensor, dim=None) -> Tensor:
if masks is not None:
outputs, targets = outputs[masks], targets[masks]

if not masks.any():
if masks is not None and not masks.any():
nll_loss = torch.tensor(0)
loss = nll_loss
else:
Expand Down
56 changes: 38 additions & 18 deletions fairseq/iterative_refinement_generator.py
Expand Up @@ -3,11 +3,20 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from collections import namedtuple

import torch

from fairseq import utils
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
from fairseq.models.model_utils import script_skip_tensor_list, skip_tensors as _skip
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT


DecoderOut = namedtuple('IterativeRefinementDecoderOut', [
'output_tokens',
'output_scores',
'attn',
'step',
'max_step',
])


class IterativeRefinementGenerator(object):
Expand Down Expand Up @@ -88,6 +97,8 @@ def generate_batched_itr(

@torch.no_grad()
def generate(self, models, sample, prefix_tokens=None):
from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel
from fairseq.models.nonautoregressive_ensembles import EnsembleLevT

if len(models) == 1:
# Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
Expand All @@ -110,7 +121,7 @@ def generate(self, models, sample, prefix_tokens=None):

# initialize buffers (very model specific, with length prediction or not)
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
prev_output_tokens = prev_decoder_out[0].clone()
prev_output_tokens = prev_decoder_out.output_tokens.clone()

finalized = [[] for _ in range(bsz)]

Expand Down Expand Up @@ -150,8 +161,10 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
"max_ratio": self.max_ratio,
"decoding_format": self.decoding_format,
}
prev_decoder_out[3] = step
prev_decoder_out[4] = self.max_iter + 1
prev_decoder_out = prev_decoder_out._replace(
step=step,
max_step=self.max_iter + 1,
)

decoder_out = model.forward_decoder(
prev_decoder_out, encoder_out, **decoder_options
Expand All @@ -160,24 +173,26 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
if self.adaptive:
# terminate if there is a loop
terminated, out_tokens, out_scores, out_attn = is_a_loop(
prev_output_tokens, decoder_out[0], decoder_out[1], decoder_out[2]
prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn
)
decoder_out = decoder_out._replace(
output_tokens=out_tokens,
output_scores=out_scores,
attn=out_attn,
)
decoder_out[0] = out_tokens
decoder_out[1] = out_scores
decoder_out[2] = out_attn

else:
terminated = decoder_out[0].new_zeros(decoder_out[0].size(0)).bool()
terminated = decoder_out.output_tokens.new_zeros(decoder_out.output_tokens.size(0)).bool()

if step == self.max_iter: # reach last iteration, terminate
terminated.fill_(1)

# collect finalized sentences
finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out[0][terminated]
finalized_scores = decoder_out[1][terminated]
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
None if decoder_out[2] is None else decoder_out[2][terminated]
None if decoder_out.attn is None else decoder_out.attn[terminated]
)

for i in range(finalized_idxs.size(0)):
Expand All @@ -194,10 +209,15 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
break

# for next step
prev_decoder_out = _skip(decoder_out, ~terminated)
encoder_out = script_skip_tensor_list(encoder_out, ~terminated)
sent_idxs = _skip(sent_idxs, ~terminated)
not_terminated = ~terminated
prev_decoder_out = decoder_out._replace(
output_tokens=decoder_out.output_tokens[not_terminated],
output_scores=decoder_out.output_scores[not_terminated],
attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None,
)
encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze())
sent_idxs = sent_idxs[not_terminated]

prev_output_tokens = prev_decoder_out[0].clone()
prev_output_tokens = prev_decoder_out.output_tokens.clone()

return finalized
16 changes: 10 additions & 6 deletions fairseq/models/cmlm_transformer.py
Expand Up @@ -10,9 +10,9 @@
arXiv preprint arXiv:1904.09324 (2019).
"""

from fairseq.utils import new_arange
from fairseq.models import register_model, register_model_architecture
from fairseq.models.nonautoregressive_transformer import NATransformerModel
from fairseq.utils import new_arange


def _skeptical_unmasking(output_scores, output_masks, p):
Expand Down Expand Up @@ -55,11 +55,11 @@ def forward(

def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs):

step = decoder_out["step"]
max_step = decoder_out["max_step"]
step = decoder_out.step
max_step = decoder_out.max_step

output_tokens = decoder_out["output_tokens"]
output_scores = decoder_out["output_scores"]
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores

# execute the decoder
output_masks = output_tokens.eq(self.unk)
Expand All @@ -78,7 +78,11 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar
output_tokens.masked_fill_(skeptical_mask, self.unk)
output_scores.masked_fill_(skeptical_mask, 0.0)

return {"output_tokens": output_tokens, "output_scores": output_scores}
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)


@register_model_architecture("cmlm_transformer", "cmlm_transformer")
Expand Down
19 changes: 12 additions & 7 deletions fairseq/models/insertion_transformer.py
Expand Up @@ -6,14 +6,15 @@
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.utils import new_arange

from fairseq.models import register_model, register_model_architecture
from fairseq.models.levenshtein_transformer import (
LevenshteinTransformerDecoder,
LevenshteinTransformerModel,
)
from fairseq.models.transformer import Linear, TransformerModel
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import new_arange


class NegativeDistanceScore(object):
Expand Down Expand Up @@ -116,8 +117,8 @@ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, paddi

@register_model("insertion_transformer")
class InsertionTransformerModel(LevenshteinTransformerModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
def __init__(self, args, encoder, decoder):
super().__init__(args, encoder, decoder)

@staticmethod
def add_args(parser):
Expand Down Expand Up @@ -169,8 +170,8 @@ def forward_decoder(
self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs
):

output_tokens = decoder_out["output_tokens"]
output_scores = decoder_out["output_scores"]
output_tokens = decoder_out.output_tokens
output_scores = decoder_out.output_scores
# TODO: decoding for InsertionTransformer
word_ins_out = self.decoder.forward_word_ins(
output_tokens, encoder_out=encoder_out
Expand All @@ -187,7 +188,11 @@ def forward_decoder(
cut_off = output_tokens.ne(self.pad).sum(1).max()
output_tokens = output_tokens[:, :cut_off]
output_scores = output_scores[:, :cut_off]
return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None}
return decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=None,
)


class InsertionTransformerDecoder(LevenshteinTransformerDecoder):
Expand All @@ -206,7 +211,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
self.label_tau = getattr(args, "label_tau", None)

def forward_word_ins(self, prev_output_tokens, encoder_out=None):
features, _ = self.extract_features(prev_output_tokens, encoder_out=encoder_out)
features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0]
features = self.pool_out(
torch.cat([features[:, :-1, :], features[:, 1:, :]], 2)
)
Expand Down
1 change: 1 addition & 0 deletions fairseq/models/iterative_nonautoregressive_transformer.py
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import torch

from fairseq.models import register_model, register_model_architecture
from fairseq.models.nonautoregressive_transformer import NATransformerModel

Expand Down

0 comments on commit 27568a7

Please sign in to comment.