-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Open
Labels
high prioritymodule: activation checkpointingRelated to activation checkpointingRelated to activation checkpointingoncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queuetriage review
Description
🐛 Describe the bug
The reentrant (default) version of torch.utils.checkpoint does not work with NamedTuple outputs, it instead returns back a regular tuple so the outputs cannot be accessed via their original names. For a repro, see:
import torch
from torch.utils.checkpoint import checkpoint
from collections import namedtuple
Tup = namedtuple("Tup", "a b c")
tup = Tup(torch.ones(10, requires_grad=True), torch.ones(10, requires_grad=True), torch.ones(10, requires_grad=True))
def foo(tup):
return Tup(tup.a + tup.b, tup.b, tup.a + tup.b + tup.c)
import pdb ; pdb.set_trace()
out = checkpoint(foo, tup, use_reentrant=False)
Versions
main
cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @kwen2501
Metadata
Metadata
Assignees
Labels
high prioritymodule: activation checkpointingRelated to activation checkpointingRelated to activation checkpointingoncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queuetriage review