-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Closed
Labels
module: mpsRelated to Apple Metal Performance Shaders frameworkRelated to Apple Metal Performance Shaders frameworkoncall: exportoncall: pt2
Description
🐛 Describe the bug
import torch
from torch import nn
from torch.export import export
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.skip = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
def forward(self, x):
identity = self.skip(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
if __name__ == "__main__":
with torch.device("mps"):
model = ResidualBlock(3, 64)
exp = export(model, (torch.zeros(64, 3, 1, 1),))
print(exp.graph)
causes this bug
Traceback (most recent call last):
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "__main__.py", line 44, in <module>
exp = export(model, (torch.zeros(64, 3, 1, 1),))
File ".venv/lib/python3.13/site-packages/torch/export/__init__.py", line 348, in export
from ._trace import _export
File ".venv/lib/python3.13/site-packages/torch/export/_trace.py", line 15, in <module>
import torch._dynamo
File ".venv/lib/python3.13/site-packages/torch/_dynamo/__init__.py", line 13, in <module>
from . import config, convert_frame, eval_frame, resume_execution
File ".venv/lib/python3.13/site-packages/torch/_dynamo/convert_frame.py", line 52, in <module>
from torch._dynamo.symbolic_convert import TensorifyState
File ".venv/lib/python3.13/site-packages/torch/_dynamo/symbolic_convert.py", line 57, in <module>
from . import (
...<6 lines>...
)
File ".venv/lib/python3.13/site-packages/torch/_dynamo/trace_rules.py", line 32, in <module>
from .variables import (
...<11 lines>...
)
File ".venv/lib/python3.13/site-packages/torch/_dynamo/variables/__init__.py", line 19, in <module>
from .base import VariableTracker
File ".venv/lib/python3.13/site-packages/torch/_dynamo/variables/base.py", line 581, in <module>
from . import builder
File ".venv/lib/python3.13/site-packages/torch/_dynamo/variables/builder.py", line 86, in <module>
from ..side_effects import SideEffects
File ".venv/lib/python3.13/site-packages/torch/_dynamo/side_effects.py", line 21, in <module>
from .codegen import PyCodegen
File ".venv/lib/python3.13/site-packages/torch/_dynamo/codegen.py", line 54, in <module>
from .variables.torch_function import TensorWithTFOverrideVariable
File ".venv/lib/python3.13/site-packages/torch/_dynamo/variables/torch_function.py", line 193, in <module>
populate_builtin_to_tensor_fn_map()
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
File ".venv/lib/python3.13/site-packages/torch/_dynamo/variables/torch_function.py", line 187, in populate_builtin_to_tensor_fn_map
setup_fn(op)
~~~~~~~~^^^^
File ".venv/lib/python3.13/site-packages/torch/_dynamo/variables/torch_function.py", line 175, in <lambda>
lambda o: o(1, inp1),
~^^^^^^^^^
File ".venv/lib/python3.13/site-packages/torch/_tensor.py", line 38, in wrapped
return handle_torch_function(wrapped, args, *args, **kwargs)
File ".venv/lib/python3.13/site-packages/torch/overrides.py", line 1721, in handle_torch_function
result = mode.__torch_function__(public_api, types, args, kwargs)
File ".venv/lib/python3.13/site-packages/torch/_dynamo/variables/torch_function.py", line 152, in __torch_function__
return func(*args, **kwargs)
File ".venv/lib/python3.13/site-packages/torch/_tensor.py", line 38, in wrapped
return handle_torch_function(wrapped, args, *args, **kwargs)
File ".venv/lib/python3.13/site-packages/torch/overrides.py", line 1721, in handle_torch_function
result = mode.__torch_function__(public_api, types, args, kwargs)
File ".venv/lib/python3.13/site-packages/torch/utils/_device.py", line 104, in __torch_function__
return func(*args, **kwargs)
File ".venv/lib/python3.13/site-packages/torch/_tensor.py", line 39, in wrapped
return f(*args, **kwargs)
File ".venv/lib/python3.13/site-packages/torch/_tensor.py", line 1141, in __rfloordiv__
return torch.floor_divide(other, self)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^
RuntimeError: Placeholder storage has not been allocated on MPS device!
This can be trivially fixed by importing torch._dynamo before executing torch.export.export:
import torch
import torch._dynamo
from torch import nn
from torch.export import export
class ResidualBlock(nn.Module):
...
if __name__ == "__main__":
with torch.device("mps"):
model = ResidualBlock(3, 64)
exp = export(model, (torch.zeros(64, 3, 1, 1),))
print(exp.graph) # prints fine
Versions
Collecting environment information...
PyTorch version: 2.7.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.5 (arm64)
GCC version: Could not collect
Clang version: 20.1.6
CMake version: version 4.0.2
Libc version: N/A
Python version: 3.13.4 (main, Jun 3 2025, 15:34:24) [Clang 17.0.0 (clang-1700.0.13.3)] (64-bit runtime)
Python platform: macOS-15.5-arm64-arm-64bit-Mach-O
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M3 Max
Versions of relevant libraries:
[pip3] fastrtc-moonshine-onnx==20241016
[pip3] numpy==2.2.0
[pip3] onnxruntime==1.22.1
[pip3] torch==2.7.1
[pip3] torchaudio==2.7.1
[pip3] torchvision==0.22.1
[conda] Could not collect
cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
Metadata
Metadata
Assignees
Labels
module: mpsRelated to Apple Metal Performance Shaders frameworkRelated to Apple Metal Performance Shaders frameworkoncall: exportoncall: pt2