-
Notifications
You must be signed in to change notification settings - Fork 332
Add visualization output via tensorboard to the clustering example #508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
benkli01
wants to merge
1
commit into
tensorflow:master
from
benkli01:toupstream/clustering-visualization
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
81 changes: 81 additions & 0 deletions
81
tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
|
||
import tensorflow as tf | ||
|
||
from tensorflow import keras | ||
from tensorflow_model_optimization.python.core.keras import compat | ||
|
||
class ClusteringSummaries(keras.callbacks.TensorBoard): | ||
"""Helper class to create tensorboard summaries for the clustering progress. | ||
|
||
This class is derived from tf.keras.callbacks.TensorBoard and just adds | ||
functionality to write histograms with batch-wise frequency. | ||
|
||
Arguments: | ||
log_dir: The path to the directory where the log files are saved | ||
cluster_update_freq: determines the frequency of updates of the | ||
clustering histograms. Same behaviour as parameter update_freq of | ||
the base class, i.e. it accepts `'batch'`, `'epoch'` or integer. | ||
""" | ||
|
||
def __init__(self, | ||
log_dir='logs', | ||
cluster_update_freq='epoch', | ||
**kwargs): | ||
super(ClusteringSummaries, self).__init__( | ||
log_dir=log_dir, **kwargs) | ||
|
||
if not isinstance(log_dir, str) or not log_dir: | ||
raise ValueError( | ||
'`log_dir` must be a non-empty string. You passed `log_dir`=' | ||
'{input}.'.format(input=log_dir)) | ||
|
||
self.cluster_update_freq = \ | ||
1 if cluster_update_freq == 'batch' else cluster_update_freq | ||
|
||
if compat.is_v1_apis(): # TF 1.X | ||
self.writer = tf.compat.v1.summary.FileWriter(log_dir) | ||
else: # TF 2.X | ||
self.writer = tf.summary.create_file_writer(log_dir) | ||
|
||
self.continuous_batch = 0 | ||
|
||
def on_train_batch_begin(self, batch, logs=None): | ||
super().on_train_batch_begin(batch, logs) | ||
# Count batches manually to get a continuous batch count spanning | ||
# epochs, because the function parameter 'batch' is reset to zero | ||
# every epoch. | ||
self.continuous_batch += 1 | ||
|
||
def on_train_batch_end(self, batch, logs=None): | ||
assert self.continuous_batch >= batch, \ | ||
"Continuous batch count must always be greater or equal than the" \ | ||
"batch count from the parameter in the current epoch." | ||
|
||
super().on_train_batch_end(batch, logs) | ||
|
||
if self.cluster_update_freq == 'epoch': | ||
return | ||
elif self.continuous_batch % self.cluster_update_freq != 0: | ||
return # skip this batch | ||
|
||
self._write_summary() | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
super().on_epoch_end(epoch, logs) | ||
if self.cluster_update_freq == 'epoch': | ||
self._write_summary() | ||
|
||
def _write_summary(self): | ||
with self.writer.as_default(): | ||
for layer in self.model.layers: | ||
if not hasattr(layer, 'layer') or not hasattr(layer.layer, 'get_clusterable_weights'): | ||
continue # skip layer | ||
clusterable_weights = layer.layer.get_clusterable_weights() | ||
if len(clusterable_weights) < 1: | ||
continue # skip layers without clusterable weights | ||
prefix = 'clustering/' | ||
# Log variables | ||
for var in layer.variables: | ||
success = tf.summary.histogram( | ||
prefix + var.name, var, step=self.continuous_batch) | ||
assert success |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.