Skip to content

Commit

Permalink
Make adding buffers more like adding parameters (#104069)
Browse files Browse the repository at this point in the history
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new `Buffer` class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the `register_buffer` method has not been changed. The `persistent` parameter in the `Buffer` type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new `Buffer` type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the `Buffer` type can be used as a drop in replacement for `register_buffer` as it just leads to `register_buffer` being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.

Fixes #35735

Pull Request resolved: #104069
Approved by: https://github.com/mikaylagawarecki
  • Loading branch information
ekamiti authored and pytorchmergebot committed Jul 17, 2023
1 parent 4fc47b4 commit 32d422f
Show file tree
Hide file tree
Showing 41 changed files with 267 additions and 148 deletions.
1 change: 1 addition & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ These are the basic building blocks for graphs:
:nosignatures:
:template: classtemplate.rst

~parameter.Buffer
~parameter.Parameter
~parameter.UninitializedParameter
~parameter.UninitializedBuffer
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/fsdp/test_fsdp_flatten_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _get_transformer(self, seed=0):
dim_feedforward=128,
dropout=0.1,
)
module.register_buffer("dummy_buffer", torch.tensor(1.0))
module.dummy_buffer = nn.Buffer(torch.tensor(1.0))

def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic
Expand Down
10 changes: 5 additions & 5 deletions test/distributed/fsdp/test_fsdp_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,11 +453,11 @@ def init_nested_wrapped_module():

# Check that `device_id` with `sync_module_states=True` works
nested_wrapped_module = init_nested_wrapped_module()
nested_wrapped_module.register_buffer(
"buf", torch.ones((2, 2), device="cpu") * self.rank
nested_wrapped_module.buf = nn.Buffer(
torch.ones((2, 2), device="cpu") * self.rank
)
nested_wrapped_module.module[0].register_buffer(
"buf", torch.ones((3, 2), device="cpu") * self.rank
nested_wrapped_module.module[0].buf = nn.Buffer(
torch.ones((3, 2), device="cpu") * self.rank
)
nested_wrapped_module = FSDP(
nested_wrapped_module,
Expand Down Expand Up @@ -705,7 +705,7 @@ def __init__(self, rank):
torch.manual_seed(rank)
torch.cuda.manual_seed(rank)
self.lin = nn.Linear(10, 10, bias=False)
self.register_buffer("buffer", torch.ones(1) * rank)
self.buffer = nn.Buffer(torch.ones(1) * rank)

m = MyModel(self.rank).cuda()
_assert_module_states(
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/fsdp/test_fsdp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, wrap_fsdp, register_buffers=False, ignore_inner=False):
super().__init__()
self.inner = Linear(*INNER_SHAPE)
if register_buffers:
self.inner.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
self.inner.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE))
self.inner.register_buffer(
"non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
)
Expand All @@ -97,7 +97,7 @@ def __init__(self, wrap_fsdp, register_buffers=False, ignore_inner=False):
)
self.outer = Linear(*OUTER_SHAPE)
if register_buffers:
self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
self.outer.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE))
self.outer.register_buffer(
"non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False
)
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/fsdp/test_fsdp_unshard_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def _test_named_parameters_and_buffers(self, prefix: str, recurse: bool):
CUDAInitMode.CUDA_BEFORE,
deterministic=True,
)
model.register_buffer("buffer", torch.ones(1))
model.buffer = nn.Buffer(torch.ones(1))
# Wrap the top-level with FSDP since `named_parameters()` and
# `named_buffers` will contain FSDP prefixes if called on a non-FSDP
# root module
Expand All @@ -436,7 +436,7 @@ def _test_named_parameters_and_buffers(self, prefix: str, recurse: bool):
),
self.process_group,
)
fsdp_model.register_buffer("buffer", torch.ones(1))
fsdp_model.buffer = nn.Buffer(torch.ones(1))
with FSDP.summon_full_params(fsdp_model):
for call in ["named_parameters", "named_buffers"]:
for (n1, p1), (n2, p2) in itertools.zip_longest(
Expand Down
3 changes: 1 addition & 2 deletions test/distributed/optim/test_zero_redundancy_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,7 @@ def test_local_optimizer_parity(
torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM),
torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM),
).to(self.device)
model.register_buffer(
"test_buffer",
model.test_buffer = torch.nn.Buffer(
torch.ones((1), device=self.device) * self.rank,
)
# Define models/optimizers for DDP with ZeRO and DDP with local
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def test_data_parallel_buffers_requiring_grad(self):
class TestModule(nn.Module):
def __init__(self, t):
super().__init__()
self.register_buffer('t_rg', t)
self.register_buffer('t_not_rg', t.clone().detach())
self.t_rg = nn.Buffer(t, t.requires_grad)
self.t_not_rg = nn.Buffer(t.clone().detach())

def forward(self, x):
return x * self.t_rg + self.t_not_rg
Expand Down
11 changes: 5 additions & 6 deletions test/distributed/test_dynamo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,9 +760,8 @@ class DuplicateModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self._param = torch.randn((3,), device="cuda")
self.register_buffer(
"_buf", torch.randn((3,), requires_grad=False, device="cuda")
)
self._buf = torch.nn.Buffer(
torch.randn((3,), requires_grad=False, device="cuda"))

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Use `_param` and `_buf` each twice in this compiled forward
Expand All @@ -789,8 +788,8 @@ def test_fsdp_dup_tensors_diff_source(self):
class BufModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"_buf", torch.randn((3,), requires_grad=False, device="cuda")
self._buf = nn.Buffer(
torch.randn((3,), requires_grad=False, device="cuda")
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -802,7 +801,7 @@ def __init__(self) -> None:
self._param = nn.Parameter(torch.randn((1,), device="cuda"))
self._buf_module = BufModule()
# Share the buffer, meaning same tensor but different source
self.register_buffer("_buf", self._buf_module._buf)
self._buf = self._buf_module._buf

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Use the same buffer tensor twice in the compiled forward,
Expand Down
4 changes: 2 additions & 2 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ class MyBlock(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(1, 1))
self.register_buffer("buffer", torch.ones(1, 1))
self.buffer = torch.nn.Buffer(torch.ones(1, 1))

def forward(self, x):
x = torch.nn.functional.linear(x, torch.randn(4, 4))
Expand Down Expand Up @@ -2668,7 +2668,7 @@ def test_not_functionalize(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 2))
self.buffer1 = torch.nn.Buffer(torch.ones(6, 2))

def forward(self, x):
x.add_(2)
Expand Down
6 changes: 3 additions & 3 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,7 @@ def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
self.buf0 = torch.nn.Buffer(torch.randn(10, 10))

def forward(self, x):
return self.relu(self.linear(x) + self.buf0)
Expand Down Expand Up @@ -1500,7 +1500,7 @@ class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
self.buf0 = torch.nn.Buffer(torch.randn(10, 10))

def forward(self, x):
return self.r(torch.sin(x)) + self.buf0
Expand All @@ -1527,7 +1527,7 @@ class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.register_buffer("buf0", torch.randn(10, 10))
self.register_buffer("buf0", torch.nn.Buffer(torch.randn(10, 10)))
self.register_parameter(
name="param0", param=torch.nn.Parameter(torch.randn(10, 10))
)
Expand Down
10 changes: 5 additions & 5 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,8 +1859,8 @@ def test_sort_out2(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("sorted", torch.ones(4, 4))
self.register_buffer("indices", torch.ones(4, 4, dtype=torch.long))
self.sorted = torch.nn.Buffer(torch.ones(4, 4))
self.indices = torch.nn.Buffer(torch.ones(4, 4, dtype=torch.long))

def forward(self, x):
torch.sort(x, out=(self.sorted, self.indices))
Expand Down Expand Up @@ -1891,7 +1891,7 @@ def test_sigmoid_out2(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("base", torch.ones(4, 4))
self.base = torch.nn.Buffer(torch.ones(4, 4))

def forward(self, x):
torch.sigmoid(x, out=self.base)
Expand Down Expand Up @@ -2174,8 +2174,8 @@ def test_named_buffers(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("x", torch.ones(3))
self.register_buffer("y", torch.ones(3))
self.x = torch.nn.Buffer(torch.ones(3))
self.y = torch.nn.Buffer(torch.ones(3))

def forward(self, inp):
res = 0
Expand Down
2 changes: 1 addition & 1 deletion test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_export_simple_model_buffer_mutation(self):
class Foo(torch.nn.Module):
def __init__(self, float_val):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 1))
self.buffer1 = torch.nn.Buffer(torch.ones(6, 1))

def forward(self, x):
self.buffer1.add_(2)
Expand Down
3 changes: 1 addition & 2 deletions test/export/test_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def test_aten_wrong_mem_format_buffer(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"a",
self.a = torch.nn.Buffer(
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last),
)

Expand Down
6 changes: 3 additions & 3 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,7 +1945,7 @@ def test_real_weights_in_symbolic_mode_with_inplace_ops(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer", torch.ones(4, 5))
self.buffer = torch.nn.Buffer(torch.ones(4, 5))

def forward(self, x):
y = self.buffer.add_(3)
Expand Down Expand Up @@ -2140,7 +2140,7 @@ def test_aot_export_forward_mutation_no_buffer_mut_banned(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 4))
self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))

def forward(self, x):
x.add_(4)
Expand All @@ -2153,7 +2153,7 @@ def test_aot_export_forward_mutation_multiple_mut_banned(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("buffer1", torch.ones(6, 4))
self.buffer1 = torch.nn.Buffer(torch.ones(6, 4))

def forward(self, x, y):
y.add_(4)
Expand Down
12 changes: 6 additions & 6 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3466,8 +3466,8 @@ def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.randn(3))
self.linear = nn.Linear(3, 3)
self.register_buffer('buffer', torch.randn(3))
self.register_buffer('buffer_tied', self.buffer)
self.buffer = nn.Buffer(torch.randn(3))
self.buffer_tied = self.buffer

def forward(self, x):
x = self.linear(x)
Expand Down Expand Up @@ -3497,7 +3497,7 @@ class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.register_buffer('buffer', torch.randn(3))
self.buffer = nn.Buffer(torch.randn(3))

def forward(self, x):
x = self.linear(x)
Expand All @@ -3517,7 +3517,7 @@ class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 3)
self.register_buffer('buffer', torch.randn(3))
self.buffer = nn.Buffer(torch.randn(3))

def forward(self, x):
x = self.linear(x)
Expand Down Expand Up @@ -3573,8 +3573,8 @@ def __init__(self):
self.linear = nn.Linear(3, 3)
self.weight = self.linear.weight
self.bias = self.linear.bias
self.register_buffer('buffer', torch.randn(3))
self.register_buffer('buffer_tied', self.buffer)
self.buffer = nn.Buffer(torch.randn(3))
self.buffer_tied = self.buffer

def forward(self, x):
x = self.linear(x)
Expand Down
3 changes: 1 addition & 2 deletions test/inductor/test_cuda_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,7 @@ def __init__(self):
start = math.log2(0.5)
end = math.log2(1 / (2**8))

self.register_buffer(
"scales",
self.scales = nn.Buffer(
2
** torch.arange(
start,
Expand Down
4 changes: 1 addition & 3 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6184,9 +6184,7 @@ def fn(x, p1, p0):
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"_tensor_constant0", torch.randn([], dtype=torch.float32)
)
self._tensor_constant0 = nn.Buffer(torch.randn([], dtype=torch.float32))

def forward(self, arg0_1, arg1_1):
convert_element_type = torch.ops.prims.convert_element_type.default(
Expand Down
4 changes: 3 additions & 1 deletion test/jit/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def __init__(self):

self.parameter_b = torch.nn.Parameter(torch.randn(4))
self.submodule_b = Submodule()
self.buffer_b = torch.nn.Buffer(torch.randn(4))

m = TestModule()
m_loaded = self.getExportImportCopy(torch.jit.script(m))
Expand Down Expand Up @@ -526,7 +527,7 @@ def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 3, device="meta")
self.bar = torch.nn.Linear(3, 4)
self.register_buffer("buffer", torch.randn(4, device="meta"))
self.buffer = torch.nn.Buffer(torch.randn(4, device="meta"))

def forward(self, x):
x = self.foo(x)
Expand Down Expand Up @@ -1150,6 +1151,7 @@ def __init__(self):

self.parameter_b = torch.nn.Parameter(torch.randn(4))
self.submodule_b = Submodule()
self.buffer_b = torch.nn.Buffer(torch.randn(4))

m = TestModule()
m_loaded = self.getExportImportCopy(torch.jit.script(m))
Expand Down

0 comments on commit 32d422f

Please sign in to comment.