-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[ONNX] Inline prim::PythonOp for Autograd Function Export #74765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
shubhambhokare1
wants to merge
24
commits into
pytorch:master
from
shubhambhokare1:sbhokare/autograd-subgraph
Closed
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
b56bf12
Subgraph creation
shubhambhokare1 b0fe094
Add pass
shubhambhokare1 f6284f2
add blocks
shubhambhokare1 7201bcb
Complete inline task
shubhambhokare1 58ee661
add support for multiple outputs
shubhambhokare1 553cac6
Add API for inline_subgraph
shubhambhokare1 48ab5e8
Refactor append function
8817720
Add logic for nested autograd functions
0bc6860
Add priority
4a9fba9
Redesign tests, remove parameter from export
539d45f
Add priority tests
9f31441
Add docs
90a5e1e
Change Node* logics to Value*
a4e0e3b
subgraph->subblock
a7ace47
Add Fallthrough mechanism
b2cd021
Add check for is_in_onnx_export
d82aacc
lint fix
3de1f54
lint fix
d5196c0
Revert fallthrough mode
7557cce
lint fixes
abf1026
Add detailed warning
ff9c51e
Fix indents
bb1687f
Lint fix post rebase
4f9ce2c
Refactor tests
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,181 @@ | ||
| # Owner(s): ["module: onnx"] | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
|
|
||
| from onnx_test_common import run_model_test | ||
| from torch.onnx import OperatorExportTypes | ||
| from torch.onnx._globals import GLOBALS | ||
| from torch.onnx.utils import _model_to_graph | ||
|
|
||
|
|
||
| class TestAutogradFuns(unittest.TestCase): | ||
| opset_version = GLOBALS.export_onnx_opset_version | ||
| keep_initializers_as_inputs = False | ||
| onnx_shape_inference = True | ||
|
|
||
| def test_single_output(self): | ||
| class SingleOut(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, i): | ||
| result = i.exp() | ||
| result = result.log() | ||
| ctx.save_for_backward(result) | ||
| return result | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| (result,) = ctx.saved_tensors | ||
| return grad_output * result | ||
|
|
||
| class Caller(torch.nn.Module): | ||
| def forward(self, input): | ||
| result = input + 5 | ||
| return SingleOut.apply(result) + 3 | ||
|
|
||
| model = Caller() | ||
| input = torch.ones(1) | ||
| run_model_test(self, model, input_args=(input,)) | ||
|
|
||
| def test_multi_output(self): | ||
| class MultiOut(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, i): | ||
| result_exp = i.exp() | ||
| result_log = result_exp.log() | ||
| ctx.save_for_backward(result_exp, result_log) | ||
| return result_exp, result_log | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| (result,) = ctx.saved_tensors | ||
| return grad_output * result | ||
|
|
||
| class Caller(torch.nn.Module): | ||
| def forward(self, input): | ||
| return MultiOut.apply(input) | ||
|
|
||
| model = Caller() | ||
| input = torch.ones(1, 5) | ||
| run_model_test(self, model, input_args=(input,)) | ||
|
|
||
| def test_partial_output(self): | ||
| class PartialOut(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, input): | ||
| ctx.save_for_backward(input) | ||
| values, indices = torch.topk(input, 3) | ||
| return values | ||
|
|
||
| class Caller(torch.nn.Module): | ||
| def forward(self, input): | ||
| return PartialOut.apply(input) | ||
|
|
||
| model = Caller() | ||
| input = torch.ones(1, 5) | ||
| run_model_test(self, model, input_args=(input,)) | ||
|
|
||
| def test_nested_autograd(self): | ||
| class Child(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, i): | ||
| result = i.log() | ||
| result_log = result.log() | ||
| ctx.save_for_backward(result_log) | ||
| return result_log | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| (result,) = ctx.saved_tensors | ||
| return grad_output * result | ||
|
|
||
| class Parent(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, i): | ||
| result_exp = i.exp() | ||
| result_log = Child.apply(result_exp) | ||
| ctx.save_for_backward(result_exp, result_log) | ||
| return result_exp, result_log | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| (result,) = ctx.saved_tensors | ||
| return grad_output * result | ||
|
|
||
| class Caller(torch.nn.Module): | ||
| def forward(self, input): | ||
| return Parent.apply(input) | ||
|
|
||
| model = Caller() | ||
| input = torch.ones(1, 5) | ||
| run_model_test(self, model, input_args=(input,)) | ||
|
|
||
| # Run export in ONNX_FALLTHROUGH mode as torch.erf() is not supported | ||
| def test_aten_unsupported(self): | ||
| class Erf(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, x): | ||
| erf_out = torch.special.erf(x) | ||
| ctx.save_for_backward(erf_out) | ||
| return erf_out | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| result = ctx.saved_tensors | ||
| return torch.special.erfinv(result), None | ||
|
|
||
| class Caller(torch.nn.Module): | ||
| def forward(self, input): | ||
| return Erf.apply(input) | ||
|
|
||
| model = Caller() | ||
| input = torch.ones(1, 5) | ||
|
|
||
| # Test ONNX_FALLTHROUGH_MODE | ||
| graph, _, _ = _model_to_graph( | ||
| model, | ||
| (input,), | ||
| operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH, | ||
| ) | ||
| iter = graph.nodes() | ||
| self.assertEqual(next(iter).kind(), "prim::PythonOp") | ||
|
|
||
| # Test ATEN_FALLBACK_MODE | ||
| graph, _, _ = _model_to_graph( | ||
| model, | ||
| (input,), | ||
| operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, | ||
| ) | ||
| iter = graph.nodes() | ||
| self.assertEqual(next(iter).kind(), "prim::PythonOp") | ||
|
|
||
| def test_inline_and_symbolic(self): | ||
| class Exp(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, i): | ||
| ctx.save_for_backward(input) | ||
| return i.exp() | ||
|
|
||
| @staticmethod | ||
| def symbolic(g, input): | ||
| return g.op("Exp", input) | ||
|
|
||
| class LogLog(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, i): | ||
| ctx.save_for_backward(input) | ||
| return i.log().log() | ||
|
|
||
| class Caller(torch.nn.Module): | ||
| def forward(self, input): | ||
| exp_result = Exp.apply(input) | ||
| return LogLog.apply(exp_result) | ||
|
|
||
| model = Caller() | ||
| input = torch.ones(1) | ||
| run_model_test(self, model, input_args=(input,)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.