Skip to content

Commit

Permalink
Fix meta registration for aten._cudnn_rnn (#91333)
Browse files Browse the repository at this point in the history
Found this issue from [weekly running 7k github models](https://github.com/pytorch/torchdynamo/issues/1884). This caused  regression on pass rate, there are 25 models failed due to this issue.
The reason is argument ```cx``` of ```aten._cudnn_rnn``` can be ```None```, but it doesn't handle well in meta registration, so throws the following error:
```
Traceback (most recent call last):
  File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1059, in run_node
    return nnmodule(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1482, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/rnn.py", line 477, in forward
    result = _VF.rnn_tanh(input, hx, self._flat_weights, self.bias, self.num_layers,
  File "/scratch/ybliang/work/repos/pytorch/torch/_subclasses/fake_tensor.py", line 916, in __torch_dispatch__
    r = func(*args, **kwargs)
  File "/scratch/ybliang/work/repos/pytorch/torch/_ops.py", line 284, in __call__
    return self._op(*args, **kwargs or {})
  File "/scratch/ybliang/work/repos/pytorch/torch/_meta_registrations.py", line 2108, in _cudnn_rnn
    cy = cx.new_empty(0 if cx is None else cell_shape)
AttributeError: 'NoneType' object has no attribute 'new_empty'
```

Pull Request resolved: #91333
Approved by: https://github.com/ezyang
  • Loading branch information
yanboliang authored and pytorchmergebot committed Dec 23, 2022
1 parent df46ba4 commit 789b143
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
20 changes: 12 additions & 8 deletions test/test_fake_tensor.py
Expand Up @@ -345,7 +345,7 @@ def fn(
mode = FakeTensorMode(allow_fallback_kernels=allow_fallback_kernels)
for i, context in enumerate([contextlib.nullcontext, lambda: mode]):
with context():
inps = (
inps1 = [
torch.randn([92, 8, 2048]).cuda(),
torch.randn([8192, 2048]).cuda(),
torch.randn([8192, 2048]).cuda(),
Expand All @@ -366,13 +366,17 @@ def fn(
torch.randn([167837696]).cuda(),
torch.randn([4, 8, 2048]).cuda(),
torch.randn([4, 8, 2048]).cuda(),
)
out = fn(*inps)
self.assertIs(out[4], inps[-3])
for ten in out:
if i == 1:
self.assertTrue(isinstance(ten, FakeTensor))
self.assertEqual(ten.device.type, 'cuda')
]
inps2 = inps1
inps2[len(inps2) - 1] = None # argument `cx` can be None

for inps in [inps1, inps2]:
out = fn(*inps)
self.assertIs(out[4], inps[-3])
for ten in out:
if i == 1:
self.assertTrue(isinstance(ten, FakeTensor))
self.assertEqual(ten.device.type, 'cuda')

@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cuda_lstm(self):
Expand Down
5 changes: 4 additions & 1 deletion torch/_meta_registrations.py
Expand Up @@ -2105,7 +2105,10 @@ def _cudnn_rnn(
output = input.new_empty(out_shape)

cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
cy = cx.new_empty(0 if cx is None else cell_shape)
if cx is None:
cy = torch.empty(0, device=input.device)
else:
cy = cx.new_empty(cell_shape)

hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])

Expand Down

0 comments on commit 789b143

Please sign in to comment.