Skip to content

[BUG] Optimizer should validate betas parameter type to prevent unexpected serialization behavior with OmegaConf objects #167319

@PommesPeter

Description

@PommesPeter

🐛 Describe the bug

🐛 Bug Description

PyTorch optimizers (e.g., AdamW, Adam) accept betas parameter as any sequence-like object without type validation. When an OmegaConf ListConfig object is passed, it leads to unexpected behavior during checkpoint serialization because ListConfig retains references to the entire configuration tree.

To Reproduce

import torch
from torch.optim import AdamW
from omegaconf import OmegaConf

# Create a config with nested structure
cfg = OmegaConf.create({
    'model': {'betas': [0.9, 0.999]},
    'data': {'batch_size': 32},
    # ... other configurations
})

# Create optimizer with OmegaConf ListConfig (no error raised)
model = torch.nn.Linear(10, 1)
optimizer = AdamW(model.parameters(), lr=1e-3, betas=cfg.model.betas)

# Save checkpoint
checkpoint = {
    'optimizer_state_dict': optimizer.state_dict()
}
torch.save(checkpoint, 'checkpoint.pt')

# The entire config tree is serialized due to ListConfig._parent reference!
# This causes:
# 1. Unexpectedly large checkpoint files
# 2. PyTorch 2.6+ loading failures with weights_only=True

Expected behavior

The optimizer should validate that betas is a tuple or list, and raise a TypeError if other sequence-like objects are passed:

if not isinstance(betas, (tuple, list)):
    raise TypeError(f"betas should be a tuple or list, got {type(betas)}")

Alternatively, convert the input to tuple:

betas = tuple(betas)  # Force conversion

Actual behavior

The scripts I check my checkpoint below:

#!/usr/bin/env python3
"""
Script to inspect PyTorch checkpoint files and find all OmegaConf ListConfig objects.
"""

import sys
import torch
from pathlib import Path
from omegaconf import ListConfig, DictConfig


def inspect_value(obj, path="root", found_objects=None, visited=None):
    """
    Recursively inspect an object and find all OmegaConf ListConfig instances.
    
    Args:
        obj: Object to inspect
        path: Current path in the object tree
        found_objects: List to store found ListConfig objects
        visited: Set of visited object ids to avoid infinite recursion
    """
    if found_objects is None:
        found_objects = []
    if visited is None:
        visited = set()
    
    # Avoid infinite recursion by checking if we've seen this object
    obj_id = id(obj)
    if obj_id in visited:
        return found_objects
    visited.add(obj_id)
    
    # Check if current object is ListConfig
    if isinstance(obj, ListConfig):
        found_objects.append({
            'path': path,
            'type': 'ListConfig',
            'value': obj,
            'length': len(obj)
        })
        print(f"Found ListConfig at: {path}")
        print(f"  Length: {len(obj)}")
        print(f"  Content: {obj}")
        print()
    
    # Check if current object is DictConfig (also from OmegaConf)
    if isinstance(obj, DictConfig):
        found_objects.append({
            'path': path,
            'type': 'DictConfig',
            'value': obj,
            'keys': list(obj.keys())
        })
        print(f"Found DictConfig at: {path}")
        print(f"  Keys: {list(obj.keys())}")
        print()
    
    # Recursively check dictionaries
    if isinstance(obj, dict):
        for key, value in obj.items():
            new_path = f"{path}['{key}']"
            inspect_value(value, new_path, found_objects, visited)
    
    # Recursively check lists and tuples
    elif isinstance(obj, (list, tuple)):
        for idx, item in enumerate(obj):
            new_path = f"{path}[{idx}]"
            inspect_value(item, new_path, found_objects, visited)
    
    # Check object attributes if it has __dict__
    elif hasattr(obj, '__dict__') and not isinstance(obj, (str, bytes, int, float, bool, type(None))):
        try:
            for key, value in obj.__dict__.items():
                new_path = f"{path}.{key}"
                inspect_value(value, new_path, found_objects, visited)
        except:
            pass
    
    return found_objects


def main():
    if len(sys.argv) < 2:
        print("Usage: python inspect_pt_omegaconf.py <path_to_pt_file>")
        print("Example: python inspect_pt_omegaconf.py checkpoints/step_5000.pt")
        sys.exit(1)
    
    pt_file = Path(sys.argv[1])
    
    if not pt_file.exists():
        print(f"Error: File not found: {pt_file}")
        sys.exit(1)
    
    print(f"Loading checkpoint: {pt_file}")
    print("=" * 80)
    
    # Load with weights_only=False to allow loading OmegaConf objects
    try:
        checkpoint = torch.load(pt_file, weights_only=False, map_location='cpu')
        print(f"✓ Successfully loaded checkpoint")
        print()
    except Exception as e:
        print(f"✗ Failed to load with weights_only=False: {e}")
        print()
        return
    
    # Inspect the checkpoint
    print("Inspecting checkpoint for OmegaConf objects...")
    print("=" * 80)
    print()
    
    found_objects = inspect_value(checkpoint)
    
    print("=" * 80)
    print(f"Summary: Found {len(found_objects)} OmegaConf objects")
    print()
    
    if found_objects:
        print("All OmegaConf objects found:")
        for idx, obj_info in enumerate(found_objects, 1):
            print(f"{idx}. Path: {obj_info['path']}")
            print(f"   Type: {obj_info['type']}")
            if obj_info['type'] == 'ListConfig':
                print(f"   Length: {obj_info['length']}")
            elif obj_info['type'] == 'DictConfig':
                print(f"   Keys: {obj_info['keys']}")
            print()
    else:
        print("No OmegaConf ListConfig or DictConfig objects found in the checkpoint.")
    
    # Also show top-level keys
    print("=" * 80)
    print("Top-level keys in checkpoint:")
    if isinstance(checkpoint, dict):
        for key in checkpoint.keys():
            value = checkpoint[key]
            value_type = type(value).__name__
            if isinstance(value, dict):
                print(f"  - {key}: dict with {len(value)} keys")
            elif isinstance(value, (list, tuple)):
                print(f"  - {key}: {value_type} with {len(value)} items")
            elif isinstance(value, (ListConfig, DictConfig)):
                print(f"  - {key}: {value_type} ← OmegaConf object!")
            else:
                print(f"  - {key}: {value_type}")


if __name__ == "__main__":
    main()
Found DictConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[0]
  Keys: ['key', 'raw_shape', 'shape']

Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[0]._content['raw_shape']
  Length: 3
  Content: [3, 720, 1280]

Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[0]._content['shape']
  Length: 3
  Content: [3, 224, 224]

Found DictConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[1]
  Keys: ['key', 'raw_shape', 'shape']

Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[1]._content['raw_shape']
  Length: 3
  Content: [3, 360, 640]

Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[1]._content['shape']
  Length: 3
  Content: [3, 224, 224]

Found DictConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[2]
  Keys: ['key', 'raw_shape', 'shape']

Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[2]._content['raw_shape']
  Length: 3
  Content: [3, 360, 640]

Found ListConfig at: root['optimizer_state_dict']['param_groups'][0]['betas']._parent._content['model']._content['shape_meta']._content['images']._content[2]._content['shape']
  Length: 3
  Content: [3, 224, 224]
  • No error or warning is raised when passing OmegaConf ListConfig objects
  • The entire configuration tree gets serialized into optimizer.state_dict()
  • Checkpoint loading fails in PyTorch 2.6+ with weights_only=True:

Versions

PyTorch version: 2.7.1+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.3) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.10.16 (main, Mar 17 2025, 21:01:46) [Clang 20.1.0 ] (64-bit runtime)
Python platform: Linux-5.4.250-9-velinux1u2-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H20

Nvidia driver version: 535.161.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.1.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   52 bits physical, 57 bits virtual
CPU(s):                          180
On-line CPU(s) list:             0-168
Off-line CPU(s) list:            169-179
Thread(s) per core:              1
Core(s) per socket:              45
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           143
Model name:                      Intel(R) Xeon(R) Platinum 8457C
Stepping:                        8
CPU MHz:                         2600.000
BogoMIPS:                        5200.00
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       2.1 MiB
L1i cache:                       1.4 MiB
L2 cache:                        90 MiB
L3 cache:                        97.5 MiB
NUMA node0 CPU(s):               0-89
NUMA node1 CPU(s):               90-179
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Unknown: No mitigations
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Vulnerable
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Mitigation; TSX disabled
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b md_clear arch_capabilities

Versions of relevant libraries:
[pip3] Could not collect
[conda] libopenvino-pytorch-frontend 2025.2.0             hecca717_1    conda-forge
[conda] numpy                     2.3.4                    pypi_0    pypi
[conda] nvidia-cublas-cu12        12.8.4.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.8.90                  pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.8.93                  pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.8.90                  pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.10.2.21                pypi_0    pypi
[conda] nvidia-cufft-cu12         11.3.3.83                pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.9.90                pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.7.3.90                pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.5.8.93                pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.7.1                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.27.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.8.93                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.8.90                  pypi_0    pypi
[conda] tbb                       2022.2.0             hb60516a_1    conda-forge
[conda] torch                     2.9.0                    pypi_0    pypi
[conda] triton                    3.5.0                    pypi_0    pypi

cc @vincentqb @jbschlosser @albanD @janeyx99 @crcrpar

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionablemodule: optimizerRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions