Skip to content

Commit

Permalink
Add non-recursive module.to_empty option
Browse files Browse the repository at this point in the history
ghstack-source-id: c3ce9b2281a8b659f8a59c459f622b38ece572c1
Pull Request resolved: #104197
  • Loading branch information
mikaylagawarecki committed Jun 26, 2023
1 parent edc9c0d commit 444ebe7
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
33 changes: 33 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12314,6 +12314,39 @@ def forward(self, x):
m.to_empty(device='meta')
m(input)

def test_module_to_empty_non_recursive(self, device):
class Layer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_features, out_features))
self.register_buffer('buf', torch.randn(out_features))

def forward(self, x):
return x @ self.weight + self.buf

class MyModule(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_features, out_features))
self.register_buffer('buf1', torch.randn(out_features))
self.layer = Layer(out_features, out_features)

def forward(self, x):
return self.layer(x @ self.weight + self.buf1)

with torch.device('meta'):
m = MyModule(3, 5)

m.to_empty(device=device, recurse=False)

# params/buffers of parent should have been materialized on device
self.assertTrue(not m.weight.is_meta)
self.assertTrue(not m.buf1.is_meta)

# parameters/buffers of children submodules should still be on meta
for p in (*m.layer.parameters(), *m.layer.buffers()):
self.assertTrue(p.is_meta)

@skipMeta
def test_skip_init(self, device):
torch.manual_seed(1)
Expand Down
13 changes: 8 additions & 5 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,9 +796,10 @@ def set_extra_state(self, state: Any):
"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
"to report this bug.")

def _apply(self, fn):
for module in self.children():
module._apply(fn)
def _apply(self, fn, recurse=True):
if recurse:
for module in self.children():
module._apply(fn)

def compute_should_use_set_data(tensor, tensor_applied):
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
Expand Down Expand Up @@ -1015,17 +1016,19 @@ def bfloat16(self: T) -> T:
"""
return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)

def to_empty(self: T, *, device: Union[str, device]) -> T:
def to_empty(self: T, *, device: Union[str, device], recurse: bool = True) -> T:
r"""Moves the parameters and buffers to the specified device without copying storage.
Args:
device (:class:`torch.device`): The desired device of the parameters
and buffers in this module.
recurse (bool): Whether parameters and buffers of submodules should
be recursively moved to the specified device.
Returns:
Module: self
"""
return self._apply(lambda t: torch.empty_like(t, device=device))
return self._apply(lambda t: torch.empty_like(t, device=device), recurse=recurse)

@overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
Expand Down

0 comments on commit 444ebe7

Please sign in to comment.