Skip to content

Conversation

Ruomei
Copy link
Contributor

@Ruomei Ruomei commented Aug 17, 2020

Motivation:

In the current clustering implementation, the original weights of the clustered layers and the cluster indices are not updated during each training step. Tho the training process alters the values of the cluster centroids, due to the changes in other non-clustered layers, which are reflected in the gradients during training, the non-updated original weights will not always match the constantly updated cluster centroids and will create problems during training. In order to fix this issue, we want to update the original weights after the backpropagation. After that, using the updated centroids, the indices should be re-generated in the next training step. In this PR, changes are made and the unit tests for them are created too.

Details of the implementation:
As shown in the figure below, in the forward pass of our current clustering implementation, first, it uses density-based or linear methods to initialize the centroids (c) for the weights of each layer. Then, the original set of weights (W) are grouped into several clusters using the centroid values. Afterward, the association between the weights and the centroids is calculated based on c and W as indices. Finally, for a single cluster, the centroid value will be shared among all the weights and used in the forward pass instead of the original weights.
image2020-4-28_20-53-11
In the current backpropagation, the clustered weights will get the gradients from the layer being wrapped. These gradients will be fed into the node gather. Then, the gather node groups all the gradients by indices and accumulates them as the gradients of the centroids. However, due to the non-differentiable node tf.math.argmin, no gradients will be calculated for original weights W by automatic differentiation in TensorFlow.

  1. how to update the original weights?
    A small modification (gradient approximations using the straight-through estimator [1]) of the training graph is used to override the gradient during backpropagation like this:
    clustered_weights = tf.gather(cluster_centroids, indices)*tf.sign(original_weights + 1e+6)
    In the forward pass, the multiply in the graph does not change the graph (tf.sign gives out identity matrix) but in the backpropagation, the multiply is changed into add and the tf.sign is changed into identity via tf.custom_gradient. Essentially, the graph becomes:
    clustered_weights = tf.gather(cluster_centroids, indices)+tf.identity(original_weights + 1e+6)
    In this way, original weights can be updated by the automatic differentiation in TensorFlow.
  2. how to update cluster indices?
    Indices are not differentiable themselves and they are calculated only in the forward pass during training. Therefore, they are updated using tf.assign specifically in the forward pass in the call function. This will lead to some extra change for using tf.distribute, which has not been covered in this PR.

Result table:
As shown in the table below, the changes in this PR significantly improve the accuracy when the number of clusters is small and give limited benefit for other configurations.

Model Number of clusters tfmot tfmot+this PR delta
Mobilenet_v1 full model (all 64) 65.03% 66.65% 1.62%
3.11 MB 3.06 MB -0.05 MB
selective clustering (32 32 32) 49.72% 68% 18.28%
7.17 MB 6.99 MB -0.18 MB
selective clustering (256 256 32) 70.16% 69.32% -0.84%
8.32 MB 7.68 MB -0.64 MB
Mobilenet_v2 full model (all 32) 68.26% 69.09% 0.83%
2.65 MB 2.64 MB -0.01 MB
selective clustering (8 8 8) 35.05% 67.28% 32.23%
6.25 MB 6.23 MB -0.02 MB
selective clustering (16 16 16) 67.10% 70.94% 3.84%
6.59 MB 6.42 MB -0.17 MB
selective clustering (256 256 32) 72.3% 72.30% 0
7.31 MB 7.18 MB -0.13 MB
DS-CNN-L full model (all 32) 94.77% 94.86% 0.09%
0.33 MB 0.33 MB 0 MB
full model (all 8) 73.51% 86.83% 13.32%
0.19 MB 0.19 MB 0 MB

Reference:
[1] Y. Bengio, N. Leonard, and A. Courville. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.

@googlebot googlebot added the cla: yes PR contributor has signed CLA label Aug 17, 2020
@github-actions github-actions bot added the technique:clustering Regarding tfmot.clustering.keras APIs and docs label Aug 17, 2020
Copy link
Contributor

@benkli01 benkli01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@Ruomei
Copy link
Contributor Author

Ruomei commented Aug 19, 2020

Hi @alanchiao and @akarmi, I have just filled in all the results in the description. Could you please take a look at the PR and let me know your thoughts? Also, not sure how long the description should be?
Thanks, @benkli01, for reviewing.

@Ruomei Ruomei force-pushed the toupstream/enable_differentiable_training branch from 16ad7a9 to a248f89 Compare August 20, 2020 12:41
Copy link
Contributor

@akarmi akarmi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Looks good to me.

@akarmi akarmi added the ready to pull Working to get PR submitted to internal repository, after which merging to Github happens. label Aug 26, 2020
@alanchiao
Copy link

As noted in the call, I'll take a look at this, but after it merges when I have the time.

@akarmi
Copy link
Contributor

akarmi commented Sep 17, 2020

@alanchiao, could you check what is holding up this PR please?

@alanchiao
Copy link

Yes I am.

copybara-service bot pushed a commit that referenced this pull request Sep 22, 2020
--
a248f89 by Ruomei Yan <ruomei.yan@arm.com>:

Enable differentiable training and update cluster indices

COPYBARA_INTEGRATE_REVIEW=#519 from Ruomei:toupstream/enable_differentiable_training a248f89
PiperOrigin-RevId: 333108062
@Ruomei
Copy link
Contributor Author

Ruomei commented Sep 23, 2020

Closed this PR since it had already been merged by a different commit.

@Ruomei Ruomei closed this Sep 23, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla: yes PR contributor has signed CLA ready to pull Working to get PR submitted to internal repository, after which merging to Github happens. technique:clustering Regarding tfmot.clustering.keras APIs and docs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants