Skip to content


Latest commit



156 lines (139 loc) · 11.9 KB

File metadata and controls

156 lines (139 loc) · 11.9 KB

ORTModule Custom Autograd Function Support

What is autograd Functions?

PyTorch allows users to define customized operators (for its forward and backward implementations) PyTorch: Defining New autograd Functions.

There are many such use cases as more optimized deep learning projects keep growing, here we just name a few:

Those operators are used in training/evaluation scenarios a lot, where is ORTModule capability overlaps. To best release ORTModule's acceleration power, we need tolerant and handle those customized operators from the to-onnx conversion, to backward graph building, and also its execution in runtime as a full lifecycle.

How ORTModule support autograd.Function?

The way we have here is through introduced PythonOp/PythonOpGrad MS domain operators in ONNX Runtime,

  • Map autograd Function (prim::PythonOp in PyTorch) to PythonOp in ONNX Runtime during model export by registering customized export function
      class ScalarAndTupleFunction(torch.autograd.Function):
          def forward(ctx, input, alpha, beta, gamma):
              ctx.alpha = alpha
              ctx.beta = beta
              ctx.gamma = gamma
              return alpha * beta[0] * beta[1] * gamma * input.clamp(min=0)
          def backward(ctx, grad_output):
              input, = ctx.saved_tensors
              alpha = ctx.alpha
              beta = ctx.beta
              gamma = ctx.gamma
              grad_input = grad_output.clone()
              grad_input[input < 0] = 0
              return alpha * beta[0] * beta[1] * gamma * grad_input, None, None, None
    The example above shows a customized function taking 4 inputs (despite of ctx), the first input is a tensor exporter treats it as input for PythonOp, the others are scalars, export function will convert all such non-tensor inputs to constant and stores in PythonOp's attributes. Things to be noted here: if the non-tensor input is one of those types "bool scalar, int scalar, float scalar, bool tuple, int tuple, float tuple", they will be stored in corresponding attributes; otherwise, they will be treated a object and the object address stored in input_pointer_scalars (reference count will be increased also to make sure it exists during model run).
  • PythonOp kernel is responsible to run the forward interface user defined through forward runner. Similarly, PythonOpGrad kernel is responsible to run the backward interface user defined through backward runner.

Currently, for training python wheel, PythonOp support is by default enabled, users don't need to be aware of it. As long as the defined torch.autograd.Function is working in PyTorch run, it should be runnable with ORTModule. If you need to enable it or disable it explicitly, refer to the wiki.

Known Issues and Workaround

PyTorch Versions

  • Minimum version 1.9 (introduced "Support registering custom export for `prim::PythonOp`` from torch.autograd.Function (#55630) (#57600)")
  • If the static forward function has only one output, any version of Pytorch 1.9 is fine. Otherwise, a PyTorch version containing this commit is required.
  • Throw _Map_base::at Exception, export errors like this:
      RuntimeError: There was an error while exporting the PyTorch model to ONNX:
      Traceback (most recent call last):
      File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/", line 316, in get_exception_as_string
      	raise exception
      File "/opt/conda/envs/ptca/lib/python3.8/site-packages/onnxruntime/training/ortmodule/", line 425, in _get_exported_model
      File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/", line 506, in export
      File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/", line 1548, in _export
      	graph, params_dict, torch_out = _model_to_graph(
      File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/", line 1113, in _model_to_graph
      	graph, params, torch_out, module = _create_jit_graph(model, args)
      File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/", line 989, in _create_jit_graph
      	graph, torch_out = _trace_and_get_graph_from_model(model, args)
      File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/onnx/", line 893, in _trace_and_get_graph_from_model
      	trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
      File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/jit/", line 1268, in _get_trace_graph
      	outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
      File "/opt/conda/envs/ptca/lib/python3.8/site-packages/deepspeed-0.9.5+95680ca-py3.8.egg/deepspeed/runtime/zero/", line 632, in _ort_post_forward_module_hook
      	a = ORTPostForwardwardFunction.apply(module, _post_forward_module_hook, _ort_run_before_backward_function, len(input), len(output), *input_and_output)
      File "/opt/conda/envs/ptca/lib/python3.8/site-packages/torch/autograd/", line 506, in apply
      	return super().apply(*args, **kwargs)  # type: ignore[misc]
      RuntimeError: _Map_base::at
    Resolution: upgrade PyTorch to new versions containing this commit, when export param autograd_inlining is set to false to skip this error.
  • "Tried to trace <torch.torch.classes.c10d.ProcessGroup object at 0x2969c520> but it is not part of the active trace" This usually happens when torch.autograd.Function's forward function used PyTorch collective calls and pass the group explicitly.
      RuntimeError: There was an error while exporting the PyTorch model to ONNX:
      Traceback (most recent call last):
      File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/onnxruntime/training/ortmodule/", line 324, in get_exception_as_string
      	raise exception
      File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/onnxruntime/training/ortmodule/", line 342, in _get_exported_model
      File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/", line 507, in export
      File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/", line 1567, in _export
      	graph, params_dict, torch_out = _model_to_graph(
      File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/", line 1124, in _model_to_graph
      	graph, params, torch_out, module = _create_jit_graph(model, args)
      File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/", line 1000, in _create_jit_graph
      	graph, torch_out = _trace_and_get_graph_from_model(model, args)
      File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/onnx/", line 904, in _trace_and_get_graph_from_model
      	trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
      File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/jit/", line 1269, in _get_trace_graph
      	outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
      File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/jit/", line 128, in forward
      	graph, out = torch._C._create_graph_by_tracing(
      File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/", line 640, in _ort_pre_forward_module_hook
      	rets = ORTPreForwardwardFunction.apply(self, module, _ort_run_after_backward_function, *inputs)
      File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/", line 823, in pre_sub_module_forward_function
      	param_coordinator.fetch_sub_module(sub_module, forward=True)
      File "/bert_ort/pengwa/py3.8/lib/python3.8/site-packages/torch/distributed/", line 2841, in all_gather_into_tensor
      	work = group._allgather_base(output_tensor, input_tensor)
      RuntimeError: Tried to trace <__torch__.torch.classes.c10d.ProcessGroup object at 0x56250ad114a0> but it is not part of the active trace. Modules that are called during a trace must be registered as submodules of the thing being traced.
    Resolution: modify the autograd.Function, to skip the run the collection operator during onnx export, here is an example.
      # Pre
      def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()):
      	return torch.distributed.all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug)
      # Workaround
      from typing import Any, List
      class DummyWork(torch.distributed.distributed_c10d.Work):
      	def is_completed(self) -> bool:
      		return True
      	def is_success(self) -> bool:
      		return True
      	def exception(self) -> Any:
      		return None
      	def wait(self, timeout: timedelta = timedelta) -> bool:
      		return True
      	def source_rank(self) -> int:
      		return 0
      	def _source_rank(self) -> int:
      		return 0
      	def result(self) -> List[torch.Tensor]:
      		return []
      	def synchronize(self):
      def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()):
      	if torch.onnx.is_in_onnx_export():
      		return DummyWork()
      	return torch.distributed.all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug)