Skip to content

torch.export fails with KeyError when BatchNorm.running_mean is read and modified, even when shape/value is unchanged #156167

@tinywisdom

Description

@tinywisdom

🐛 Describe the bug

Summary

When using torch.export to export a model that reads and modifies internal buffers (e.g., BatchNorm.running_mean.zero_()) during forward(), the process fails with:

KeyError: 'L__self___bn.running_mean'

However, if the line output = self.bn(t1) is commented out, the export proceeds normally — even though the exact same parameter and buffer accesses are present.

This suggests that:

  1. Modifying running_mean during forward is not explicitly disallowed — but silently triggers export failure

  2. The error message is confusing and unhelpful

  3. There might be incomplete or inconsistent state tracking of internal buffer mutations inside forward

Repro Steps (Minimal Script)

import torch
import torch.nn as nn
from typing import List

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 8, 1, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(8)
        self.bn.eval()
        self.init_components: List[nn.Module] = [self.conv, self.bn]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        t1 = self.conv(x)
        print("Before BN:")
        for name, buf in self.bn.named_buffers():
            print(f"{name}: {buf.shape}, {buf}")

        output = self.bn(t1)  # ← Commenting this out avoids the KeyError

        print("After BN:")
        for name, buf in self.bn.named_buffers():
            print(f"{name}: {buf.shape}, {buf}")
        
        # Modifying buffer in-place (no shape change)
        self.bn.running_mean.zero_()
        return output

model = MyModel()
input_tensor = torch.randn(1, 3, 32, 32)

model.eval()
with torch.no_grad():
    output = model(input_tensor)

# This line triggers KeyError
torch.export.export(model, (input_tensor,))

Fails with:

KeyError: 'L__self___bn.running_mean'

Technical Observations

  • The buffer exists and is printed both before and after the call to self.bn(t1)

  • Export seems to track the buffer name 'L__self___bn.running_mean', but fails to find it during _rewrite_dynamo_tensor_constants

  • Commenting out output = self.bn(t1) — even though the same buffer is still accessed — prevents the crash, suggesting a non-robust trace path resolution

  • This implies export’s symbolic tracing is sensitive to usage context, and may not detect side-effects that occur post graph-capture

Expected Behavior

  • If modifying .running_mean in forward() is not allowed during export, a clear and early error message should be provided.
  • If PyTorch allows it (e.g. because shape/type doesn’t change), then export should succeed or handle it gracefully.
  • The buffer is not dynamically created, and its value is only zeroed-out. There is no structural change, so failure is unexpected.

Versions

Click to expand envs
PyTorch version: 2.7.1a0+gite2d141d
Is debug build: True
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.35
Is CUDA available: True

cc @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

Labels

export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next steponcall: exportoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions