-
Notifications
You must be signed in to change notification settings - Fork 331
Fix for sparsity-preserving clustering #702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for sparsity-preserving clustering #702
Conversation
* small fix to stop centroids from drifting * unit test coverage for fix
# Re-discover the sparsity masks to avoid drifting | ||
self.sparsity_masks[weight_name] = ( | ||
tf.cast(tf.math.not_equal(clustered_weights, 0), dtype=tf.float32) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks to me that this could be wrong: we form mask here from clustered_weights and then we multiply them with this mask.
In clustering we have trainable: weights and centroids. They could move away from zeros during training. To keep sparsity, we need to 1. to apply original sparsity_mask that is built in the build function to the original_weights; 2. to assign the smallest centroid back to zero, if there are no zero centroids. Here there is a variation: we could let original_weights to train as they want, because we use them to compute gradients in the backpropagation and give them more freedom, but we need to apply sparsity mask before we compute pulling indices. However, I would update with the original sparsity_mask the original_weights during training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tend to agree with Elena. With the currently proposed change, what prevents the mask to drift away from the original sparsity? To prevent the drift of the zero centroid, and keep the mask intact, why not apply the mask to the original weights before updating pulling_indices above, as Elena proposes. Alternatively (or in addition!?), we could find a way to force the "zero" centroid, currently hidden behind ClusteringAlgorithm abstraction, to stay at zero.
Any thoughts?
* modified unit tests to more extensive check of number of unique weights after sparsity preserve clustering
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
1 similar comment
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
…s, extended mnist test for additional checks
4ca62bd
to
45b15be
Compare
CLAs look good, thanks! ℹ️ Googlers: Go here for more info. |
1 similar comment
CLAs look good, thanks! ℹ️ Googlers: Go here for more info. |
Thank you @TamasArm for fixing this. |
@daverim Hi David, could you please take a look at this and retrigger the pull request labeler if possible? Thank you in advance! |
Summary of change:
Ensure the sparsity mask is updated during training time in sparsity-preserving clustering. Due to cluster-weight associations updates, the location of zero weights may change during training. This patch updated the sparsity mask according to such changes. It also adds extra tests to ensure sparsity aware clustering works during distributed training.