Skip to content

Commit

Permalink
add logits option to generate artifacts (microsoft#17276)
Browse files Browse the repository at this point in the history
### Description

Adding the ability to export logits as an output for train and eval
graphs in generate_artifacts
it will remain optional..
  • Loading branch information
AdamLouly committed Aug 29, 2023
1 parent f3682ee commit 8224891
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
15 changes: 15 additions & 0 deletions orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def generate_artifacts(
ort_format (bool): Whether to save the generated artifacts in ORT format or not. Default is False.
custom_op_library (str | os.PathLike): The path to the custom op library.
If not specified, no custom op library is used.
additional_output_names (List[str]): List of additional output names to be added to the training/eval model.
Raises:
RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block`
Expand Down Expand Up @@ -104,6 +105,20 @@ def __init__(self, _loss):
self._loss = _loss

def build(self, *inputs_to_loss):
if "additional_output_names" in extra_options:
# If additional output names is not a list, raise an error
if not isinstance(extra_options["additional_output_names"], list):
raise RuntimeError(
f"Unknown type provided for additional output names {type(extra_options['additional_output_names'])}. "
"Expected additional output names to be a list of strings."
)

loss_output = self._loss(*inputs_to_loss)
if isinstance(loss_output, tuple):
return (*loss_output, *tuple(extra_options["additional_output_names"]))
else:
return (loss_output, *tuple(extra_options["additional_output_names"]))

return self._loss(*inputs_to_loss)

training_block = _TrainingBlock(loss_block)
Expand Down
33 changes: 33 additions & 0 deletions orttraining/orttraining/test/python/orttraining_test_onnxblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,39 @@ def mse_loss(prediction, target):
assert np.allclose(ort_grad, _to_numpy(pt_param.grad))


def test_additional_output_names():
class DropoutModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.dropout = torch.nn.Dropout(p=0.5)

def forward(self, x):
return self.dropout(x)

model = DropoutModel()
onnx_model = _get_onnx_model(model, (torch.randn(1, 3, 224, 224),))

with tempfile.TemporaryDirectory() as temp_dir:
artifacts.generate_artifacts(onnx_model, loss=artifacts.LossType.CrossEntropyLoss, artifact_directory=temp_dir)

eval_model = onnx.load(os.path.join(temp_dir, "eval_model.onnx"))

# Make sure only loss is the output
assert len(eval_model.graph.output) == 1

# Re-generate artifacts with additional output names
artifacts.generate_artifacts(
onnx_model,
loss=artifacts.LossType.CrossEntropyLoss,
artifact_directory=temp_dir,
additional_output_names=["output-0"],
)

# Make sure the eval model has two outputs
eval_model = onnx.load(os.path.join(temp_dir, "eval_model.onnx"))
assert len(eval_model.graph.output) == 2


def test_eval_model_has_no_training_mode_dropout():
class DropoutModel(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit 8224891

Please sign in to comment.