-
Notifications
You must be signed in to change notification settings - Fork 560
Closed
Labels
staleHas not had recent activityHas not had recent activity
Description
❓ Questions and Help
I've got a decently big model that uses gradient checkpointing (using torch.utils.checkpoint), seems to max out mem usage at 4GB, and fits inside of my 8 GB of GPU memory with no problem.
However, when I try to run it on a TPU using XLA, it uses more than 8GM of memory and gets an OOM.
Reading the xla API guide, it seems to me that the lazy evaluation of xla tensors might actually stop the checkpointing, because xla might compile down what it sees as wasted computation (can explain this more if necessary).
However, I also understand there could be other issues, like something having to do with alignment to 128, which I also don't understand.
Just hoping to get some guidance to aid my further debugging.
fanshiqing and ronghanghu
Metadata
Metadata
Assignees
Labels
staleHas not had recent activityHas not had recent activity