From 8c9292d594af0378f3be90e641aebb815d9e0c26 Mon Sep 17 00:00:00 2001 From: Benedikt Wilbertz Date: Thu, 22 Feb 2018 22:08:45 +0100 Subject: [PATCH] fix reverse/copy for packed problems --- tensor2tensor/data_generators/problem.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index ebcc0697d..b2fcecc5e 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -425,11 +425,22 @@ def maybe_reverse_features(self, feature_map): return inputs, targets = feature_map["inputs"], feature_map["targets"] feature_map["inputs"], feature_map["targets"] = targets, inputs + if "inputs_segmentation" in feature_map: + inputs, targets = feature_map["inputs_segmentation"], feature_map["targets_segmentation"] + feature_map["inputs_segmentation"], feature_map["targets_segmentation"] = targets, inputs + if "inputs_position" in feature_map: + inputs, targets = feature_map["inputs_position"], feature_map["targets_position"] + feature_map["inputs_position"], feature_map["targets_position"] = targets, inputs + def maybe_copy_features(self, feature_map): if not self._was_copy: return feature_map["targets"] = feature_map["inputs"] + if "inputs_segmentation" in feature_map: + feature_map["targets_segmentation"] = feature_map["inputs_segmentation"] + if "inputs_position" in feature_map: + feature_map["targets_position"] = feature_map["inputs_position"] def dataset(self, mode,