Skip to content

cannot convert to channels last format for conv2d conv3d hybrid model #77821

@SeedKunY

Description

@SeedKunY

🐛 Describe the bug

cannot convert to channels last format for conv2d conv3d hybrid model

import torch.nn as nn
import torch

cpu_device = torch.device("cpu")


class Hybrid_model(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer1 = nn.Conv2d(2, 4, kernel_size=3, stride=1, padding=1, bias=False)
    self.layer2 = nn.Conv3d(4, 8, kernel_size=3, stride=1, padding=1, bias=False)

  def forward(self, inputs):
    x = self.layer1(inputs)
    x = torch.reshape(x, (x.size(0), x.size(1), x.size(2), 16, 2))
    x = self.layer2(x)
    return x


if __name__ == "__main__":
  test_model = Hybrid_model().to(memory_format=torch.channels_last)
  x = torch.randn([3, 2, 32, 32], dtype=torch.float, requires_grad=True)
  y = test_model(x)
python test_conv2dconv3d_hybrid_model.py
Traceback (most recent call last):
  File "test_conv2dconv3d_hybrid_model.py", line 21, in <module>
    test_model = Hybrid_model().to(memory_format=torch.channels_last)
  File "/home/gta/miniconda3/envs/xxx/lib/python3.7/site-packages/torch/nn/modules/module.py", line 899, in to
    return self._apply(convert)
  File "/home/gta/miniconda3/envs/xxx/lib/python3.7/site-packages/torch/nn/modules/module.py", line 570, in _apply
    module._apply(fn)
  File "/home/gta/miniconda3/envs/xxx/lib/python3.7/site-packages/torch/nn/modules/module.py", line 593, in _apply
    param_applied = fn(param)
  File "/home/gta/miniconda3/envs/xxx/lib/python3.7/site-packages/torch/nn/modules/module.py", line 896, in convert
    non_blocking, memory_format=convert_to_format)
RuntimeError: required rank 4 tensor to use channels_last format

Even we use test_model = Hybrid_model().to(memory_format=torch.channels_last_3d), it still report error as below:

python test_conv2dconv3d_hybrid_model.py
Traceback (most recent call last):
  File "test_conv2dconv3d_hybrid_model.py", line 21, in <module>
    test_model = Hybrid_model().to(memory_format=torch.channels_last_3d)
  File "/home/gta/miniconda3/envs/xxx/lib/python3.7/site-packages/torch/nn/modules/module.py", line 899, in to
    return self._apply(convert)
  File "/home/gta/miniconda3/envs/xxx/lib/python3.7/site-packages/torch/nn/modules/module.py", line 570, in _apply
    module._apply(fn)
  File "/home/gta/miniconda3/envs/xxx/lib/python3.7/site-packages/torch/nn/modules/module.py", line 593, in _apply
    param_applied = fn(param)
  File "/home/gta/miniconda3/envs/xxx/lib/python3.7/site-packages/torch/nn/modules/module.py", line 896, in convert
    non_blocking, memory_format=convert_to_format)
RuntimeError: required rank 5 tensor to use channels_last_3d format

Versions

Collecting environment information...
PyTorch version: 1.10.0a0+gitcb9f926
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.17

Python version: 3.7.11 (default, Jul 27 2021, 14:32:16) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.10.54+prerelease2927-x86_64-with-debian-bullseye-sid
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] torch==1.10.0a0+gitcb9f926
[conda] mkl 2022.0.1 h06a4308_117 defaults
[conda] mkl-include 2022.0.1 h06a4308_117 defaults
[conda] numpy 1.21.2 py37hd8d4704_0 defaults
[conda] numpy-base 1.21.2 py37h2b8c604_0 defaults
[conda] torch 1.10.0a0+gitcb9f926 pypi_0 pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: convolutionProblems related to convolutions (THNN, THCUNN, CuDNN)triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions