-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Description
🐛 Describe the bug
Summary
When using torch.export
to export a model that reads and modifies internal buffers (e.g., BatchNorm.running_mean.zero_()
) during forward()
, the process fails with:
KeyError: 'L__self___bn.running_mean'
However, if the line output = self.bn(t1)
is commented out, the export proceeds normally — even though the exact same parameter and buffer accesses are present.
This suggests that:
-
Modifying running_mean during forward is not explicitly disallowed — but silently triggers export failure
-
The error message is confusing and unhelpful
-
There might be incomplete or inconsistent state tracking of internal buffer mutations inside forward
Repro Steps (Minimal Script)
import torch
import torch.nn as nn
from typing import List
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 8, 1, stride=1, padding=1)
self.bn = nn.BatchNorm2d(8)
self.bn.eval()
self.init_components: List[nn.Module] = [self.conv, self.bn]
def forward(self, x: torch.Tensor) -> torch.Tensor:
t1 = self.conv(x)
print("Before BN:")
for name, buf in self.bn.named_buffers():
print(f"{name}: {buf.shape}, {buf}")
output = self.bn(t1) # ← Commenting this out avoids the KeyError
print("After BN:")
for name, buf in self.bn.named_buffers():
print(f"{name}: {buf.shape}, {buf}")
# Modifying buffer in-place (no shape change)
self.bn.running_mean.zero_()
return output
model = MyModel()
input_tensor = torch.randn(1, 3, 32, 32)
model.eval()
with torch.no_grad():
output = model(input_tensor)
# This line triggers KeyError
torch.export.export(model, (input_tensor,))
Fails with:
KeyError: 'L__self___bn.running_mean'
Technical Observations
-
The buffer exists and is printed both before and after the call to
self.bn(t1)
-
Export seems to track the buffer name
'L__self___bn.running_mean'
, but fails to find it during _rewrite_dynamo_tensor_constants -
Commenting out
output = self.bn(t1)
— even though the same buffer is still accessed — prevents the crash, suggesting a non-robust trace path resolution -
This implies export’s symbolic tracing is sensitive to usage context, and may not detect side-effects that occur post graph-capture
Expected Behavior
- If modifying
.running_mean
inforward()
is not allowed during export, a clear and early error message should be provided. - If PyTorch allows it (e.g. because shape/type doesn’t change), then export should succeed or handle it gracefully.
- The buffer is not dynamically created, and its value is only zeroed-out. There is no structural change, so failure is unexpected.
Versions
Click to expand envs
PyTorch version: 2.7.1a0+gite2d141d
Is debug build: True
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.35
Is CUDA available: True
cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4