Skip to content

reentrant torch.utils.checkpoint does not work with NamedTuple outputs #85088

@rohan-varma

Description

@rohan-varma

🐛 Describe the bug

The reentrant (default) version of torch.utils.checkpoint does not work with NamedTuple outputs, it instead returns back a regular tuple so the outputs cannot be accessed via their original names. For a repro, see:

import torch
from torch.utils.checkpoint import checkpoint

from collections import namedtuple

Tup = namedtuple("Tup", "a b c")

tup = Tup(torch.ones(10, requires_grad=True), torch.ones(10, requires_grad=True), torch.ones(10, requires_grad=True))

def foo(tup):
	return Tup(tup.a + tup.b, tup.b, tup.a + tup.b + tup.c)

import pdb ; pdb.set_trace()
out = checkpoint(foo, tup, use_reentrant=False)

Versions

main

cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @kwen2501

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions