Skip to content
This repository was archived by the owner on Aug 1, 2023. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 39 additions & 23 deletions pytorch_translate/ensemble_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from caffe2.python.onnx import backend as caffe2_backend
from caffe2.python.predictor import predictor_exporter
from fairseq import tasks, utils
from fairseq.iterative_refinement_generator import DecoderOut
from fairseq.models import ARCH_MODEL_REGISTRY
from fairseq.models.model_utils import script_skip_tensor, script_skip_tensor_list
from fairseq.models.model_utils import script_skip_tensor
from fairseq.models.transformer import EncoderOut
from pytorch_translate.beam_decode import BeamDecode
from pytorch_translate.data import dictionary
from pytorch_translate.research.knowledge_distillation import (
Expand Down Expand Up @@ -2194,42 +2196,43 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None):

# initialize buffers (very model specific, with length prediction or not)
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
prev_output_tokens = prev_decoder_out[0].clone()
prev_output_tokens = prev_decoder_out.output_tokens.clone()

finalized_tokens_list = [torch.tensor(0) for _ in range(bsz)]
finalized_scores_list = [torch.tensor(0) for _ in range(bsz)]
finalized_attns_list = [torch.tensor(0) for _ in range(bsz)]
finalized_alignments_list = [torch.tensor(0) for _ in range(bsz)]
prev_decoder_out[4] = self.max_iter + 1
prev_decoder_out._replace(max_step=self.max_iter + 1)

for step in range(self.max_iter + 1):
prev_decoder_out[3] = step
prev_decoder_out._replace(step=step)
decoder_out = model.forward_decoder(
prev_decoder_out,
encoder_out,
eos_penalty=self.eos_penalty,
max_ratio=self.max_ratio if step == 0 else None,
decoding_format=self.decoding_format,
)

terminated, output_tokens, output_scores, output_attn = is_a_loop(
self.pad,
prev_output_tokens,
decoder_out[0],
decoder_out[1],
decoder_out[2],
decoder_out.output_tokens,
decoder_out.output_scores,
decoder_out.attn,
)
decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=output_attn,
)
decoder_out[0] = output_tokens
decoder_out[1] = output_scores
decoder_out[2] = output_attn

terminated = last_step(step, self.max_iter, terminated)
# collect finalized sentences
finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out[0][terminated]
finalized_scores = decoder_out[1][terminated]
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
None if decoder_out[2] is None else decoder_out[2][terminated]
None if decoder_out.attn is None else decoder_out.attn[terminated]
)
finalized_tokens_list = finalize_hypos_loop_tokens(
finalized_tokens_list,
Expand All @@ -2256,17 +2259,30 @@ def generate(self, models, src_tokens, src_lengths, prefix_tokens=None):
)

# for next step
prev_decoder_out = [
script_skip_tensor(decoder_out[0], ~terminated),
script_skip_tensor(decoder_out[1], ~terminated),
decoder_out[2],
decoder_out[3],
decoder_out[4],
]
encoder_out = script_skip_tensor_list(encoder_out, ~terminated)
prev_decoder_out = DecoderOut(
output_tokens=script_skip_tensor(
decoder_out.output_tokens, ~terminated
),
output_scores=script_skip_tensor(
decoder_out.output_scores, ~terminated
),
attn=decoder_out.attn,
step=decoder_out.step,
max_step=decoder_out.max_step,
history=None,
)

encoder_out = EncoderOut(
encoder_out=script_skip_tensor(encoder_out.encoder_out, ~terminated),
encoder_padding_mask=None,
encoder_embedding=script_skip_tensor(
encoder_out.encoder_embedding, ~terminated
),
encoder_states=None,
)
sent_idxs = script_skip_tensor(sent_idxs, ~terminated)

prev_output_tokens = prev_decoder_out[0].clone()
prev_output_tokens = prev_decoder_out.output_tokens.clone()

return (
finalized_tokens_list,
Expand Down