diff --git a/pytorch_translate/ensemble_export.py b/pytorch_translate/ensemble_export.py index 7d691131..9af75354 100644 --- a/pytorch_translate/ensemble_export.py +++ b/pytorch_translate/ensemble_export.py @@ -19,8 +19,10 @@ from caffe2.python.onnx import backend as caffe2_backend from caffe2.python.predictor import predictor_exporter from fairseq import tasks, utils +from fairseq.iterative_refinement_generator import DecoderOut from fairseq.models import ARCH_MODEL_REGISTRY -from fairseq.models.model_utils import script_skip_tensor, script_skip_tensor_list +from fairseq.models.model_utils import script_skip_tensor +from fairseq.models.transformer import EncoderOut from pytorch_translate.beam_decode import BeamDecode from pytorch_translate.data import dictionary from pytorch_translate.research.knowledge_distillation import ( @@ -2194,16 +2196,16 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None): # initialize buffers (very model specific, with length prediction or not) prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) - prev_output_tokens = prev_decoder_out[0].clone() + prev_output_tokens = prev_decoder_out.output_tokens.clone() finalized_tokens_list = [torch.tensor(0) for _ in range(bsz)] finalized_scores_list = [torch.tensor(0) for _ in range(bsz)] finalized_attns_list = [torch.tensor(0) for _ in range(bsz)] finalized_alignments_list = [torch.tensor(0) for _ in range(bsz)] - prev_decoder_out[4] = self.max_iter + 1 + prev_decoder_out._replace(max_step=self.max_iter + 1) for step in range(self.max_iter + 1): - prev_decoder_out[3] = step + prev_decoder_out._replace(step=step) decoder_out = model.forward_decoder( prev_decoder_out, encoder_out, @@ -2211,25 +2213,26 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None): max_ratio=self.max_ratio if step == 0 else None, decoding_format=self.decoding_format, ) - terminated, output_tokens, output_scores, output_attn = is_a_loop( self.pad, prev_output_tokens, - decoder_out[0], - decoder_out[1], - decoder_out[2], + decoder_out.output_tokens, + decoder_out.output_scores, + decoder_out.attn, + ) + decoder_out._replace( + output_tokens=output_tokens, + output_scores=output_scores, + attn=output_attn, ) - decoder_out[0] = output_tokens - decoder_out[1] = output_scores - decoder_out[2] = output_attn terminated = last_step(step, self.max_iter, terminated) # collect finalized sentences finalized_idxs = sent_idxs[terminated] - finalized_tokens = decoder_out[0][terminated] - finalized_scores = decoder_out[1][terminated] + finalized_tokens = decoder_out.output_tokens[terminated] + finalized_scores = decoder_out.output_scores[terminated] finalized_attn = ( - None if decoder_out[2] is None else decoder_out[2][terminated] + None if decoder_out.attn is None else decoder_out.attn[terminated] ) finalized_tokens_list = finalize_hypos_loop_tokens( finalized_tokens_list, @@ -2256,17 +2259,30 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None): ) # for next step - prev_decoder_out = [ - script_skip_tensor(decoder_out[0], ~terminated), - script_skip_tensor(decoder_out[1], ~terminated), - decoder_out[2], - decoder_out[3], - decoder_out[4], - ] - encoder_out = script_skip_tensor_list(encoder_out, ~terminated) + prev_decoder_out = DecoderOut( + output_tokens=script_skip_tensor( + decoder_out.output_tokens, ~terminated + ), + output_scores=script_skip_tensor( + decoder_out.output_scores, ~terminated + ), + attn=decoder_out.attn, + step=decoder_out.step, + max_step=decoder_out.max_step, + history=None, + ) + + encoder_out = EncoderOut( + encoder_out=script_skip_tensor(encoder_out.encoder_out, ~terminated), + encoder_padding_mask=None, + encoder_embedding=script_skip_tensor( + encoder_out.encoder_embedding, ~terminated + ), + encoder_states=None, + ) sent_idxs = script_skip_tensor(sent_idxs, ~terminated) - prev_output_tokens = prev_decoder_out[0].clone() + prev_output_tokens = prev_decoder_out.output_tokens.clone() return ( finalized_tokens_list,