Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add export of prim::data #45747

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,16 @@ def forward(self, input):
m1 = torch.randn(3, 4, 5, 6, 7)
self.run_test(MyModel(), m1)

@skipIfUnsupportedMinOpsetVersion(9)
def test_data(self):
class Data(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
return x.new_zeros(x.data.size())

x = torch.randn(3, 4)
self.run_test(Data(), x)

@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest() # Need type inference
def test_index_mask_nd(self):
Expand Down
2 changes: 2 additions & 0 deletions torch/onnx/symbolic_opset9.py
Original file line number Diff line number Diff line change
Expand Up @@ -2193,6 +2193,8 @@ def log2(g, self):
def prim_shape(g, self):
return g.op('Shape', self)

def prim_data(g, self):
return self

@parse_args('v', 'i')
def one_hot(g, self, num_classes):
Expand Down