diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index ebcc0697d..b2fcecc5e 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -425,11 +425,22 @@ def maybe_reverse_features(self, feature_map): return inputs, targets = feature_map["inputs"], feature_map["targets"] feature_map["inputs"], feature_map["targets"] = targets, inputs + if "inputs_segmentation" in feature_map: + inputs, targets = feature_map["inputs_segmentation"], feature_map["targets_segmentation"] + feature_map["inputs_segmentation"], feature_map["targets_segmentation"] = targets, inputs + if "inputs_position" in feature_map: + inputs, targets = feature_map["inputs_position"], feature_map["targets_position"] + feature_map["inputs_position"], feature_map["targets_position"] = targets, inputs + def maybe_copy_features(self, feature_map): if not self._was_copy: return feature_map["targets"] = feature_map["inputs"] + if "inputs_segmentation" in feature_map: + feature_map["targets_segmentation"] = feature_map["inputs_segmentation"] + if "inputs_position" in feature_map: + feature_map["targets_position"] = feature_map["inputs_position"] def dataset(self, mode,