Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions test/ao/sparsity/test_data_sparsifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Owner(s): ["module: unknown"]

import logging
import random
import torch
from torch.nn.utils.parametrize import is_parametrized
from torch.testing._internal.common_utils import TestCase
Expand Down Expand Up @@ -138,16 +137,23 @@ def check_add_data(self, data_list, data_with_config, defaults, **kwargs):
sparsifier = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
all_data = data_list + data_with_config
for some_data in all_data:
name1, data1, _ = self._get_name_data_config(some_data)
name1, data1, config = self._get_name_data_config(some_data, defaults=defaults)
data1 = sparsifier._extract_weight(data1)
data1_old = copy.deepcopy(data1)
assert torch.all(data1 == sparsifier.get_data(name=name1))
# get some other data at random and with the same name
rand_idx = random.randint(0, len(all_data) - 1)
_, data2, _ = self._get_name_data_config(all_data[rand_idx])
data2 = sparsifier._extract_weight(data2)

sparsifier.step()
mask = sparsifier.get_mask(name1)

data2 = torch.randn(data1.shape) # add another data with the same shape as original data
sparsifier.add_data(name=name1, data=data2)
assert torch.all(data2 == sparsifier.get_data(name=name1))

assert torch.all(sparsifier.get_mask(name1) == mask) # mask should not change
assert torch.all(data1_old == data1)

assert sparsifier.data_groups[name1] == config # if replaced old_config should match new config

def check_state_dict(self, data_list, data_with_config, defaults, **kwargs):
sparsifier1 = self._make_sparsifier(data_list, data_with_config, defaults=defaults, **kwargs)
sparsifier2 = self._make_sparsifier(data_list=[data_list[0]], data_with_config=[], defaults=defaults, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
from ...sparsifier import base_sparsifier
from collections import defaultdict
from torch import nn
import warnings
import copy
from ...sparsifier import utils
from torch.nn.utils import parametrize
import sys
import warnings

if not sys.warnoptions:
# to suppress repeated warnings when being used in a training loop.
warnings.simplefilter("once")

__all__ = ['BaseDataSparsifier']

Expand Down Expand Up @@ -74,8 +79,15 @@ def _extract_weight(self, data):
elif type(data) in EMBEDDING_TYPES:
return data.weight

def add_data(self, name: str, data, **config):
r""" Configures and parametrizes the internal container model with name and data
def add_data(self, name: str, data, reuse_mask=True, **config):
r""" Configures and parametrizes the internal container model with name and data.

**Note**:
1. If the data with name already exists, it replaces the data.
2. While replacing, the old mask is reused when `reuse_mask=True`
3. If `reuse_mask=True`, then the replacing data needs to have the same shape as that of old data.
4. By default, the config of the replaced data is used as config for the replacing data, unless something
is specified in the config dictionary.
"""
assert type(data) in SUPPORTED_TYPES, \
"specified data type not supported at the moment"
Expand All @@ -90,7 +102,19 @@ def add_data(self, name: str, data, **config):
if name in self.state:
# If the named data already exists - replace
warnings.warn("Replacing existing data of the same name. - Did you mean a different name?")
self.__delete_data(name=name)

# reuse old config
old_args = self.data_groups[name]
local_args = copy.deepcopy(old_args)
local_args.update(config)

if reuse_mask:
current_data = self.get_data(name=name)
assert weight.shape == current_data.shape, \
"to retain the old mask, the shape of the new data must be the same as the previous one"
mask = self.get_mask(name=name) # reuse mask instead of creating a new one

self._delete_data(name=name)

# parameter creates a deepcopy of the weight inside, so create a buffer
self._container.register_buffer(name=name, tensor=weight)
Expand Down Expand Up @@ -254,7 +278,7 @@ def step(self):
def update_mask(self, name, data, **kwargs):
pass

def __delete_data(self, name):
def _delete_data(self, name):
"""Detaches some data from the sparsifier.

Args:
Expand All @@ -264,7 +288,7 @@ def __delete_data(self, name):
Note:
Currently private. Kind of used as a helper function when replacing data of the same name
"""
self.squash_mask(names=[name], leave_parametrized=True)
self.squash_mask(names=[name], leave_parametrized=False) # do not apply the mask while deleting
delattr(self._container, name)
self.state.pop(name)
self.data_groups.pop(name)