diff --git a/pytorch_translate/ensemble_export.py b/pytorch_translate/ensemble_export.py index eb779891..ea56f6a2 100644 --- a/pytorch_translate/ensemble_export.py +++ b/pytorch_translate/ensemble_export.py @@ -19,6 +19,7 @@ from caffe2.python.predictor import predictor_exporter from fairseq import tasks, utils from pytorch_translate.tasks.pytorch_translate_task import DictionaryHolderTask +from pytorch_translate.transformer import TransformerEncoder from pytorch_translate.word_prediction import word_prediction_model from torch.onnx import ExportTypes, OperatorExportTypes @@ -249,15 +250,30 @@ def forward(self, src_tokens, src_lengths): # (seq_length, batch_size) for compatibility with Caffe2 src_tokens_seq_first = src_tokens.t() - for i, model in enumerate(self.models): + futures = [] + for model in self.models: # evaluation mode model.eval() - encoder_out = model.encoder(src_tokens_seq_first, src_lengths) + # TODO(jamesreed): transformer encodder returns a None output, and + # the fork/join API doesn't handle that well. We should figure out + # a way to annotate outputs as Optional and record that in fork/join + # traces. + if isinstance(model.encoder, TransformerEncoder): + futures.append(model.encoder(src_tokens_seq_first, src_lengths)) + else: + futures.append( + torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths) + ) # evaluation mode model.eval() + for i, (model, future) in enumerate(zip(self.models, futures)): + if isinstance(model.encoder, TransformerEncoder): + encoder_out = future + else: + encoder_out = torch.jit._wait(future) # "primary" encoder output (vector representations per source token) encoder_outputs = encoder_out[0] outputs.append(encoder_outputs) @@ -392,6 +408,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs): else: possible_translation_tokens = None + futures = [] + for i, model in enumerate(self.models): if ( isinstance(model, rnn.RNNModel) @@ -433,54 +451,73 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs): src_embeddings, ) - # store cached states, use evaluation mode - model.decoder._is_incremental_eval = True - model.eval() - - # placeholder - incremental_state = {} - - # cache previous state inputs - utils.set_incremental_state( - model.decoder, - incremental_state, - "cached_state", - (prev_hiddens, prev_cells, prev_input_feed), - ) - - decoder_output = model.decoder( + def forked_section( input_tokens, encoder_out, - incremental_state=incremental_state, - possible_translation_tokens=possible_translation_tokens, - ) - logits, attn_scores, _ = decoder_output - - log_probs = F.log_softmax(logits, dim=2) - - log_probs_per_model.append(log_probs) - attn_weights_per_model.append(attn_scores) - - ( - next_hiddens, - next_cells, - next_input_feed, - ) = utils.get_incremental_state( - model.decoder, incremental_state, "cached_state" + possible_translation_tokens, + prev_hiddens, + prev_cells, + prev_input_feed, + ): + # store cached states, use evaluation mode + model.decoder._is_incremental_eval = True + model.eval() + + # placeholder + incremental_state = {} + + # cache previous state inputs + utils.set_incremental_state( + model.decoder, + incremental_state, + "cached_state", + (prev_hiddens, prev_cells, prev_input_feed), + ) + + decoder_output = model.decoder( + input_tokens, + encoder_out, + incremental_state=incremental_state, + possible_translation_tokens=possible_translation_tokens, + ) + logits, attn_scores, _ = decoder_output + + log_probs = F.log_softmax(logits, dim=2) + + log_probs_per_model.append(log_probs) + attn_weights_per_model.append(attn_scores) + + ( + next_hiddens, + next_cells, + next_input_feed, + ) = utils.get_incremental_state( + model.decoder, incremental_state, "cached_state" + ) + + return ( + log_probs, + attn_scores, + tuple(next_hiddens), + tuple(next_cells), + next_input_feed, + ) + + fut = torch.jit._fork( + forked_section, + input_tokens, + encoder_out, + possible_translation_tokens, + prev_hiddens, + prev_cells, + prev_input_feed, ) - for h, c in zip(next_hiddens, next_cells): - state_outputs.extend([h, c]) - beam_axis_per_state.extend([0, 0]) - - state_outputs.append(next_input_feed) - beam_axis_per_state.append(0) - + futures.append(fut) elif isinstance(model, transformer.TransformerModel) or isinstance( model, char_source_transformer_model.CharSourceTransformerModel ): encoder_output = inputs[i] - # store cached states, use evaluation mode model.decoder._is_incremental_eval = True model.eval() @@ -493,21 +530,36 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs): encoder_out = (encoder_output, None, None) - decoder_output = model.decoder( + def forked_section( + input_tokens, + encoder_out, + state_inputs, + possible_translation_tokens, + timestep, + ): + decoder_output = model.decoder( + input_tokens, + encoder_out, + incremental_state=state_inputs, + possible_translation_tokens=possible_translation_tokens, + timestep=timestep, + ) + logits, attn_scores, _, attention_states = decoder_output + + log_probs = F.log_softmax(logits, dim=2) + + return log_probs, attn_scores, tuple(attention_states) + + fut = torch.jit._fork( + forked_section, input_tokens, encoder_out, - incremental_state=state_inputs, - possible_translation_tokens=possible_translation_tokens, - timestep=timestep, + state_inputs, + possible_translation_tokens, + timestep, ) - logits, attn_scores, _, attention_states = decoder_output - - log_probs = F.log_softmax(logits, dim=2) - log_probs_per_model.append(log_probs) - attn_weights_per_model.append(attn_scores) - state_outputs.extend(attention_states) - beam_axis_per_state.extend([0 for _ in attention_states]) + futures.append(fut) elif isinstance( model, hybrid_transformer_rnn.HybridTransformerRNNModel ) or isinstance(model, char_source_hybrid.CharSourceHybridModel): @@ -519,30 +571,91 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs): encoder_out = (encoder_output, None, None) - incremental_state = {} num_states = (1 + model.decoder.num_layers) * 2 state_inputs = inputs[next_state_input : next_state_input + num_states] next_state_input += num_states - utils.set_incremental_state( - model.decoder, incremental_state, "cached_state", state_inputs - ) - decoder_output = model.decoder( + def forked_section( + input_tokens, + encoder_out, + state_inputs, + possible_translation_tokens, + timestep, + ): + incremental_state = {} + utils.set_incremental_state( + model.decoder, incremental_state, "cached_state", state_inputs + ) + + decoder_output = model.decoder( + input_tokens, + encoder_out, + incremental_state=incremental_state, + possible_translation_tokens=possible_translation_tokens, + timestep=timestep, + ) + logits, attn_scores, _ = decoder_output + + log_probs = F.log_softmax(logits, dim=2) + + next_states = utils.get_incremental_state( + model.decoder, incremental_state, "cached_state" + ) + + return log_probs, attn_scores, tuple(next_states) + + fut = torch.jit._fork( + forked_section, input_tokens, encoder_out, - incremental_state=incremental_state, - possible_translation_tokens=possible_translation_tokens, - timestep=timestep, + state_inputs, + possible_translation_tokens, + timestep, ) - logits, attn_scores, _ = decoder_output - log_probs = F.log_softmax(logits, dim=2) + futures.append(fut) + else: + raise RuntimeError(f"Not a supported model: {type(model)}") + + for (model, fut) in zip(self.models, futures): + if ( + isinstance(model, rnn.RNNModel) + or isinstance(model, char_source_model.CharSourceModel) + or isinstance(model, word_prediction_model.WordPredictionModel) + ): + ( + log_probs, + attn_scores, + next_hiddens, + next_cells, + next_input_feed, + ) = torch.jit._wait(fut) + + for h, c in zip(next_hiddens, next_cells): + state_outputs.extend([h, c]) + beam_axis_per_state.extend([0, 0]) + + state_outputs.append(next_input_feed) + beam_axis_per_state.append(0) + + elif isinstance(model, transformer.TransformerModel) or isinstance( + model, char_source_transformer_model.CharSourceTransformerModel + ): + log_probs, attn_scores, attention_states = torch.jit._wait(fut) + + log_probs_per_model.append(log_probs) + attn_weights_per_model.append(attn_scores) + + state_outputs.extend(attention_states) + beam_axis_per_state.extend([0 for _ in attention_states]) + elif isinstance( + model, hybrid_transformer_rnn.HybridTransformerRNNModel + ) or isinstance(model, char_source_hybrid.CharSourceHybridModel): + log_probs, attn_scores, next_states = torch.jit._wait(fut) + log_probs_per_model.append(log_probs) attn_weights_per_model.append(attn_scores) - next_states = utils.get_incremental_state( - model.decoder, incremental_state, "cached_state" - ) state_outputs.extend(next_states) # sequence RNN states have beam along axis 1 beam_axis_per_state.extend([1 for _ in next_states[:-2]])