-
Notifications
You must be signed in to change notification settings - Fork 25k
Closed
Description
Currently using proxied tensor shapes as inputs to torch.zeros
is not symbolically traceable:
# simple example
import torch
class M(torch.nn.Module):
def forward(self, x):
x = torch.zeros(x.shape[0], x.shape[1])
return x
m = M()
gm = torch.fx.symbolic_trace(m)
print(gm)
...
<ipython-input-9-de0156858ee1> in forward(self, x)
3 class M(torch.nn.Module):
4 def forward(self, x):
----> 5 x = torch.zeros(x.shape[0], x.shape[1])
6 return x
7
TypeError: zeros() received an invalid combination of arguments - got (Proxy, Proxy), but expected one of:
* (tuple of ints size, *, tuple of names names, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
* (tuple of ints size, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
# real model code
def forward(self, x1, x2):
assert x1.shape == x2.shape
d = self.d
s1 = self.s1
s2 = self.s2
n, c, h, w = x1.shape
out_h = (h - 1) // s1 + 1
out_w = (w - 1) // s1 + 1
out_k = 2 * d // s2 + 1
result = torch.zeros(n, out_k ** 2, out_h, out_w, device=x1.device)
...
Is this expected to work?
Metadata
Metadata
Assignees
Labels
No labels