Skip to content

Meta PReLU failing when input shape is 1D #89560

@ahmadsarvmeily

Description

@ahmadsarvmeily

🐛 Describe the bug

Calling PReLU with a 1D input (1 channel) fails with the meta backend, as shown below. This fails both in torch 1.13, and in the most recent nightly 1.14.0.dev20221123+cu117

import torch
x = torch.randn(4).to('meta')
model = torch.nn.PReLU().to('meta')
model(x)

Traceback:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ahmads/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ahmads/.local/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 1271, in forward
    return F.prelu(input, self.weight)
  File "/home/ahmads/.local/lib/python3.8/site-packages/torch/_prims_common/wrappers.py", line 119, in _fn
    result = fn(**bound.arguments)
  File "/home/ahmads/.local/lib/python3.8/site-packages/torch/_refs/nn/functional/__init__.py", line 1062, in prelu
    weight = prims.broadcast_in_dim(
  File "/home/ahmads/.local/lib/python3.8/site-packages/torch/_ops.py", line 285, in __call__
    return self._op(*args, **kwargs or {})
  File "/home/ahmads/.local/lib/python3.8/site-packages/torch/_prims/__init__.py", line 274, in _prim_impl
    meta(*args, **kwargs)
  File "/home/ahmads/.local/lib/python3.8/site-packages/torch/_prims/__init__.py", line 1217, in _broadcast_in_dim_meta
    reduce(lambda acc, x: _greater_than_reduce(acc, x), broadcast_dimensions, -1)
  File "/home/ahmads/.local/lib/python3.8/site-packages/torch/_prims/__init__.py", line 1217, in <lambda>
    reduce(lambda acc, x: _greater_than_reduce(acc, x), broadcast_dimensions, -1)
  File "/home/ahmads/.local/lib/python3.8/site-packages/torch/_prims/__init__.py", line 1213, in _greater_than_reduce
    assert x < len(shape)

Versions

Collecting environment information...
PyTorch version: 1.14.0.dev20221123+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-10ubuntu2) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.2 (default, Jul 16 2020, 14:00:26)  [GCC 9.3.0] (64-bit runtime)
Python platform: Linux-5.10.16.3-microsoft-standard-WSL2-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Quadro T2000
Nvidia driver version: 517.00
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.14.0.dev20221123+cu117
[conda] Could not collect

cc @ezyang @mruberry @ngimel @lezcano @fdrocha @peterbell10

Metadata

Metadata

Assignees

Labels

module: primTorchmodule: structured kernelsRelated to new structured kernels functionalitytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions