Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch -> ExecuTorch conversion fails with BatchNorm2d layer #126291

Open
dino-illumix opened this issue May 15, 2024 · 1 comment
Open

Pytorch -> ExecuTorch conversion fails with BatchNorm2d layer #126291

dino-illumix opened this issue May 15, 2024 · 1 comment

Comments

@dino-illumix
Copy link

dino-illumix commented May 15, 2024

馃悰 Describe the bug

When attempting to export any Pytorch model containing a BatchNorm2d layer with ExecuTorch the following error is encountered:

raise SpecViolationError(torch._export.verifier.SpecViolationError: Operator torch._ops.aten._native_batch_norm_legit_functional.default is not Aten Canonical.

I am running the following code:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._export import capture_pre_autograd_graph
from torch.export import export, ExportedProgram

from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager
from executorch import exir

class Model(nn.Module):
def init(self):
super().init()
self.conv1 = nn.Conv2d(1, 20, 5)
self.norm = nn.BatchNorm2d(20)
self.conv2 = nn.Conv2d(20, 20, 5)

def forward(self, x):
    x = F.relu(self.conv1(x))
    x = self.norm(x)
    return F.relu(self.conv2(x))

example_args = (torch.randn(1, 1, 224, 224),)

Enable eager execution

torch.set_grad_enabled(False)

pre_autograd_aten_dialect = capture_pre_autograd_graph(Model(), example_args)

Optionally do quantization:

pre_autograd_aten_dialect = convert_pt2e(prepare_pt2e(pre_autograd_aten_dialect, CustomBackendQuantizer))

aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, example_args)

print(aten_dialect)

edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)

Optionally do delegation:

edge_program = edge_program.to_backend(CustomBackendPartitioner)

executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
ExecutorchBackendConfig(
passes=[], # User-defined passes
)
)

with open("model.pte", "wb") as file:
file.write(executorch_program.buffer)
print("Model saved to model.pte")

which produces the following output:

ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: "f32[20, 1, 5, 5]", arg1_1: "f32[20]", arg2_1: "f32[20]", arg3_1: "f32[20]", arg4_1: "f32[20, 20, 5, 5]", arg5_1: "f32[20]", arg6_1: "i64[]", arg7_1: "f32[20]", arg8_1: "f32[20]", arg9_1: "f32[1, 1, 224, 224]"):
# File: /Users/dino-illumix/Documents/Projects/LightingPredict/executorch_export.py:19 in forward, code: x = F.relu(self.conv1(x))
convolution: "f32[1, 20, 220, 220]" = torch.ops.aten.convolution.default(arg9_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg9_1 = arg0_1 = arg1_1 = None
relu: "f32[1, 20, 220, 220]" = torch.ops.aten.relu.default(convolution); convolution = None

        # File: /Users/dino-illumix/Documents/Projects/LightingPredict/executorch_export.py:20 in forward, code: x = self.norm(x)
        add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1);  arg6_1 = None
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(relu, arg2_1, arg3_1, arg7_1, arg8_1, True, 0.1, 1e-05);  relu = arg2_1 = arg3_1 = arg7_1 = arg8_1 = None
        getitem: "f32[1, 20, 220, 220]" = _native_batch_norm_legit_functional[0]
        getitem_3: "f32[20]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[20]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        
        # File: /Users/dino-illumix/Documents/Projects/LightingPredict/executorch_export.py:21 in forward, code: return F.relu(self.conv2(x))
        convolution_1: "f32[1, 20, 216, 216]" = torch.ops.aten.convolution.default(getitem, arg4_1, arg5_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  getitem = arg4_1 = arg5_1 = None
        relu_1: "f32[1, 20, 216, 216]" = torch.ops.aten.relu.default(convolution_1);  convolution_1 = None
        return (add, getitem_3, getitem_4, relu_1)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='conv1_weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg1_1'), target='conv1_bias', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg2_1'), target='norm_weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg3_1'), target='norm_bias', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg4_1'), target='conv2_weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg5_1'), target='conv2_bias', persistent=None), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg6_1'), target='norm_num_batches_tracked', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg7_1'), target='norm_running_mean', persistent=True), InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg8_1'), target='norm_running_var', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg9_1'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add'), target='norm_num_batches_tracked'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_3'), target='norm_running_mean'), OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='getitem_4'), target='norm_running_var'), OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='relu_1'), target=None)])
Range constraints: {}

Traceback (most recent call last):
File "/Users/dino-illumix/Documents/Projects/LightingPredict/executorch_export.py", line 35, in
edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/program/_program.py", line 635, in to_edge
EXIRATenDialectVerifier()(program.graph_module)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 57, in call
return self._check_graph_module(*args, **kwargs)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 208, in _check_graph_module
_check_valid_op(node.target)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/torch/_export/verifier.py", line 191, in _check_valid_op
self.check_valid_op(op)
File "/Users/dino-illumix/anaconda3/envs/executorch/lib/python3.10/site-packages/executorch/exir/verification/verifier.py", line 87, in check_valid_op
raise SpecViolationError(
torch._export.verifier.SpecViolationError: Operator torch._ops.aten._native_batch_norm_legit_functional.default is not Aten Canonical.

o/s= OSX 14.2.1
torch==2.3.0
executorch==0.2.0
python=3.10.14

I can suppress the error by running:
edge_program: EdgeProgramManager = to_edge(aten_dialect, compile_config=EdgeCompileConfig(_check_ir_validity=False)) # to_edge(aten_dialect

however I then get an error message:
RuntimeError: Missing out variants: {'aten::_native_batch_norm_legit_functional'}

Versions

Collecting environment information...
PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.2.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.40.1)
CMake version: version 3.29.3
Libc version: N/A

Python version: 3.10.14 (main, May 6 2024, 14:42:37) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.2.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Pro

Versions of relevant libraries:
[pip3] executorch==0.2.0
[pip3] numpy==1.26.4
[pip3] onnx==1.16.0
[pip3] torch==2.3.0
[pip3] torchaudio==2.3.0
[pip3] torchsr==1.0.4
[pip3] torchvision==0.18.0
[conda] executorch 0.2.0 pypi_0 pypi
[conda] numpy 1.26.4 pypi_0 pypi
[conda] torch 2.3.0 pypi_0 pypi
[conda] torchaudio 2.3.0 pypi_0 pypi
[conda] torchsr 1.0.4 pypi_0 pypi
[conda] torchvision 0.18.0 pypi_0 pypi

cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@angelayi
Copy link
Contributor

Can you try adding this before you export:

mod = Model()
mod.eval()
torch.export.export(mod, example_args)

By the way, you don't need to call capture_pre_autograd_graph if you are not doing quantization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants