diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 3d6a8e8248b7..549614de496a 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -65,6 +65,7 @@ def generate_artifacts( ort_format (bool): Whether to save the generated artifacts in ORT format or not. Default is False. custom_op_library (str | os.PathLike): The path to the custom op library. If not specified, no custom op library is used. + additional_output_names (List[str]): List of additional output names to be added to the training/eval model. Raises: RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block` @@ -104,6 +105,20 @@ def __init__(self, _loss): self._loss = _loss def build(self, *inputs_to_loss): + if "additional_output_names" in extra_options: + # If additional output names is not a list, raise an error + if not isinstance(extra_options["additional_output_names"], list): + raise RuntimeError( + f"Unknown type provided for additional output names {type(extra_options['additional_output_names'])}. " + "Expected additional output names to be a list of strings." + ) + + loss_output = self._loss(*inputs_to_loss) + if isinstance(loss_output, tuple): + return (*loss_output, *tuple(extra_options["additional_output_names"])) + else: + return (loss_output, *tuple(extra_options["additional_output_names"])) + return self._loss(*inputs_to_loss) training_block = _TrainingBlock(loss_block) diff --git a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py index c6e8b98d3516..f7a7220dd66e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_onnxblock.py @@ -847,6 +847,39 @@ def mse_loss(prediction, target): assert np.allclose(ort_grad, _to_numpy(pt_param.grad)) +def test_additional_output_names(): + class DropoutModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.dropout = torch.nn.Dropout(p=0.5) + + def forward(self, x): + return self.dropout(x) + + model = DropoutModel() + onnx_model = _get_onnx_model(model, (torch.randn(1, 3, 224, 224),)) + + with tempfile.TemporaryDirectory() as temp_dir: + artifacts.generate_artifacts(onnx_model, loss=artifacts.LossType.CrossEntropyLoss, artifact_directory=temp_dir) + + eval_model = onnx.load(os.path.join(temp_dir, "eval_model.onnx")) + + # Make sure only loss is the output + assert len(eval_model.graph.output) == 1 + + # Re-generate artifacts with additional output names + artifacts.generate_artifacts( + onnx_model, + loss=artifacts.LossType.CrossEntropyLoss, + artifact_directory=temp_dir, + additional_output_names=["output-0"], + ) + + # Make sure the eval model has two outputs + eval_model = onnx.load(os.path.join(temp_dir, "eval_model.onnx")) + assert len(eval_model.graph.output) == 2 + + def test_eval_model_has_no_training_mode_dropout(): class DropoutModel(torch.nn.Module): def __init__(self):