Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge CompilationUnit from torch._C and torch.jit (#50614)
Summary: This simplifies our handling and allows passing CompilationUnits from Python to C++ defined functions via PyBind easily. Discussed on Slack with SplitInfinity Pull Request resolved: #50614 Reviewed By: anjali411 Differential Revision: D25938005 Pulled By: SplitInfinity fbshipit-source-id: 94aadf0c063ddfef7ca9ea17bfa998d8e7b367ad
- Loading branch information
1 parent
5e79b8e
commit ac0a3cc
Showing
5 changed files
with
121 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
from torch.testing._internal.jit_utils import JitTestCase | ||
|
||
if __name__ == "__main__": | ||
raise RuntimeError( | ||
"This test file is not meant to be run directly, use:\n\n" | ||
"\tpython test/test_jit.py TestPythonBindings\n\n" | ||
"instead." | ||
) | ||
|
||
|
||
class TestPythonBindings(JitTestCase): | ||
def test_cu_get_functions(self): | ||
@torch.jit.script | ||
def test_get_python_cu_fn(x: torch.Tensor): | ||
return 2 * x | ||
|
||
cu = torch.jit._state._python_cu | ||
self.assertTrue( | ||
"test_get_python_cu_fn" in (str(fn.name) for fn in cu.get_functions()) | ||
) | ||
|
||
def test_cu_create_function(self): | ||
@torch.jit.script | ||
def fn(x: torch.Tensor): | ||
return 2 * x | ||
|
||
cu = torch._C.CompilationUnit() | ||
cu.create_function("test_fn", fn.graph) | ||
|
||
inp = torch.randn(5) | ||
|
||
self.assertEqual(inp * 2, cu.find_function("test_fn")(inp)) | ||
self.assertEqual(cu.find_function("doesnt_exist"), None) | ||
self.assertEqual(inp * 2, cu.test_fn(inp)) | ||
with self.assertRaises(AttributeError): | ||
cu.doesnt_exist(inp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters