diff --git a/docs/source/nn.rst b/docs/source/nn.rst index 6e1b0f1a03771..ca515ed1428f8 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -22,6 +22,7 @@ These are the basic building blocks for graphs: :nosignatures: :template: classtemplate.rst + ~parameter.Buffer ~parameter.Parameter ~parameter.UninitializedParameter ~parameter.UninitializedBuffer diff --git a/test/distributed/fsdp/test_fsdp_flatten_params.py b/test/distributed/fsdp/test_fsdp_flatten_params.py index 44c21388f7b42..82328d698dff4 100644 --- a/test/distributed/fsdp/test_fsdp_flatten_params.py +++ b/test/distributed/fsdp/test_fsdp_flatten_params.py @@ -58,7 +58,7 @@ def _get_transformer(self, seed=0): dim_feedforward=128, dropout=0.1, ) - module.register_buffer("dummy_buffer", torch.tensor(1.0)) + module.dummy_buffer = nn.Buffer(torch.tensor(1.0)) def get_input(device, dtype): torch.manual_seed(1) # keep everything deterministic diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index c5eb22581c543..bfc177a335fd6 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -453,11 +453,11 @@ def init_nested_wrapped_module(): # Check that `device_id` with `sync_module_states=True` works nested_wrapped_module = init_nested_wrapped_module() - nested_wrapped_module.register_buffer( - "buf", torch.ones((2, 2), device="cpu") * self.rank + nested_wrapped_module.buf = nn.Buffer( + torch.ones((2, 2), device="cpu") * self.rank ) - nested_wrapped_module.module[0].register_buffer( - "buf", torch.ones((3, 2), device="cpu") * self.rank + nested_wrapped_module.module[0].buf = nn.Buffer( + torch.ones((3, 2), device="cpu") * self.rank ) nested_wrapped_module = FSDP( nested_wrapped_module, @@ -705,7 +705,7 @@ def __init__(self, rank): torch.manual_seed(rank) torch.cuda.manual_seed(rank) self.lin = nn.Linear(10, 10, bias=False) - self.register_buffer("buffer", torch.ones(1) * rank) + self.buffer = nn.Buffer(torch.ones(1) * rank) m = MyModel(self.rank).cuda() _assert_module_states( diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 117f202bbaae3..2d34010a613ca 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -87,7 +87,7 @@ def __init__(self, wrap_fsdp, register_buffers=False, ignore_inner=False): super().__init__() self.inner = Linear(*INNER_SHAPE) if register_buffers: - self.inner.register_buffer("buffer", torch.randn(BUFFER_SHAPE)) + self.inner.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE)) self.inner.register_buffer( "non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False ) @@ -97,7 +97,7 @@ def __init__(self, wrap_fsdp, register_buffers=False, ignore_inner=False): ) self.outer = Linear(*OUTER_SHAPE) if register_buffers: - self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE)) + self.outer.buffer = nn.Buffer(torch.randn(BUFFER_SHAPE)) self.outer.register_buffer( "non_persistent_buffer", torch.randn(BUFFER_SHAPE), persistent=False ) diff --git a/test/distributed/fsdp/test_fsdp_unshard_params.py b/test/distributed/fsdp/test_fsdp_unshard_params.py index 7d91340bc1c19..0f057e065356d 100644 --- a/test/distributed/fsdp/test_fsdp_unshard_params.py +++ b/test/distributed/fsdp/test_fsdp_unshard_params.py @@ -423,7 +423,7 @@ def _test_named_parameters_and_buffers(self, prefix: str, recurse: bool): CUDAInitMode.CUDA_BEFORE, deterministic=True, ) - model.register_buffer("buffer", torch.ones(1)) + model.buffer = nn.Buffer(torch.ones(1)) # Wrap the top-level with FSDP since `named_parameters()` and # `named_buffers` will contain FSDP prefixes if called on a non-FSDP # root module @@ -436,7 +436,7 @@ def _test_named_parameters_and_buffers(self, prefix: str, recurse: bool): ), self.process_group, ) - fsdp_model.register_buffer("buffer", torch.ones(1)) + fsdp_model.buffer = nn.Buffer(torch.ones(1)) with FSDP.summon_full_params(fsdp_model): for call in ["named_parameters", "named_buffers"]: for (n1, p1), (n2, p2) in itertools.zip_longest( diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index bf79d2bdc54ab..1954f588de52c 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -855,8 +855,7 @@ def test_local_optimizer_parity( torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM), torch.nn.Linear(HIDDEN_DIM, OUTPUT_DIM), ).to(self.device) - model.register_buffer( - "test_buffer", + model.test_buffer = torch.nn.Buffer( torch.ones((1), device=self.device) * self.rank, ) # Define models/optimizers for DDP with ZeRO and DDP with local diff --git a/test/distributed/test_data_parallel.py b/test/distributed/test_data_parallel.py index 317bd6f2c557d..43602d6b83cc0 100644 --- a/test/distributed/test_data_parallel.py +++ b/test/distributed/test_data_parallel.py @@ -34,8 +34,8 @@ def test_data_parallel_buffers_requiring_grad(self): class TestModule(nn.Module): def __init__(self, t): super().__init__() - self.register_buffer('t_rg', t) - self.register_buffer('t_not_rg', t.clone().detach()) + self.t_rg = nn.Buffer(t, t.requires_grad) + self.t_not_rg = nn.Buffer(t.clone().detach()) def forward(self, x): return x * self.t_rg + self.t_not_rg diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 1284771730441..9329d5897695e 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -760,9 +760,8 @@ class DuplicateModule(nn.Module): def __init__(self) -> None: super().__init__() self._param = torch.randn((3,), device="cuda") - self.register_buffer( - "_buf", torch.randn((3,), requires_grad=False, device="cuda") - ) + self._buf = torch.nn.Buffer( + torch.randn((3,), requires_grad=False, device="cuda")) def forward(self, x: torch.Tensor) -> torch.Tensor: # Use `_param` and `_buf` each twice in this compiled forward @@ -789,8 +788,8 @@ def test_fsdp_dup_tensors_diff_source(self): class BufModule(nn.Module): def __init__(self) -> None: super().__init__() - self.register_buffer( - "_buf", torch.randn((3,), requires_grad=False, device="cuda") + self._buf = nn.Buffer( + torch.randn((3,), requires_grad=False, device="cuda") ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -802,7 +801,7 @@ def __init__(self) -> None: self._param = nn.Parameter(torch.randn((1,), device="cuda")) self._buf_module = BufModule() # Share the buffer, meaning same tensor but different source - self.register_buffer("_buf", self._buf_module._buf) + self._buf = self._buf_module._buf def forward(self, x: torch.Tensor) -> torch.Tensor: # Use the same buffer tensor twice in the compiled forward, diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index af314c13b3737..c967a4e9b9b9e 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -973,7 +973,7 @@ class MyBlock(torch.nn.Module): def __init__(self): super().__init__() self.weight = torch.nn.Parameter(torch.ones(1, 1)) - self.register_buffer("buffer", torch.ones(1, 1)) + self.buffer = torch.nn.Buffer(torch.ones(1, 1)) def forward(self, x): x = torch.nn.functional.linear(x, torch.randn(4, 4)) @@ -2668,7 +2668,7 @@ def test_not_functionalize(self): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("buffer1", torch.ones(6, 2)) + self.buffer1 = torch.nn.Buffer(torch.ones(6, 2)) def forward(self, x): x.add_(2) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index a709c826290a9..03ef4f0730545 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -1451,7 +1451,7 @@ def __init__(self): super().__init__() self.relu = torch.nn.ReLU() self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) + self.buf0 = torch.nn.Buffer(torch.randn(10, 10)) def forward(self, x): return self.relu(self.linear(x) + self.buf0) @@ -1500,7 +1500,7 @@ class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) + self.buf0 = torch.nn.Buffer(torch.randn(10, 10)) def forward(self, x): return self.r(torch.sin(x)) + self.buf0 @@ -1527,7 +1527,7 @@ class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) + self.register_buffer("buf0", torch.nn.Buffer(torch.randn(10, 10))) self.register_parameter( name="param0", param=torch.nn.Parameter(torch.randn(10, 10)) ) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 659707a89a3e0..2e84776ea7658 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1859,8 +1859,8 @@ def test_sort_out2(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("sorted", torch.ones(4, 4)) - self.register_buffer("indices", torch.ones(4, 4, dtype=torch.long)) + self.sorted = torch.nn.Buffer(torch.ones(4, 4)) + self.indices = torch.nn.Buffer(torch.ones(4, 4, dtype=torch.long)) def forward(self, x): torch.sort(x, out=(self.sorted, self.indices)) @@ -1891,7 +1891,7 @@ def test_sigmoid_out2(self): class MyModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("base", torch.ones(4, 4)) + self.base = torch.nn.Buffer(torch.ones(4, 4)) def forward(self, x): torch.sigmoid(x, out=self.base) @@ -2174,8 +2174,8 @@ def test_named_buffers(self): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("x", torch.ones(3)) - self.register_buffer("y", torch.ones(3)) + self.x = torch.nn.Buffer(torch.ones(3)) + self.y = torch.nn.Buffer(torch.ones(3)) def forward(self, inp): res = 0 diff --git a/test/export/test_export.py b/test/export/test_export.py index e368499bab001..b18d4ea796966 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -50,7 +50,7 @@ def test_export_simple_model_buffer_mutation(self): class Foo(torch.nn.Module): def __init__(self, float_val): super().__init__() - self.register_buffer("buffer1", torch.ones(6, 1)) + self.buffer1 = torch.nn.Buffer(torch.ones(6, 1)) def forward(self, x): self.buffer1.add_(2) diff --git a/test/export/test_verifier.py b/test/export/test_verifier.py index 56c110b391e06..a951ac423a3fd 100644 --- a/test/export/test_verifier.py +++ b/test/export/test_verifier.py @@ -178,8 +178,7 @@ def test_aten_wrong_mem_format_buffer(self) -> None: class TestModel(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer( - "a", + self.a = torch.nn.Buffer( torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last), ) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 8d8ce6d1c1e11..e4357ce0bcafd 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1945,7 +1945,7 @@ def test_real_weights_in_symbolic_mode_with_inplace_ops(self): class M(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("buffer", torch.ones(4, 5)) + self.buffer = torch.nn.Buffer(torch.ones(4, 5)) def forward(self, x): y = self.buffer.add_(3) @@ -2140,7 +2140,7 @@ def test_aot_export_forward_mutation_no_buffer_mut_banned(self): class M(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("buffer1", torch.ones(6, 4)) + self.buffer1 = torch.nn.Buffer(torch.ones(6, 4)) def forward(self, x): x.add_(4) @@ -2153,7 +2153,7 @@ def test_aot_export_forward_mutation_multiple_mut_banned(self): class M(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("buffer1", torch.ones(6, 4)) + self.buffer1 = torch.nn.Buffer(torch.ones(6, 4)) def forward(self, x, y): y.add_(4) diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 217e76d45010b..719aa798e57ef 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -3466,8 +3466,8 @@ def __init__(self): super().__init__() self.bias = nn.Parameter(torch.randn(3)) self.linear = nn.Linear(3, 3) - self.register_buffer('buffer', torch.randn(3)) - self.register_buffer('buffer_tied', self.buffer) + self.buffer = nn.Buffer(torch.randn(3)) + self.buffer_tied = self.buffer def forward(self, x): x = self.linear(x) @@ -3497,7 +3497,7 @@ class Foo(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(3, 3) - self.register_buffer('buffer', torch.randn(3)) + self.buffer = nn.Buffer(torch.randn(3)) def forward(self, x): x = self.linear(x) @@ -3517,7 +3517,7 @@ class Foo(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(3, 3) - self.register_buffer('buffer', torch.randn(3)) + self.buffer = nn.Buffer(torch.randn(3)) def forward(self, x): x = self.linear(x) @@ -3573,8 +3573,8 @@ def __init__(self): self.linear = nn.Linear(3, 3) self.weight = self.linear.weight self.bias = self.linear.bias - self.register_buffer('buffer', torch.randn(3)) - self.register_buffer('buffer_tied', self.buffer) + self.buffer = nn.Buffer(torch.randn(3)) + self.buffer_tied = self.buffer def forward(self, x): x = self.linear(x) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 79161d8358fc2..30f9274875cf4 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -565,8 +565,7 @@ def __init__(self): start = math.log2(0.5) end = math.log2(1 / (2**8)) - self.register_buffer( - "scales", + self.scales = nn.Buffer( 2 ** torch.arange( start, diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 7686f2acf2a07..2e65174a32100 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6184,9 +6184,7 @@ def fn(x, p1, p0): class Repro(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer( - "_tensor_constant0", torch.randn([], dtype=torch.float32) - ) + self._tensor_constant0 = nn.Buffer(torch.randn([], dtype=torch.float32)) def forward(self, arg0_1, arg1_1): convert_element_type = torch.ops.prims.convert_element_type.default( diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 5cbb5317b6612..079247200cb17 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -487,6 +487,7 @@ def __init__(self): self.parameter_b = torch.nn.Parameter(torch.randn(4)) self.submodule_b = Submodule() + self.buffer_b = torch.nn.Buffer(torch.randn(4)) m = TestModule() m_loaded = self.getExportImportCopy(torch.jit.script(m)) @@ -526,7 +527,7 @@ def __init__(self): super().__init__() self.foo = torch.nn.Linear(2, 3, device="meta") self.bar = torch.nn.Linear(3, 4) - self.register_buffer("buffer", torch.randn(4, device="meta")) + self.buffer = torch.nn.Buffer(torch.randn(4, device="meta")) def forward(self, x): x = self.foo(x) @@ -1150,6 +1151,7 @@ def __init__(self): self.parameter_b = torch.nn.Parameter(torch.randn(4)) self.submodule_b = Submodule() + self.buffer_b = torch.nn.Buffer(torch.randn(4)) m = TestModule() m_loaded = self.getExportImportCopy(torch.jit.script(m)) diff --git a/test/nn/test_lazy_modules.py b/test/nn/test_lazy_modules.py index d3b0d58c01300..a070911a1b528 100644 --- a/test/nn/test_lazy_modules.py +++ b/test/nn/test_lazy_modules.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from torch.nn.parameter import UninitializedParameter, UninitializedBuffer -from torch.nn import Parameter +from torch.nn import Buffer, Parameter from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings from torch.testing._internal.common_cuda import TEST_CUDA @@ -47,29 +47,29 @@ def test_lazy_module_parameter(self): @suppress_warnings def test_lazy_module_buffer(self): module = LazyModule() - module.register_buffer('test_buffer', UninitializedBuffer()) + module.test_buffer = UninitializedBuffer() self.assertTrue(module.has_uninitialized_params()) state_dict = module.state_dict() self.assertIsInstance(state_dict['test_buffer'], UninitializedBuffer) new_module = LazyModule() # An error is raised when there is an attempt to replace an existing parameter # with an uninitialized one - new_module.register_buffer('test_buffer', torch.ones(5, 5)) + new_module.test_buffer = Buffer(torch.ones(5, 5)) with self.assertRaisesRegex(RuntimeError, 'shape of an uninitialized'): new_module.load_state_dict(state_dict) # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one new_module = LazyModule() - new_module.register_buffer('test_buffer', torch.ones(5, 5)) + new_module.test_buffer = Buffer(torch.ones(5, 5)) module.load_state_dict(new_module.state_dict()) self.assertEqual(module.test_buffer, torch.ones((5, 5))) # Uninitialized parameters are left unchanged module = LazyModule() - module.register_buffer('test_buffer', UninitializedBuffer()) + module.test_buffer = UninitializedBuffer() self.assertTrue(module.has_uninitialized_params()) new_module = LazyModule() - new_module.register_buffer('test_buffer', UninitializedBuffer()) + new_module.test_buffer = UninitializedBuffer() module.load_state_dict(new_module.state_dict()) module.load_state_dict(new_module.state_dict()) self.assertTrue(module.has_uninitialized_params()) @@ -85,7 +85,7 @@ def test_lazy_module_jit_param(self): @suppress_warnings def test_lazy_module_jit_buffer(self): module = LazyModule() - module.register_buffer('test_buffer', UninitializedBuffer()) + module.test_buffer = UninitializedBuffer() self.assertTrue(module.has_uninitialized_params()) with self.assertRaisesRegex(RuntimeError, 'run a forward pass'): torch.jit.script(module) @@ -101,7 +101,7 @@ def test_lazy_share_memory_param(self): @suppress_warnings def test_lazy_share_memory_buffer(self): module = LazyModule() - module.register_buffer('test_buffer', UninitializedBuffer()) + module.test_buffer = UninitializedBuffer() self.assertTrue(module.has_uninitialized_params()) with self.assertRaisesRegex(RuntimeError, 'share memory on an uninitialized'): module.share_memory() diff --git a/test/nn/test_parametrization.py b/test/nn/test_parametrization.py index 1e7d8c77ebffc..fe413eebc8724 100644 --- a/test/nn/test_parametrization.py +++ b/test/nn/test_parametrization.py @@ -9,7 +9,7 @@ import torch.nn.functional as F import torch.nn.init as init import torch.nn.utils.parametrize as parametrize -from torch.nn import Parameter +from torch.nn import Buffer, Parameter from torch.testing._internal.common_utils import run_tests, skipIfNoLapack, \ TemporaryFileName, instantiate_parametrized_tests, set_default_dtype from torch.testing._internal.common_cuda import TEST_MULTIGPU @@ -305,7 +305,7 @@ def forward(self, x): # Instantiate parametrizations on buffers. It should work as expected delattr(model, "bias") - model.register_buffer("bias", torch.ones(8)) + model.bias = Buffer(torch.ones(8)) parametrize.register_parametrization(model, "bias", FirstZero()) parametrize.register_parametrization(model, "bias", LastZero()) self.assertTrue(parametrize.is_parametrized(model)) @@ -333,8 +333,8 @@ def test_serialization_parametrization(self): class Orthogonal(nn.Module): def __init__(self, n): super().__init__() - self.register_buffer("id", torch.eye(n)) - self.register_buffer("B", torch.empty(n, n)) + self.id = Buffer(torch.eye(n)) + self.B = Buffer(torch.empty(n, n)) init.orthogonal_(self.B) def forward(self, X): @@ -396,7 +396,7 @@ def right_inverse(self, X): class Orthogonal(nn.Module): def __init__(self, n): super().__init__() - self.register_buffer("B", torch.eye(n)) + self.B = Buffer(torch.eye(n)) def forward(self, X): Id = torch.eye(X.size(0)) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 89feb2a25ea7a..50c4842ac67e2 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -297,7 +297,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): scale_1 = self.weight.reshape(1, -1, 1, 1) @@ -4214,7 +4214,7 @@ def test_gather_constant_fold(self): class GatherModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) # torch.nn.Embedding is converted to ONNX::Gather. # Constant folding will be triggerred for constant inputs. # This pattern is common for constant mask inputs in transformer models. @@ -4233,7 +4233,7 @@ def forward(self, x): class GatherModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("weight", torch.ones(2)) + self.weight = torch.nn.Buffer(torch.ones(2)) def forward(self, x): # shape is of rank 0 @@ -4248,7 +4248,7 @@ def forward(self, x): class GatherModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("rb", torch.randn(1, 1, 3, 1, 1)) + self.rb = torch.nn.Buffer(torch.randn(1, 1, 3, 1, 1)) def forward(self, x): x += self.rb[0] @@ -9394,7 +9394,7 @@ def test_shape_constant_fold(self): class ShapeModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): shape = self.weight.shape[0] @@ -10895,7 +10895,7 @@ class InnerModule2(torch.nn.Module): def __init__(self, embedding_dim): super().__init__() self.weights = InnerModule2.get_embedding(embedding_dim) - self.register_buffer("_float_tensor", torch.FloatTensor(1)) + self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1)) self.const = 2 @staticmethod @@ -10957,7 +10957,7 @@ def __init__(self, embedding_dim): self.embedding_dim = embedding_dim self.const = 2.5 self.weights = InnerModule.get_embedding(self.embedding_dim) - self.register_buffer("_float_tensor", torch.FloatTensor(1)) + self._float_tensor = torch.nn.Buffer(torch.FloatTensor(1)) @staticmethod def get_embedding(embedding_dim: int): diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index a1e458af7359f..e1a9bbf6f7ef1 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -540,7 +540,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): b = self.weight.reshape(1, -1, 1, 1) @@ -563,7 +563,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): div = self.weight.div(torch.tensor([1, 2, 3, 4, 5])) @@ -586,7 +586,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5])) @@ -609,7 +609,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): add = self.weight + torch.tensor([1, 2, 3, 4, 5]) @@ -640,7 +640,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): sub = self.weight - torch.tensor([1, 2, 3, 4, 5]) @@ -671,7 +671,7 @@ def __init__( self, ): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): sqrt = torch.sqrt(self.weight) @@ -691,7 +691,7 @@ def test_constant_fold_shape(self): class ShapeModule(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer("weight", torch.ones(5)) + self.weight = torch.nn.Buffer(torch.ones(5)) def forward(self, x): shape = self.weight.shape[0] diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py index 96d5cea156af3..c11a57db17a28 100644 --- a/test/quantization/core/test_quantized_tensor.py +++ b/test/quantization/core/test_quantized_tensor.py @@ -1440,7 +1440,7 @@ def __init__(self, per_channel): s = torch.rand(5, dtype=torch.float64) + 0.1 zp = torch.randint(5, 15, (5,)) x_q = torch.quantize_per_channel(x, s, zp, 1, torch.quint8) - self.register_buffer('x', x_q) + self.x = torch.nn.Buffer(x_q) @torch.jit.script_method def forward(self): diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py index d51fcbb999710..0441356501d81 100644 --- a/test/quantization/eager/test_quantize_eager_qat.py +++ b/test/quantization/eager/test_quantize_eager_qat.py @@ -94,9 +94,9 @@ def __init__(self, self.beta = nn.Parameter(torch.empty(out_channels)) self.affine = True self.track_running_stats = True - self.register_buffer('running_mean', torch.zeros(out_channels)) - self.register_buffer('running_var', torch.ones(out_channels)) - self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) + self.running_mean = nn.Buffer(torch.zeros(out_channels)) + self.running_var = nn.Buffer(torch.ones(out_channels)) + self.num_batches_tracked = nn.Buffer(torch.tensor(0, dtype=torch.long)) self.activation_post_process = self.qconfig.activation() self.weight_fake_quant = self.qconfig.weight() if bias: diff --git a/test/test_fx.py b/test/test_fx.py index 15daf07ac2435..d72aafb128f1f 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -817,7 +817,7 @@ def __init__(self): self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.lin = torch.nn.Linear(d_hid, d_hid) - self.register_buffer('buffer', torch.randn(bs + 100, d_hid)) + self.buffer = torch.nn.Buffer(torch.randn(bs + 100, d_hid)) def forward(self, x): x = torch.mm(x, self.mm_param) @@ -2660,7 +2660,7 @@ def getitem_inner(self): class GetItemBase(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer('pe', torch.randn(8, 8)) + self.pe = torch.nn.Buffer(torch.randn(8, 8)) class GetItem1(GetItemBase): def forward(self, x): @@ -3026,7 +3026,7 @@ class B(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(100, 200) - self.register_buffer("buf", torch.randn(2, 3)) + self.buf = torch.nn.Buffer(torch.randn(2, 3)) self.net_c = C() def forward(self, x): @@ -3196,7 +3196,7 @@ class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(1, 1) - self.register_buffer('buffer', torch.ones(1)) + self.buffer = torch.nn.Buffer(torch.ones(1)) def forward(self, x): return self.l1(x) + self.buffer diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index bf2419c8edc86..65cda3897fcde 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1215,8 +1215,8 @@ def __init__(self): self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2)) self.linear = torch.nn.Linear(2, 2) self.attr = torch.randn(2) - self.register_buffer("attr2", torch.randn(2)) - self.register_buffer("attr3", torch.ones(2, dtype=torch.int32)) + self.attr2 = torch.nn.Buffer(torch.randn(2)) + self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32)) def forward(self, x): return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x)) diff --git a/test/test_jit.py b/test/test_jit.py index c0b6ad706e189..25e165f1e9316 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -482,7 +482,7 @@ def test_restore_device_cuda(self): class MyModule(torch.jit.ScriptModule): def __init__(self): super().__init__() - self.register_buffer('b0', torch.randn(1, 3)) + self.b0 = nn.Buffer(torch.randn(1, 3)) self.p0 = nn.Parameter(torch.randn(2, 3)) @torch.jit.script_method @@ -538,7 +538,7 @@ def __init__(self): super().__init__() whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu') self.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1)) - self.register_buffer('b0', whole_tensor.narrow(0, 3, 1)) + self.b0 = nn.Buffer(whole_tensor.narrow(0, 3, 1)) m = Foo() m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0')) @@ -3989,7 +3989,7 @@ def test_cpp_module_iterator(self): a.p = nn.Parameter(torch.rand(3, 4)) a.foo = nn.Module() a.foo.name = 'foo' - a.foo.register_buffer('b', torch.rand(1, 1)) + a.foo.b = nn.Buffer(torch.rand(1, 1)) a.foo.bar = nn.Module() a.foo.bar.name = 'bar' a.foo.bar.an_int = 4 @@ -8957,7 +8957,7 @@ def test_script_module_param_buffer_mutation(self): class ModuleBufferMutate(torch.jit.ScriptModule): def __init__(self): super().__init__() - self.register_buffer('running_var', torch.tensor(0, dtype=torch.long)) + self.running_var = nn.Buffer(torch.tensor(0, dtype=torch.long)) @torch.jit.script_method def forward(self): @@ -9084,12 +9084,12 @@ class DerivedStateModule(torch.jit.ScriptModule): def __init__(self): super(TestScript.DerivedStateModule, self).__init__() self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float)) - self.register_buffer('derived', torch.neg(self.param).detach().clone()) + self.derived = nn.Buffer(torch.neg(self.param).detach().clone()) # This is a flag so we can test that the pack method was called - self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long)) + self.pack_called = nn.Buffer(torch.zeros(1, dtype=torch.long)) # This is a flag so we can test that the unpack method was called - self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long)) + self.unpack_called = nn.Buffer(torch.zeros(1, dtype=torch.long)) @torch.jit.script_method def _pack(self): @@ -9269,7 +9269,7 @@ def test_pack_unpack_nested(self): class SubSubMod(torch.jit.ScriptModule): def __init__(self): super().__init__() - self.register_buffer('buf', torch.ones(3, 4) * 3) + self.buf = nn.Buffer(torch.ones(3, 4) * 3) @torch.jit.script_method def _pack(self): @@ -9286,7 +9286,7 @@ def forward(self, x): class SubMod(torch.jit.ScriptModule): def __init__(self): super().__init__() - self.register_buffer('buf', torch.ones(3, 4) * 2) + self.buf = nn.Buffer(torch.ones(3, 4) * 2) self.ssm = SubSubMod() @torch.jit.script_method @@ -9305,7 +9305,7 @@ class Mod(torch.jit.ScriptModule): def __init__(self): super().__init__() self.submod = SubMod() - self.register_buffer('buf', torch.ones(3, 4) * 1) + self.buf = nn.Buffer(torch.ones(3, 4) * 1) @torch.jit.script_method def _pack(self): @@ -13111,7 +13111,7 @@ def __init__(self, in_features, out_features): self.out_features = out_features self.weight = torch.nn.Parameter(torch.empty(out_features, in_features)) self.bias = torch.nn.Parameter(torch.empty(out_features)) - self.register_buffer('counter', torch.ones(out_features)) + self.counter = nn.Buffer(torch.ones(out_features)) self.reset_parameters() def reset_parameters(self): @@ -13164,7 +13164,7 @@ def __init__(self, in_features, out_features): super().__init__() self.weight = torch.nn.Parameter(torch.ones(out_features, in_features)) self.bias = torch.nn.Parameter(torch.ones(out_features)) - self.register_buffer("buffer", torch.ones(out_features)) + self.buffer = nn.Buffer(torch.ones(out_features)) self.submodule = Submodule() def forward(self, x): @@ -13619,8 +13619,8 @@ class Root(torch.jit.ScriptModule): def __init__(self, number): super().__init__() - self.register_buffer('buffer1', torch.ones(2, 2)) - self.register_buffer('buffer2', torch.ones(2, 2)) + self.buffer1 = nn.Buffer(torch.ones(2, 2)) + self.buffer2 = nn.Buffer(torch.ones(2, 2)) self.number = number @torch.jit.script_method @@ -13638,8 +13638,8 @@ class M(torch.jit.ScriptModule): def __init__(self, number, submodule): super().__init__() - self.register_buffer('buffer1', torch.ones(2, 2)) - self.register_buffer('buffer2', torch.ones(2, 2)) + self.buffer1 = nn.Buffer(torch.ones(2, 2)) + self.buffer2 = nn.Buffer(torch.ones(2, 2)) self.number = number self.submodule = submodule @@ -13675,8 +13675,8 @@ def __setstate__(self, state): class NoArgState(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer('buffer1', torch.ones(2, 2)) - self.register_buffer('buffer2', torch.ones(2, 2)) + self.buffer1 = nn.Buffer(torch.ones(2, 2)) + self.buffer2 = nn.Buffer(torch.ones(2, 2)) def forward(self): pass @@ -15091,7 +15091,7 @@ class M(torch.jit.ScriptModule): def __init__(self): super().__init__() tensor = torch.zeros(1, requires_grad=False) - self.register_buffer('some_state', torch.nn.Parameter(tensor)) + self.some_state = nn.Buffer(torch.nn.Parameter(tensor)) @torch.jit.script_method def forward(self, x): @@ -15484,8 +15484,8 @@ def __init__(self): self.mod = (torch.nn.ReLU()) self.mod2 = (torch.nn.ReLU()) self.mod3 = torch.nn.Sequential(torch.nn.Sequential(torch.nn.ReLU())) - self.register_buffer('x', torch.zeros(3)) - self.register_buffer('y', torch.zeros(3)) + self.x = nn.Buffer(torch.zeros(3)) + self.y = nn.Buffer(torch.zeros(3)) self.z = torch.zeros(3) def bleh(self): diff --git a/test/test_mps.py b/test/test_mps.py index 7e73bf5d61244..8f16af44d378b 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -19,7 +19,7 @@ import itertools from collections import defaultdict from torch import inf -from torch.nn import Parameter +from torch.nn import Buffer, Parameter from torch.testing._internal import opinfo from torch.testing._internal.common_utils import \ (gradcheck, gradgradcheck, run_tests, TestCase, download_file, IS_CI, NoTest, @@ -7643,14 +7643,14 @@ class Layer(nn.Module): def __init__(self): super().__init__() self.layer_dummy_param = Parameter(torch.empty(3, 5)) - self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7)) + self.layer_dummy_buf = Buffer(torch.zeros(1, 3, 3, 7)) class Net(nn.Module): def __init__(self): super().__init__() self.l1 = Layer() self.dummy_param = Parameter(torch.empty(3, 5)) - self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1)) + self.dummy_buf = Buffer(torch.zeros(7, 3, 3, 1)) l = Layer() n = Net() diff --git a/test/test_nn.py b/test/test_nn.py index d6a4512dfaa73..91eea01c43a74 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -31,7 +31,7 @@ from torch.nn.utils import parameters_to_vector, vector_to_parameters from torch.nn.utils.fusion import fuse_conv_bn_weights from torch.nn.utils.fusion import fuse_linear_bn_weights -from torch.nn import Parameter +from torch.nn import Buffer, Parameter from torch.nn.parallel._functions import Broadcast from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \ @@ -365,8 +365,8 @@ def names(named_buffers): class M(nn.Module): def __init__(self): super().__init__() - self.register_buffer("buffer1", torch.empty(3, 5)) - self.register_buffer("buffer2", self.buffer1) + self.buffer1 = Buffer(torch.empty(3, 5)) + self.buffer2 = self.buffer1 m = M() self.assertEqual(names(m.named_buffers()), @@ -425,7 +425,7 @@ def test_dir(self): linear = nn.Linear(2, 2) linear._test_submodule = nn.Linear(2, 2) linear._test_parameter = Parameter(torch.empty(2, 2)) - linear.register_buffer('_test_buffer', torch.empty(2, 2)) + linear._test_buffer = Buffer(torch.empty(2, 2)) keys = dir(linear) self.assertIn('_test_submodule', keys) self.assertIn('_test_parameter', keys) @@ -530,6 +530,9 @@ def test_register_buffer_raises_error_if_attr_exists(self): with self.assertRaises(KeyError): m.register_buffer('attribute_name', torch.rand(5)) + with self.assertRaises(KeyError): + m.attribute_name = Buffer(torch.rand(5)) + del m.attribute_name m.register_parameter('attribute_name', nn.Parameter()) with self.assertRaises(KeyError): @@ -556,12 +559,18 @@ def test_register_buffer_allows_overwriting_with_same_name(self): self.assertEqual(m.buffer_name, buffer2) m.register_buffer('buffer_name', buffer3) self.assertEqual(m.buffer_name, buffer3) + m.buffer_name = Buffer(buffer1) + self.assertEqual(m.buffer_name, Buffer(buffer1)) + m.buffer_name = Buffer(buffer2) + self.assertEqual(m.buffer_name, Buffer(buffer2)) + m.buffer_name = Buffer(buffer3) + self.assertEqual(m.buffer_name, Buffer(buffer3)) def test_get_buffer(self): m = nn.Module() buffer1 = torch.randn(2, 3) buffer2 = torch.randn(4, 5) - m.register_buffer('foo', buffer1) + m.foo = Buffer(buffer1) m.register_buffer('bar', buffer2) self.assertEqual(buffer1, m.get_buffer('foo')) self.assertEqual(buffer2, m.get_buffer('bar')) @@ -575,13 +584,13 @@ def __init__(self, foo, bar): class Sub(nn.Module): def __init__(self, foo, bar): super().__init__() - self.register_buffer('foo', foo) + self.foo = Buffer(foo) self.subsub = SubSub(bar) class SubSub(nn.Module): def __init__(self, bar): super().__init__() - self.register_buffer('bar', bar) + self.bar = Buffer(bar) foo = torch.randn(2, 3) bar = torch.randn(4, 5) @@ -591,33 +600,35 @@ def __init__(self, bar): def test_buffer_not_persistent(self): m = nn.Module() - m.register_buffer('buf', torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5), persistent=False) self.assertTrue(len(list(m.buffers())) == 1) self.assertTrue(len(m.state_dict()) == 0) def test_buffer_not_persistent_del(self): m = nn.Module() - m.register_buffer('buf', torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5), persistent=False) del m.buf self.assertTrue(len(list(m.buffers())) == 0) def test_buffer_not_persistent_overwrite(self): m = nn.Module() - m.register_buffer('buf', torch.rand(5), persistent=False) - m.register_buffer('buf', torch.rand(5)) + m.buf = nn.Buffer(torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5)) # can we overwrite a non-persistent buffer with a persistent one? self.assertTrue(len(list(m.buffers())) == 1) self.assertTrue(len(m.state_dict()) == 1) # can we overwrite a persistent buffer with a non-persistent one? - m.register_buffer('buf', torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5), persistent=False) self.assertTrue(len(list(m.buffers())) == 1) self.assertTrue(len(m.state_dict()) == 0) def test_buffer_not_persistent_assign(self): m = nn.Module() - m.register_buffer('buf', torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5), persistent=False) + self.assertTrue(len(list(m.buffers())) == 1) + self.assertTrue(len(m.state_dict()) == 0) # Assigning None removes the buffer but if we then assign a new Tensor # to the same property, it should still be marked as a buffer. @@ -659,7 +670,7 @@ def test_load_state_dict_type(self): def test_buffer_not_persistent_load(self): m = nn.Module() - m.register_buffer('buf', torch.rand(5), persistent=False) + m.buf = nn.Buffer(torch.rand(5), persistent=False) m.load_state_dict({}) def test_register_parameter_raises_error_if_name_is_not_string(self): @@ -681,6 +692,11 @@ def test_register_parameter_raises_error_if_attr_exists(self): with self.assertRaises(KeyError): m.register_parameter('attribute_name', nn.Parameter()) + del m.attribute_name + m.attribute_name = Buffer(torch.rand(5)) + with self.assertRaises(KeyError): + m.register_parameter('attribute_name', nn.Parameter()) + del m.attribute_name m.add_module('attribute_name', nn.Module()) with self.assertRaises(KeyError): @@ -1625,7 +1641,7 @@ def test_type(self): net.l = l net.l2 = l net.add_module('empty', None) - net.register_buffer('indices', torch.LongTensor(1)) + net.indices = Buffer(torch.LongTensor(1)) net.float() self.assertIsInstance(l.weight.data, torch.FloatTensor) self.assertIsInstance(l.bias.data, torch.FloatTensor) @@ -2811,8 +2827,8 @@ def test_assignments(get_list, a, b, c): del l.a, l.b self.assertEqual(list(l.children()), []) - buf = torch.randn(10) - l.register_buffer('buf', buf) + buf = Buffer(torch.randn(10)) + l.buf = buf self.assertIs(l.buf, buf) l.buf = None self.assertIs(l.buf, None) diff --git a/test/test_stateless.py b/test/test_stateless.py index dd38b35a927d5..8c3941de041ba 100644 --- a/test/test_stateless.py +++ b/test/test_stateless.py @@ -18,7 +18,7 @@ class MockModule(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(1, 1) - self.register_buffer('buffer', torch.ones(1)) + self.buffer = torch.nn.Buffer(torch.ones(1)) self.foo = 0.0 def forward(self, x): @@ -30,8 +30,8 @@ def __init__(self): super().__init__() self.l1 = torch.nn.Linear(1, 1) self.tied_bias = self.l1.bias - self.register_buffer('buffer', torch.ones(1)) - self.register_buffer('tied_buffer', self.buffer) + self.buffer = torch.nn.Buffer(torch.ones(1)) + self.tied_buffer = self.buffer def forward(self, x): return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer @@ -408,7 +408,7 @@ def __repr__(self): def test_tied_weights_warns(self, functional_call): module = MockModule() module.tied_bias = module.l1.bias - module.register_buffer("tied_buffer", module.buffer) + module.tied_buffer = torch.nn.Buffer(module.buffer) @parametrize("functional_call", [ subtest(torch.func.functional_call, "torch_func"), @@ -613,7 +613,7 @@ def test_setattr(self, functional_call): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer('foo', torch.tensor([0.0])) + self.foo = torch.nn.Buffer(torch.tensor([0.0])) def forward(self, x): self.foo = self.foo + 1 @@ -637,7 +637,7 @@ def test_in_place_operator(self, functional_call): class Foo(torch.nn.Module): def __init__(self): super().__init__() - self.register_buffer('foo', torch.tensor([0.0])) + self.foo = torch.nn.Buffer(torch.tensor([0.0])) def forward(self, x): self.foo.add_(1) @@ -759,7 +759,7 @@ class Module(torch.nn.Module): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(1, 1) - self.register_buffer('buffer', torch.ones(1)) + self.buffer = torch.nn.Buffer(torch.ones(1)) def forward(self, x): parameters = tuple(self.parameters()) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 1827b9c21b2f3..9bdf3a5ea687c 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -448,6 +448,7 @@ def istensor(obj): """Check of obj is a tensor""" tensor_list = ( torch.Tensor, + torch.nn.Buffer, torch.nn.Parameter, *config.traceable_tensor_subclasses, ) diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 979788d2bd9f3..60d9f09f8e284 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -292,7 +292,12 @@ def _type_dispatch(cls): # NB: Careful not to close over self to avoid ref cycle from lru_cache entries = [ ( - (torch.Tensor, torch.nn.Parameter, torch._subclasses.FakeTensor), + ( + torch.Tensor, + torch.nn.Buffer, + torch.nn.Parameter, + torch._subclasses.FakeTensor, + ), cls.wrap_tensor, ), ((tuple, list, odict_values), cls.wrap_listlike), @@ -882,6 +887,7 @@ def wrap_tensor(self, value: torch.Tensor): else: assert type(value) in ( torch.Tensor, + torch.nn.Buffer, torch.nn.Parameter, torch._subclasses.fake_tensor.FakeTensor, ), type(value) @@ -1463,7 +1469,7 @@ def update_dim2constraint(dim, constraint_range): def wrap_to_fake_tensor_and_record( e, tx, ignore_subclass=False, *, source: Optional[Source], is_tensor: bool ): - if type(e) in (torch.Tensor, torch.nn.Parameter) or ( + if type(e) in (torch.Tensor, torch.nn.Buffer, torch.nn.Parameter) or ( ignore_subclass and isinstance(e, torch.Tensor) ): assert source is not None diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 4803be41fcf14..7a2a4659735f5 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1426,6 +1426,7 @@ def check(x): not isinstance(x, FakeTensor) and type(x) is not torch.Tensor and type(x) is not torch.nn.Parameter + and type(x) is not torch.nn.Buffer ) return [ diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index be1efecff32dd..0015efdd0c2eb 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -496,6 +496,7 @@ def __call__( if ( type(t) is torch.Tensor + or type(t) is torch.nn.Buffer or type(t) is torch.nn.Parameter or (ignore_subclass and isinstance(t, torch.Tensor)) or isinstance(t, FakeTensor) @@ -544,6 +545,9 @@ def __call__( # NB: Cannot directly use Parameter constructor # because that would force a detach, not desirable r._is_param = True + elif type(t) is torch.nn.Buffer: + # similar to above + r._is_buffer = True return r elif torch.overrides.is_tensor_like(t): # Blindly converting tensor subclasses to meta can cause diff --git a/torch/_utils.py b/torch/_utils.py index 02ea1449e289a..3c1ec83de6799 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -387,6 +387,17 @@ def _rebuild_qtensor( return tensor +def _rebuild_buffer(data, requires_grad, persistent): + buffer = torch.nn.Buffer(data, requires_grad, persistent) + return buffer + + +def _rebuild_buffer_with_state(data, requires_grad, persistent, state): + buffer = torch.nn.Buffer(data, requires_grad, persistent) + buffer = _set_obj_state(buffer, state) + return buffer + + def _rebuild_parameter(data, requires_grad, backward_hooks): param = torch.nn.Parameter(data, requires_grad) # NB: This line exists only for backwards compatibility; the diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index a471056dbe558..1672415f70bfe 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -238,7 +238,7 @@ def inner(e): def fetch_tensor_proxy(tracer): return lambda t: get_proxy_slot(t, tracer, t) -HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter) +HANDLED_TYPES = (torch.Tensor, torch.nn.Parameter, torch.nn.Buffer) def proxy_call(proxy_mode, func, pre_dispatch, args, kwargs): unrecognized_types = [] diff --git a/torch/nn/__init__.py b/torch/nn/__init__.py index 9fca305daa254..f0915d4361d8b 100644 --- a/torch/nn/__init__.py +++ b/torch/nn/__init__.py @@ -1,5 +1,6 @@ from .modules import * # noqa: F403 from .parameter import ( + Buffer as Buffer, Parameter as Parameter, UninitializedParameter as UninitializedParameter, UninitializedBuffer as UninitializedBuffer, diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 70f8a420d756d..10122433b57eb 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -5,7 +5,7 @@ import weakref import torch -from ..parameter import Parameter +from ..parameter import Parameter, Buffer import torch.utils.hooks as hooks from torch import Tensor, device, dtype @@ -1745,16 +1745,16 @@ def remove_from(*dicts_or_sets): modules[name] = value else: buffers = self.__dict__.get('_buffers') - if buffers is not None and name in buffers: + if isinstance(value, Buffer) or buffers is not None and name in buffers: if value is not None and not isinstance(value, torch.Tensor): raise TypeError("cannot assign '{}' as buffer '{}' " - "(torch.Tensor or None expected)" + "(torch.nn.Buffer, torch.Tensor or None expected)" .format(torch.typename(value), name)) - for hook in _global_buffer_registration_hooks.values(): - output = hook(self, name, value) - if output is not None: - value = output - buffers[name] = value + if isinstance(value, Buffer): + persistent = value.persistent + else: + persistent = name not in self._non_persistent_buffers_set + self.register_buffer(name, value, persistent) else: super().__setattr__(name, value) diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index c15ad0c863c94..cc0203c42d1c1 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -196,6 +196,74 @@ def __deepcopy__(self, memo): memo[id(self)] = result return result +# Metaclass to combine _TensorMeta and the instance check override for Buffer. +class _BufferMeta(torch._C._TensorMeta): + # Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag. + def __instancecheck__(self, instance): + return isinstance(instance, torch.Tensor) and getattr(instance, '_is_buffer', False) + + +class Buffer(torch.Tensor, metaclass=_BufferMeta): + r"""A kind of Tensor that should not be considered a model + parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. + + Buffers are :class:`~torch.Tensor` subclasses, that have a + very special property when used with :class:`Module` s - when they're + assigned as Module attributes they are automatically added to the list of + its buffers, and will appear e.g. in :meth:`~Module.buffers` iterator. + Assigning a Tensor doesn't have such effect. One can still assign a Tensor as explicitly by using + a the modules `~register_buffer` function. + + Args: + data (Tensor): buffer tensor. + requires_grad (bool, optional): if the buffer requires gradient. + Default: `False` + persistent (bool, optional): whether the buffer is part of the module's + :attr:`state_dict`. Default: `True` + """ + def __new__(cls, data=None, requires_grad=False, persistent=True): + if data is None: + data = torch.empty(0) + + # Path for custom tensors: set a flag on the instance to indicate buffer-ness. + t = data.detach().requires_grad_(requires_grad) + if type(t) is not type(data) and not isinstance(data, Parameter): + raise RuntimeError(f"Creating a Buffer from an instance of type {type(data).__name__} " + "requires that detach() returns an instance of the same type, but return " + f"type {type(t).__name__} was found instead. To use the type as a " + "Buffer, please correct the detach() semantics defined by " + "its __torch_dispatch__() implementation.") + t.persistent = persistent + t._is_buffer = True + return t + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad, self.persistent) + memo[id(self)] = result + return result + + def __repr__(self): + return 'Buffer containing:\n' + super().__repr__() + + def __reduce_ex__(self, proto): + state = torch._utils._get_obj_state(self) + + if not state: + return ( + torch._utils._rebuild_buffer, + (self.data, self.requires_grad, self.persistent) + ) + + return ( + torch._utils._rebuild_buffer_with_state, + (self.data, self.requires_grad, self.persistent, state) + ) + + __torch_function__ = _disabled_torch_function_impl + class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): r"""A buffer that is not initialized. @@ -214,7 +282,10 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): cls_to_become = torch.Tensor - def __new__(cls, requires_grad=False, device=None, dtype=None) -> None: + def __new__(cls, requires_grad=False, device=None, dtype=None, persistent=True) -> None: factory_kwargs = {'device': device, 'dtype': dtype} data = torch.empty(0, **factory_kwargs) - return torch.Tensor._make_subclass(cls, data, requires_grad) + ret = torch.Tensor._make_subclass(cls, data, requires_grad) + ret.persistent = persistent + ret._is_buffer = True + return ret diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index 219bb6d4efa2a..9ef33149fadb3 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -26,11 +26,22 @@ class UninitializedParameter(Tensor): dtype: Optional[torch.dtype] = None, ): ... +class Buffer(Tensor): + persistent: builtins.bool + def __init__( + self, + data: Tensor = ..., + requires_grad: builtins.bool = ..., + persistent: builtins.bool = ..., + ): ... + class UninitializedBuffer(Tensor): + persistent: builtins.bool def __init__( self, data: Tensor = ..., requires_grad: builtins.bool = ..., + persistent: builtins.bool = ..., ): ... def materialize( self, diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 20d77632ef354..6bc41ab20f3fc 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -4736,14 +4736,14 @@ class Layer(nn.Module): def __init__(self): super().__init__() self.layer_dummy_param = nn.Parameter(torch.empty(3, 5)) - self.register_buffer('layer_dummy_buf', torch.zeros(1, 3, 3, 7)) + self.layer_dummy_buf = nn.Buffer(torch.zeros(1, 3, 3, 7)) class Net(nn.Module): def __init__(self): super().__init__() self.l1 = Layer() self.dummy_param = nn.Parameter(torch.empty(3, 5)) - self.register_buffer('dummy_buf', torch.zeros(7, 3, 3, 1)) + self.dummy_buf = nn.Buffer(torch.zeros(7, 3, 3, 1)) l = Layer() n = Net()