Skip to content

Commit

Permalink
condition some logic on PyTorch 2.2.x for bc
Browse files Browse the repository at this point in the history
  • Loading branch information
speediedan committed Oct 21, 2023
1 parent b343ba9 commit fbe11f5
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tests/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
_TORCH_GREATER_EQUAL_2_2
)
from lightning.pytorch import seed_everything, Trainer
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecisionPlugin
Expand Down Expand Up @@ -344,13 +345,14 @@ def configure_model(self) -> None:
self.layer[i] = wrap(layer)
self.layer = wrap(self.layer)

# starting with https://github.com/pytorch/pytorch/pull/108033, FSDP no longer moves ignored parameters
# (or buffers) to device. We need to manually move them to device in versions > 2.1.x (precise version TBD)
for param in self.layer._ignored_params:
with torch.no_grad():
param.data = param.to(self.device)
if param.grad is not None:
param.grad.data = param.grad.to(self.device)
if _TORCH_GREATER_EQUAL_2_2:
# starting with https://github.com/pytorch/pytorch/pull/108033, FSDP no longer moves ignored parameters
# (or buffers) to device. We need to manually move them to device in versions > 2.1.x (precise version TBD)
for param in self.layer._ignored_params:
with torch.no_grad():
param.data = param.to(self.device)
if param.grad is not None:
param.grad.data = param.grad.to(self.device)

# verify activation checkpointing can be manually applied
check_fn = lambda submodule: isinstance(submodule, tuple([torch.nn.Linear])) # noqa E731
Expand Down

0 comments on commit fbe11f5

Please sign in to comment.