-
Notifications
You must be signed in to change notification settings - Fork 25.9k
Closed
Labels
actionablemodule: optimizerRelated to torch.optimRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
🐛 Bug Description
PyTorch optimizers (e.g., AdamW, Adam) accept betas parameter as any sequence-like object without type validation. When an OmegaConf ListConfig object is passed, it leads to unexpected behavior during checkpoint serialization because ListConfig retains references to the entire configuration tree.
To Reproduce
import torch
from torch.optim import AdamW
from omegaconf import OmegaConf
# Create a config with nested structure
cfg = OmegaConf.create({
'model': {'betas': [0.9, 0.999]},
'data': {'batch_size': 32},
# ... other configurations
})
# Create optimizer with OmegaConf ListConfig (no error raised)
model = torch.nn.Linear(10, 1)
optimizer = AdamW(model.parameters(), lr=1e-3, betas=cfg.model.betas)
# Save checkpoint
checkpoint = {
'optimizer_state_dict': optimizer.state_dict()
}
torch.save(checkpoint, 'checkpoint.pt')
# The entire config tree is serialized due to ListConfig._parent reference!
# This causes:
# 1. Unexpectedly large checkpoint files
# 2. PyTorch 2.6+ loading failures with weights_only=TrueExpected behavior
The optimizer should validate that betas is a tuple or list, and raise a TypeError if other sequence-like objects are passed:
if not isinstance(betas, (tuple, list)):
raise TypeError(f"betas should be a tuple or list, got {type(betas)}")Alternatively, convert the input to tuple:
betas = tuple(betas) # Force conversionActual behavior
The scripts I check my checkpoint below:
#!/usr/bin/env python3
"""
Script to inspect PyTorch checkpoint files and find all OmegaConf ListConfig objects.
"""
import sys
import torch
from pathlib import Path
from omegaconf import ListConfig, DictConfig
def inspect_value(obj, path="root", found_objects=None, visited=None):
"""
Recursively inspect an object and find all OmegaConf ListConfig instances.
Args:
obj: Object to inspect
path: Current path in the object tree
found_objects: List to store found ListConfig objects
visited: Set of visited object ids to avoid infinite recursion
"""
if found_objects is None:
found_objects = []
if visited is None:
visited = set()
# Avoid infinite recursion by checking if we've seen this object
obj_id = id(obj)
if obj_id in visited:
return found_objects
visited.add(obj_id)
# Check if current object is ListConfig
if isinstance(obj, ListConfig):
found_objects.append({
'path': path,
'type': 'ListConfig',
'value': obj,
'length': len(obj)
})
print(f"Found ListConfig at: {path}")
print(f" Length: {len(obj)}")
print(f" Content: {obj}")
print()
# Check if current object is DictConfig (also from OmegaConf)
if isinstance(obj, DictConfig):
found_objects.append({
'path': path,
'type': 'DictConfig',
'value': obj,
'keys': list(obj.keys())
})
print(f"Found DictConfig at: {path}")
print(f" Keys: {list(obj.keys())}")
print()
# Recursively check dictionaries
if isinstance(obj, dict):
for key, value in obj.items():
new_path = f"{path}['{key}']"
inspect_value(value, new_path, found_objects, visited)
# Recursively check lists and tuples
elif isinstance(obj, (list, tuple)):
for idx, item in enumerate(obj):
new_path = f"{path}[{idx}]"
inspect_value(item, new_path, found_objects, visited)
# Check object attributes if it has __dict__
elif hasattr(obj, '__dict__') and not isinstance(obj, (str, bytes, int, float, bool, type(None))):
try:
for key, value in obj.__dict__.items():
new_path = f"{path}.{key}"
inspect_value(value, new_path, found_objects, visited)
except:
pass
return found_objects
def main():
if len(sys.argv) < 2:
print("Usage: python inspect_pt_omegaconf.py <path_to_pt_file>")
print("Example: python inspect_pt_omegaconf.py checkpoints/step_5000.pt")
sys.exit(1)
pt_file = Path(sys.argv[1])
if not pt_file.exists():
print(f"Error: File not found: {pt_file}")
sys.exit(1)
print(f"Loading checkpoint: {pt_file}")
print("=" * 80)
# Load with weights_only=False to allow loading OmegaConf objects
try:
checkpoint = torch.load(pt_file, weights_only=False, map_location='cpu')
print(f"✓ Successfully loaded checkpoint")
print()
except Exception as e:
print(f"✗ Failed to load with weights_only=False: {e}")
print()
return
# Inspect the checkpoint
print("Inspecting checkpoint for OmegaConf objects...")
print("=" * 80)
print()
found_objects = inspect_value(checkpoint)
print("=" * 80)
print(f"Summary: Found {len(found_objects)} OmegaConf objects")
print()
if found_objects:
print("All OmegaConf objects found:")
for idx, obj_info in enumerate(found_objects, 1):
print(f"{idx}. Path: {obj_info['path']}")
print(f" Type: {obj_info['type']}")
if obj_info['type'] == 'ListConfig':
print(f" Length: {obj_info['length']}")
elif obj_info['type'] == 'DictConfig':
print(f" Keys: {obj_info['keys']}")
print()
else:
print("No OmegaConf ListConfig or DictConfig objects found in the checkpoint.")
# Also show top-level keys
print("=" * 80)
print("Top-level keys in checkpoint:")
if isinstance(checkpoint, dict):
for key in checkpoint.keys():
value = checkpoint[key]
value_type = type(value).__name__
if isinstance(value, dict):
print(f" - {key}: dict with {len(value)} keys")
elif isinstance(value, (list, tuple)):
print(f" - {key}: {value_type} with {len(value)} items")
elif isinstance(value, (ListConfig, DictConfig)):
print(f" - {key}: {value_type} ← OmegaConf object!")
else:
print(f" - {key}: {value_type}")
if __name__ == "__main__":
main()Found DictConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[0]
Keys: ['key', 'raw_shape', 'shape']
Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[0]._content['raw_shape']
Length: 3
Content: [3, 720, 1280]
Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[0]._content['shape']
Length: 3
Content: [3, 224, 224]
Found DictConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[1]
Keys: ['key', 'raw_shape', 'shape']
Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[1]._content['raw_shape']
Length: 3
Content: [3, 360, 640]
Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[1]._content['shape']
Length: 3
Content: [3, 224, 224]
Found DictConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[2]
Keys: ['key', 'raw_shape', 'shape']
Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[2]._content['raw_shape']
Length: 3
Content: [3, 360, 640]
Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[2]._content['shape']
Length: 3
Content: [3, 224, 224]
- No error or warning is raised when passing OmegaConf
ListConfigobjects - The entire configuration tree gets serialized into
optimizer.state_dict() - Checkpoint loading fails in PyTorch 2.6+ with
weights_only=True:
Versions
PyTorch version: 2.7.1+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.3) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.10.16 (main, Mar 17 2025, 21:01:46) [Clang 20.1.0 ] (64-bit runtime)
Python platform: Linux-5.4.250-9-velinux1u2-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H20
Nvidia driver version: 535.161.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.1.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 52 bits physical, 57 bits virtual
CPU(s): 180
On-line CPU(s) list: 0-168
Off-line CPU(s) list: 169-179
Thread(s) per core: 1
Core(s) per socket: 45
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 143
Model name: Intel(R) Xeon(R) Platinum 8457C
Stepping: 8
CPU MHz: 2600.000
BogoMIPS: 5200.00
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 2.1 MiB
L1i cache: 1.4 MiB
L2 cache: 90 MiB
L3 cache: 97.5 MiB
NUMA node0 CPU(s): 0-89
NUMA node1 CPU(s): 90-179
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Unknown: No mitigations
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Vulnerable
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b md_clear arch_capabilities
Versions of relevant libraries:
[pip3] Could not collect
[conda] libopenvino-pytorch-frontend 2025.2.0 hecca717_1 conda-forge
[conda] numpy 2.3.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] tbb 2022.2.0 hb60516a_1 conda-forge
[conda] torch 2.9.0 pypi_0 pypi
[conda] triton 3.5.0 pypi_0 pypiMetadata
Metadata
Assignees
Labels
actionablemodule: optimizerRelated to torch.optimRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module