Skip to content

Commit

Permalink
Add support for pruning summaries in 1.X fashion.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 286103751
  • Loading branch information
alanchiao authored and tensorflower-gardener committed Dec 18, 2019
1 parent 21a1fde commit 105ec70
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def __init__(self, log_dir, update_freq='epoch', **kwargs):
super(PruningSummaries, self).__init__(
log_dir=log_dir, update_freq=update_freq, **kwargs)

def _log_pruning_metrics(self, logs, prefix, step):
if tf.__version__[0] == '1':
self._write_custom_summaries(step, logs)
else:
self._log_metrics(logs, prefix, step)

def on_epoch_end(self, batch, logs=None):
super(PruningSummaries, self).on_epoch_end(batch, logs)

Expand Down Expand Up @@ -112,4 +118,4 @@ def on_epoch_end(self, batch, logs=None):
for threshold, threshold_value in param_value_pairs[1::2]:
pruning_logs.update({threshold.name + '/threshold': threshold_value})

self._log_metrics(pruning_logs, '', iteration)
self._log_pruning_metrics(pruning_logs, '', iteration)
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,29 @@
# ==============================================================================
"""Tests for Pruning callbacks."""

import os
import tempfile

from absl.testing import parameterized
import numpy as np
import tensorflow as tf

# TODO(b/139939526): move to public API.
from tensorflow.python.keras import keras_parameterized
from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils
from tensorflow_model_optimization.python.core.sparsity.keras import prune
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_callbacks

# TODO(b/139939526): move to public API.


@keras_parameterized.run_all_keras_modes
class PruneTest(tf.test.TestCase, parameterized.TestCase):

def testUpdatesPruningStep(self):
def _assertLogsExist(self, log_dir):
self.assertNotEmpty(os.listdir(log_dir))

def testUpdatePruningStepsAndLogsSummaries(self):
log_dir = tempfile.mkdtemp()
model = prune.prune_low_magnitude(
keras_test_utils.build_simple_dense_model())
model.compile(
Expand All @@ -38,13 +46,17 @@ def testUpdatesPruningStep(self):
tf.keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5),
batch_size=20,
epochs=3,
callbacks=[pruning_callbacks.UpdatePruningStep()])
callbacks=[
pruning_callbacks.UpdatePruningStep(),
pruning_callbacks.PruningSummaries(log_dir=log_dir)
])

self.assertEqual(2,
tf.keras.backend.get_value(model.layers[0].pruning_step))
self.assertEqual(2,
tf.keras.backend.get_value(model.layers[1].pruning_step))

self._assertLogsExist(log_dir)

if __name__ == '__main__':
tf.test.main()

0 comments on commit 105ec70

Please sign in to comment.