Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 179 additions & 66 deletions pytorch_translate/ensemble_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from caffe2.python.predictor import predictor_exporter
from fairseq import tasks, utils
from pytorch_translate.tasks.pytorch_translate_task import DictionaryHolderTask
from pytorch_translate.transformer import TransformerEncoder
from pytorch_translate.word_prediction import word_prediction_model
from torch.onnx import ExportTypes, OperatorExportTypes

Expand Down Expand Up @@ -249,15 +250,30 @@ def forward(self, src_tokens, src_lengths):
# (seq_length, batch_size) for compatibility with Caffe2
src_tokens_seq_first = src_tokens.t()

for i, model in enumerate(self.models):
futures = []
for model in self.models:
# evaluation mode
model.eval()

encoder_out = model.encoder(src_tokens_seq_first, src_lengths)
# TODO(jamesreed): transformer encodder returns a None output, and
# the fork/join API doesn't handle that well. We should figure out
# a way to annotate outputs as Optional and record that in fork/join
# traces.
if isinstance(model.encoder, TransformerEncoder):
futures.append(model.encoder(src_tokens_seq_first, src_lengths))
else:
futures.append(
torch.jit._fork(model.encoder, src_tokens_seq_first, src_lengths)
)

# evaluation mode
model.eval()

for i, (model, future) in enumerate(zip(self.models, futures)):
if isinstance(model.encoder, TransformerEncoder):
encoder_out = future
else:
encoder_out = torch.jit._wait(future)
# "primary" encoder output (vector representations per source token)
encoder_outputs = encoder_out[0]
outputs.append(encoder_outputs)
Expand Down Expand Up @@ -392,6 +408,8 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
else:
possible_translation_tokens = None

futures = []

for i, model in enumerate(self.models):
if (
isinstance(model, rnn.RNNModel)
Expand Down Expand Up @@ -433,54 +451,73 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):
src_embeddings,
)

# store cached states, use evaluation mode
model.decoder._is_incremental_eval = True
model.eval()

# placeholder
incremental_state = {}

# cache previous state inputs
utils.set_incremental_state(
model.decoder,
incremental_state,
"cached_state",
(prev_hiddens, prev_cells, prev_input_feed),
)

decoder_output = model.decoder(
def forked_section(
input_tokens,
encoder_out,
incremental_state=incremental_state,
possible_translation_tokens=possible_translation_tokens,
)
logits, attn_scores, _ = decoder_output

log_probs = F.log_softmax(logits, dim=2)

log_probs_per_model.append(log_probs)
attn_weights_per_model.append(attn_scores)

(
next_hiddens,
next_cells,
next_input_feed,
) = utils.get_incremental_state(
model.decoder, incremental_state, "cached_state"
possible_translation_tokens,
prev_hiddens,
prev_cells,
prev_input_feed,
):
# store cached states, use evaluation mode
model.decoder._is_incremental_eval = True
model.eval()

# placeholder
incremental_state = {}

# cache previous state inputs
utils.set_incremental_state(
model.decoder,
incremental_state,
"cached_state",
(prev_hiddens, prev_cells, prev_input_feed),
)

decoder_output = model.decoder(
input_tokens,
encoder_out,
incremental_state=incremental_state,
possible_translation_tokens=possible_translation_tokens,
)
logits, attn_scores, _ = decoder_output

log_probs = F.log_softmax(logits, dim=2)

log_probs_per_model.append(log_probs)
attn_weights_per_model.append(attn_scores)

(
next_hiddens,
next_cells,
next_input_feed,
) = utils.get_incremental_state(
model.decoder, incremental_state, "cached_state"
)

return (
log_probs,
attn_scores,
tuple(next_hiddens),
tuple(next_cells),
next_input_feed,
)

fut = torch.jit._fork(
forked_section,
input_tokens,
encoder_out,
possible_translation_tokens,
prev_hiddens,
prev_cells,
prev_input_feed,
)

for h, c in zip(next_hiddens, next_cells):
state_outputs.extend([h, c])
beam_axis_per_state.extend([0, 0])

state_outputs.append(next_input_feed)
beam_axis_per_state.append(0)

futures.append(fut)
elif isinstance(model, transformer.TransformerModel) or isinstance(
model, char_source_transformer_model.CharSourceTransformerModel
):
encoder_output = inputs[i]

# store cached states, use evaluation mode
model.decoder._is_incremental_eval = True
model.eval()
Expand All @@ -493,21 +530,36 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):

encoder_out = (encoder_output, None, None)

decoder_output = model.decoder(
def forked_section(
input_tokens,
encoder_out,
state_inputs,
possible_translation_tokens,
timestep,
):
decoder_output = model.decoder(
input_tokens,
encoder_out,
incremental_state=state_inputs,
possible_translation_tokens=possible_translation_tokens,
timestep=timestep,
)
logits, attn_scores, _, attention_states = decoder_output

log_probs = F.log_softmax(logits, dim=2)

return log_probs, attn_scores, tuple(attention_states)

fut = torch.jit._fork(
forked_section,
input_tokens,
encoder_out,
incremental_state=state_inputs,
possible_translation_tokens=possible_translation_tokens,
timestep=timestep,
state_inputs,
possible_translation_tokens,
timestep,
)
logits, attn_scores, _, attention_states = decoder_output

log_probs = F.log_softmax(logits, dim=2)
log_probs_per_model.append(log_probs)
attn_weights_per_model.append(attn_scores)

state_outputs.extend(attention_states)
beam_axis_per_state.extend([0 for _ in attention_states])
futures.append(fut)
elif isinstance(
model, hybrid_transformer_rnn.HybridTransformerRNNModel
) or isinstance(model, char_source_hybrid.CharSourceHybridModel):
Expand All @@ -519,30 +571,91 @@ def forward(self, input_tokens, prev_scores, timestep, *inputs):

encoder_out = (encoder_output, None, None)

incremental_state = {}
num_states = (1 + model.decoder.num_layers) * 2
state_inputs = inputs[next_state_input : next_state_input + num_states]
next_state_input += num_states
utils.set_incremental_state(
model.decoder, incremental_state, "cached_state", state_inputs
)

decoder_output = model.decoder(
def forked_section(
input_tokens,
encoder_out,
state_inputs,
possible_translation_tokens,
timestep,
):
incremental_state = {}
utils.set_incremental_state(
model.decoder, incremental_state, "cached_state", state_inputs
)

decoder_output = model.decoder(
input_tokens,
encoder_out,
incremental_state=incremental_state,
possible_translation_tokens=possible_translation_tokens,
timestep=timestep,
)
logits, attn_scores, _ = decoder_output

log_probs = F.log_softmax(logits, dim=2)

next_states = utils.get_incremental_state(
model.decoder, incremental_state, "cached_state"
)

return log_probs, attn_scores, tuple(next_states)

fut = torch.jit._fork(
forked_section,
input_tokens,
encoder_out,
incremental_state=incremental_state,
possible_translation_tokens=possible_translation_tokens,
timestep=timestep,
state_inputs,
possible_translation_tokens,
timestep,
)
logits, attn_scores, _ = decoder_output

log_probs = F.log_softmax(logits, dim=2)
futures.append(fut)
else:
raise RuntimeError(f"Not a supported model: {type(model)}")

for (model, fut) in zip(self.models, futures):
if (
isinstance(model, rnn.RNNModel)
or isinstance(model, char_source_model.CharSourceModel)
or isinstance(model, word_prediction_model.WordPredictionModel)
):
(
log_probs,
attn_scores,
next_hiddens,
next_cells,
next_input_feed,
) = torch.jit._wait(fut)

for h, c in zip(next_hiddens, next_cells):
state_outputs.extend([h, c])
beam_axis_per_state.extend([0, 0])

state_outputs.append(next_input_feed)
beam_axis_per_state.append(0)

elif isinstance(model, transformer.TransformerModel) or isinstance(
model, char_source_transformer_model.CharSourceTransformerModel
):
log_probs, attn_scores, attention_states = torch.jit._wait(fut)

log_probs_per_model.append(log_probs)
attn_weights_per_model.append(attn_scores)

state_outputs.extend(attention_states)
beam_axis_per_state.extend([0 for _ in attention_states])
elif isinstance(
model, hybrid_transformer_rnn.HybridTransformerRNNModel
) or isinstance(model, char_source_hybrid.CharSourceHybridModel):
log_probs, attn_scores, next_states = torch.jit._wait(fut)

log_probs_per_model.append(log_probs)
attn_weights_per_model.append(attn_scores)

next_states = utils.get_incremental_state(
model.decoder, incremental_state, "cached_state"
)
state_outputs.extend(next_states)
# sequence RNN states have beam along axis 1
beam_axis_per_state.extend([1 for _ in next_states[:-2]])
Expand Down