Skip to content

CHECKPOINT_PREFIX is not stripped when non-root module is activation checkpointed #117399

@mvpatel2000

Description

@mvpatel2000

🐛 Describe the bug

When training large models, we often want to activation checkpoint something smaller than the wrap module for FSDP. For example, we might want to only activation checkpoint attention in a transformer block.

Unfortunately, when calling get_state_dict with the new distributed checkpoint interface, the _CHECKPOINT_PREFIX from checkpoint wrapper is not properly stripped when we activation checkpoint submodules.

We have to monkeypatch torch here to strip this always.

Versions

Torch 2.1.2 / Nightly for Torch 2.2

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions