Skip to content

Does pytorch xla work with gradient checkpointing? #1571

@david-alexander-white

Description

@david-alexander-white

❓ Questions and Help

I've got a decently big model that uses gradient checkpointing (using torch.utils.checkpoint), seems to max out mem usage at 4GB, and fits inside of my 8 GB of GPU memory with no problem.

However, when I try to run it on a TPU using XLA, it uses more than 8GM of memory and gets an OOM.

Reading the xla API guide, it seems to me that the lazy evaluation of xla tensors might actually stop the checkpointing, because xla might compile down what it sees as wasted computation (can explain this more if necessary).

However, I also understand there could be other issues, like something having to do with alignment to 128, which I also don't understand.

Just hoping to get some guidance to aid my further debugging.

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleHas not had recent activity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions