Skip to content

Commit

Permalink
[easy] Dispatch torch.from_numpy to torch.as_tensor (#114609)
Browse files Browse the repository at this point in the history
...rather than detaching the tensor

Pull Request resolved: #114609
Approved by: https://github.com/larryliu0820, https://github.com/voznesenskym
ghstack dependencies: #114608
  • Loading branch information
lezcano authored and pytorchmergebot committed Nov 28, 2023
1 parent 0bb2600 commit 79ee99e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 20 deletions.
14 changes: 13 additions & 1 deletion test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8181,17 +8181,29 @@ def wrapper(x):
y = my_np(x)
return torch.as_tensor(y)

@torch.compile
def wrapper2(x):
x = x.numpy()
y = my_np(x)
return torch.from_numpy(y)

x_np = torch.arange(8, dtype=torch.float32, requires_grad=True)
x = torch.arange(8, dtype=torch.float32, requires_grad=True)

out_np = wrapper(x_np)
out = my_torch(x)
self.assertEqual(out, out_np)

x2_np = torch.arange(8, dtype=torch.float32, requires_grad=True)
out2_np = wrapper2(x2_np)
self.assertEqual(out, out2_np)

out_np.backward()
out.backward()
self.assertEqual(x.grad, x_np.grad)

out2_np.backward()
self.assertEqual(x.grad, x2_np.grad)

# Disable constant propagation, so we isolate value range analysis
@patch.object(config, "constant_and_index_propagation", False)
@patch.object(config, "joint_graph_constant_folding", False)
Expand Down
29 changes: 10 additions & 19 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,25 +386,16 @@ def call_function(
unimplemented("torch.from_numpy. config.trace_numpy is False")
if not np:
unimplemented("torch.from_numpy. NumPy is not available")
assert len(args) == 1, f"Got arguments {args}"
assert not kwargs
t = args[0]
from .tensor import NumpyNdarrayVariable

if isinstance(t, NumpyNdarrayVariable):
# TODO: mark the tensor as non-resizable
return wrap_fx_proxy_cls(
target_cls=TensorVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
torch.detach,
*proxy_args_kwargs(args, {}),
),
example_value=None,
)
else:
unimplemented(f"torch.from_numpy(<{type(t)}>)")
return wrap_fx_proxy_cls(
target_cls=TensorVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
torch.as_tensor,
*proxy_args_kwargs(args, {}),
),
example_value=None,
)
elif can_dispatch_torch_function(tx, args, kwargs):
return dispatch_torch_function(tx, self, args, kwargs)
elif self.value is torch.autograd._profiler_enabled:
Expand Down

0 comments on commit 79ee99e

Please sign in to comment.