-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[sparsity] Sparsity parametrization #58705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 9a69b41 (more details on the Dr. CI page and at hud.pytorch.org/pr/58705):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@zafartahirov has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
torch/ao/nn/sparse/linear.py
Outdated
import torch | ||
from torch.nn import functional as F | ||
|
||
class Linear(torch.nn.Linear): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have the masking done by a module that can operate on the weights. This would be a FakeSparseModule, whos forward method masks the weights. Can you rewrite this implementation such that:
- We have a base FakeSparseModule class that contains a mask attribute.
- The FakeSparseModule is applied to the weight as a parameterization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, the parametrization is going to be part of a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, in this PR lets do the following:
- define fakesparse module [similar to the mulby module]
class FakeSparse(nn.Module):
def __init__(self, shape, sparsity_config):
super().__init__()
# Initialize mask to be all ones
self.mask = torch.ones(shape)
self.sparsity_config = sparsity_config
def forward(self, x):
assert self.mask.shape == x.shape
return self.mask * x
If you are doing the reparameterization in a later PR, use this module manually, similar to how it is used in nn.qat.linear.
class Linear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(in_features, out_features, bias, **factory_kwargs)
self.fake_sparse = FakeSparse(self.weight.size())
def forward(self, input):
return F.linear(input, self.fake_sparse(self.weight), self.bias)
torch/ao/nn/sparse/linear.py
Outdated
|
||
@classmethod | ||
def from_dense(cls, dense): | ||
sparse = cls(dense.in_features, dense.out_features, (dense.bias is not None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we handle this instead by doing the default init of the mask attr to be all ones? This way from_dense is not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on this? We need a transformation function for the prepare
.
""" | ||
_version = 1 | ||
_FLOAT_MODULE = torch.nn.Linear | ||
_FLOAT_MODULE = (torch.nn.Linear, SparseLinear) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am assuming that nn.Linear is needed for the PTQ flow and SparseLinear for the QAT flow. Is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets remove the torch.nn.Linear as a supported option, In both flows, we will need to create a SparseLinearModule. In the Post training sparsity case, one could do:
m = nn.Linear(3,5)
# Sparsifier.prepare would essentially do this:
m_sparse = torch.nn.sparse.linear.from_dense(m)
# m_sparse now has the same weights, but the mask is set to identity
m_sparse.set_mask(mask, pattern) # Function to override the mask and define the sparsity pattern
# Now we can convert only the sparse linear module.
weight = mod.weight | ||
if getattr(mod.qconfig, 'mask', False): | ||
weight = mod.qconfig.mask * mod.weight | ||
weight = mod.weight * mod.mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us make mask an attribute of the fakeSparse Module and do this via reparameterization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reparametrization will be a separate PR, as there are several things to iron out there.
|
||
@classmethod | ||
def from_float(cls, mod): | ||
def from_float(cls, mod, row_block_size=1, col_block_size=4): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not need the row/block col size here. This block just applies the mask. The mask should enforce the required sparsity pattern
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure how we could avoid that -- once we call the from_float
, we convert the float model into the sparse quantized model. In the latter we have to pack the weight, and we need the block shape in there. The sparsifier has this information, but the current model does not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, from the point of view of keeping the API clean, should we do the following:
- When we create a sparseLinear module, the fakesparse module should have all the information needed to specify the mask.
- At convert, we only rely on the fake-sparse module
def from_float(cls, mod):
sparsity_pattern = mod.fake_sparse.sparsity_config.sparsity_pattern # This is the shape of the sparse blocks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this being addressed in a later PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be addressed by the torch.ao.utils.convert
-- because it will take the sparse_config, the rows/cols can be passed around by that utility. Currently, it will be an argument to this to make sure the tests pass. I am going to keep the current implementation as a stopgap measure -- addressed in the later PRs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add tests so that functionality is verified.
Differential Revision: [D28970959](https://our.internmc.facebook.com/intern/diff/D28970959) [ghstack-poisoned]
@zafartahirov has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Demo for the current PR: https://gist.github.com/z-a-f/1d06ae8d5a509d3c9c1596dcb924afe0 Differential Revision: [D28970959](https://our.internmc.facebook.com/intern/diff/D28970959) [ghstack-poisoned]
@zafartahirov has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
The basic demo for this particular implementation can be found here: https://gist.github.com/z-a-f/1d06ae8d5a509d3c9c1596dcb924afe0 Test Plan: ``` python test/test_ao_sparsity.py ``` Differential Revision: [D28970959](https://our.internmc.facebook.com/intern/diff/D28970959) [ghstack-poisoned]
self.assertEqual(model_save.seq[1].parametrizations['weight'][0].mask, | ||
model_load.seq[1].parametrizations['weight'][0].mask) | ||
|
||
def test_jit_trace(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add a test to check if parameterized models are scriptable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parametrizations are not scriptable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please mark a TODO to add tests for symbolic tracing also. We need paramaterization to compose with these APIs.
not. | ||
""" | ||
def __init__(self, mask): | ||
super().__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we instead init with shape instead of mask?
The mask could be initialized to torch.ones(shape)
def __init__(self, shape):
super().__init__()
self.register_buffer('mask', torch.ones(shape))
If the user wants to override the mask, they can directly set the attribute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the user already has a mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, a few comments
@zafartahirov has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
The basic demo for this particular implementation can be found here: https://gist.github.com/z-a-f/1d06ae8d5a509d3c9c1596dcb924afe0 Test Plan: ``` python test/test_ao_sparsity.py ``` Differential Revision: [D28970959](https://our.internmc.facebook.com/intern/diff/D28970959) [ghstack-poisoned]
@zafartahirov has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
The basic demo for this particular implementation can be found here: https://gist.github.com/z-a-f/1d06ae8d5a509d3c9c1596dcb924afe0 Test Plan: ``` python test/test_ao_sparsity.py ``` Differential Revision: [D28970959](https://our.internmc.facebook.com/intern/diff/D28970959) [ghstack-poisoned]
@zafartahirov has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
The basic demo for this particular implementation can be found here: https://gist.github.com/z-a-f/1d06ae8d5a509d3c9c1596dcb924afe0 Test Plan: ``` python test/test_ao_sparsity.py ``` Differential Revision: [D28970959](https://our.internmc.facebook.com/intern/diff/D28970959) [ghstack-poisoned]
self.seq[0].weight = nn.Parameter(torch.zeros_like(self.seq[0].weight) + 2.0) | ||
self.seq[1].weight = nn.Parameter(torch.zeros_like(self.seq[1].weight) + 3.0) | ||
if bias: | ||
self.linear = nn.Parameter(torch.zeros_like(self.linear.bias) + 10.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.seq[0] should be self.seq[1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, some previous comments were missed, please take a look before landing.
Also, please add TODOs for scripting and symbolic tracing test coverage.
@zafartahirov has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
The basic demo for this particular implementation can be found here: https://gist.github.com/z-a-f/1d06ae8d5a509d3c9c1596dcb924afe0 Test Plan: ``` python test/test_ao_sparsity.py ``` Differential Revision: [D28970959](https://our.internmc.facebook.com/intern/diff/D28970959) [ghstack-poisoned]
@zafartahirov has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
The basic demo for this particular implementation can be found here: https://gist.github.com/z-a-f/1d06ae8d5a509d3c9c1596dcb924afe0 Test Plan: ``` python test/test_ao_sparsity.py ``` ghstack-source-id: cf62000 Pull Request resolved: pytorch/pytorch#58705
ghstack-source-id: f42b3da Pull Request resolved: pytorch/pytorch#58705
Stack from ghstack:
The basic demo for this particular implementation can be found here:
https://gist.github.com/z-a-f/1d06ae8d5a509d3c9c1596dcb924afe0
Test Plan:
Differential Revision: D28970959