Skip to content

Commit

Permalink
Add support for non-persistent buffers. (#37191)
Browse files Browse the repository at this point in the history
Summary:
Issue: #18056
Pull Request resolved: #37191

Differential Revision: D21428373

Pulled By: albanD

fbshipit-source-id: a7d367bafb95137e1bc380178b82b08eff5d5a5a
  • Loading branch information
sharvil authored and facebook-github-bot committed May 7, 2020
1 parent 46ed334 commit 594b33e
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 10 deletions.
49 changes: 49 additions & 0 deletions test/test_nn.py
Expand Up @@ -909,6 +909,55 @@ def test_register_buffer_allows_overwriting_with_same_name(self):
m.register_buffer('buffer_name', buffer3)
self.assertEqual(m.buffer_name, buffer3)

def test_buffer_not_persistent(self):
m = nn.Module()
m.register_buffer('buf', torch.rand(5), persistent=False)
self.assertTrue(len(list(m.buffers())) == 1)
self.assertTrue(len(m.state_dict()) == 0)

def test_buffer_not_persistent_del(self):
m = nn.Module()
m.register_buffer('buf', torch.rand(5), persistent=False)
del m.buf
self.assertTrue(len(list(m.buffers())) == 0)

def test_buffer_not_persistent_overwrite(self):
m = nn.Module()
m.register_buffer('buf', torch.rand(5), persistent=False)
m.register_buffer('buf', torch.rand(5))

# can we overwrite a non-persistent buffer with a persistent one?
self.assertTrue(len(list(m.buffers())) == 1)
self.assertTrue(len(m.state_dict()) == 1)

# can we overwrite a persistent buffer with a non-persistent one?
m.register_buffer('buf', torch.rand(5), persistent=False)
self.assertTrue(len(list(m.buffers())) == 1)
self.assertTrue(len(m.state_dict()) == 0)

def test_buffer_not_persistent_assign(self):
m = nn.Module()
m.register_buffer('buf', torch.rand(5), persistent=False)

# Assigning None removes the buffer but if we then assign a new Tensor
# to the same property, it should still be marked as a buffer.
m.buf = None
self.assertTrue(len(list(m.buffers())) == 0)
self.assertTrue(len(m.state_dict()) == 0)
m.buf = torch.rand(5)
self.assertTrue(len(list(m.buffers())) == 1)
self.assertTrue(len(m.state_dict()) == 0)

# Assigning a Parameter removes the buffer.
m.buf = nn.Parameter(torch.rand(5))
self.assertTrue(len(list(m.buffers())) == 0)
self.assertTrue(len(m.state_dict()) == 1)

def test_buffer_not_persistent_load(self):
m = nn.Module()
m.register_buffer('buf', torch.rand(5), persistent=False)
m.load_state_dict({})

def test_register_parameter_raises_error_if_name_is_not_string(self):
m = nn.Module()
expected_error = 'parameter name should be a string. Got '
Expand Down
37 changes: 27 additions & 10 deletions torch/nn/modules/module.py
Expand Up @@ -84,6 +84,7 @@ def __init__(self):
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._persistent_buffers_set = set()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
Expand All @@ -104,19 +105,26 @@ def forward(self, *input):
"""
raise NotImplementedError

def register_buffer(self, name, tensor):
r"""Adds a persistent buffer to the module.
def register_buffer(self, name, tensor, persistent=True):
r"""Adds a buffer to the module.
This is typically used to register a buffer that should not to be
considered a model parameter. For example, BatchNorm's ``running_mean``
is not a parameter, but is part of the persistent state.
is not a parameter, but is part of the module's state. Buffers, by
default, are persistent and will be saved alongside parameters. This
behavior can be changed by setting :attr:`persistent` to ``False``. The
only difference between a persistent buffer and a non-persistent buffer
is that the latter will not be a part of this module's
:attr:`state_dict`.
Buffers can be accessed as attributes using given names.
Args:
name (string): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor): buffer to be registered.
persistent (bool): whether the buffer is part of this module's
:attr:`state_dict`.
Example::
Expand All @@ -141,6 +149,10 @@ def register_buffer(self, name, tensor):
.format(torch.typename(tensor), name))
else:
self._buffers[name] = tensor
if persistent:
self._persistent_buffers_set.add(name)
else:
self._persistent_buffers_set.discard(name)

def register_parameter(self, name, param):
r"""Adds a parameter to the module.
Expand Down Expand Up @@ -606,17 +618,20 @@ def __getattr__(self, name):
type(self).__name__, name))

def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
del d[name]
if isinstance(d, dict):
del d[name]
else:
d.discard(name)

params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
remove_from(self.__dict__, self._buffers, self._modules, self._persistent_buffers_set)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
Expand All @@ -630,7 +645,7 @@ def remove_from(*dicts):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers)
remove_from(self.__dict__, self._parameters, self._buffers, self._persistent_buffers_set)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
Expand All @@ -654,6 +669,7 @@ def __delattr__(self, name):
del self._parameters[name]
elif name in self._buffers:
del self._buffers[name]
self._persistent_buffers_set.discard(name)
elif name in self._modules:
del self._modules[name]
else:
Expand Down Expand Up @@ -687,7 +703,7 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in self._buffers.items():
if buf is not None:
if buf is not None and name in self._persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()

def state_dict(self, destination=None, prefix='', keep_vars=False):
Expand Down Expand Up @@ -766,7 +782,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)

local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
persistent_buffers = {k: v for k, v in self._buffers.items() if k in self._persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}

for name, param in local_state.items():
Expand Down

0 comments on commit 594b33e

Please sign in to comment.