Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from tensorflow_model_optimization.python.core.internal.tensor_encoding.encoders.common_encoders import hadamard_quantization
from tensorflow_model_optimization.python.core.internal.tensor_encoding.encoders.common_encoders import identity
from tensorflow_model_optimization.python.core.internal.tensor_encoding.encoders.common_encoders import uniform_quantization
from tensorflow_model_optimization.python.core.internal.tensor_encoding.encoders.common_encoders import drive
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,52 @@ def hadamard_quantization(bits):
stages_impl.HadamardEncodingStage.ENCODED_VALUES_KEY).add_parent(
stages_impl.FlattenEncodingStage(),
stages_impl.FlattenEncodingStage.ENCODED_VALUES_KEY).make()


def drive(bias_correction=True):
"""Returns DRIVE `Encoder`.

First, the `Encoder` reshapes the input to a rank-1 `Tensor` and applies a
randomized Hadamard transform (rotation). It then applies a rotation-aware
sign, and, finally, the quantized values are bit-packed into an integer type.

This encoder is derived from the source published with "DRIVE: One-bit
Distributed Mean Estimation" (NeurIPS '21;
https://arxiv.org/pdf/2105.08339.pdf), and the algorithm presented therein.

Limitations:
(1) In the implementation of HadamardEncodingStage a single seed is shared
among senders, as described in the paper this should be used when the number
of senders are no more than log of the dimension of the input tensor.
(2) This encoder works better on larger tensors. An ideal preprocessing
stage would concatenate the input model into a single tensor. Additionally,
the ability to mark a few tensors for being skipped would also be helpful
(e.g., normalization layers). Currently, this is not always possible with
the tensor encoders API.

Despite the limitations of this implementation, this achieves accuracy similar
to sending the full tensors for many distributed learning scenarios.

The `Encoder` is a composition of the following encoding stages:
* `FlattenEncodingStage` - reshaping the input tensor into a vector.
* `HadamardEncodingStage` - applying the Hadamard transform.
* `RotationAwareSignEncodingStage` - applying a rotation-aware sign.
* `BitpackingEncodingStage` - bit-packing the result into integer values.

Args:
bias_correction: A Python bool, whether to use bias correcting or
MSE minimizing scale.
If `True`, the encoding is unbiased on expectation.
If `False`, the encoding minimizes the MSE.

Returns:
The DRIVE `Encoder`.
"""
return core_encoder.EncoderComposer(
stages_impl.BitpackingEncodingStage(1)).add_parent(
stages_impl.RotationAwareSignEncodingStage(bias_correction), stages_impl
.RotationAwareSignEncodingStage.ENCODED_VALUES_KEY).add_parent(
stages_impl.HadamardEncodingStage(),
stages_impl.HadamardEncodingStage.ENCODED_VALUES_KEY).add_parent(
stages_impl.FlattenEncodingStage(),
stages_impl.FlattenEncodingStage.ENCODED_VALUES_KEY).make()
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
common_encoders.identity,
lambda: common_encoders.uniform_quantization(8),
lambda: common_encoders.hadamard_quantization(8),
lambda: common_encoders.drive(),
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.stages_impl import HadamardEncodingStage
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.stages_impl import IdentityEncodingStage
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.stages_impl import UniformQuantizationEncodingStage
from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.stages_impl import RotationAwareSignEncodingStage
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,93 @@ def decode(self,
return tf.cast(unpacked_x, dummy_type_value.dtype)
else:
return tf.cast(unpacked_x, tf.float32)


@encoding_stage.tf_style_encoding_stage
class RotationAwareSignEncodingStage(encoding_stage.EncodingStageInterface):
"""Encoding stage performing a rotation-aware sign.

This class is adapted from the source published with "DRIVE: One-bit
Distributed Mean Estimation" (NeurIPS '21;
https://arxiv.org/pdf/2105.08339.pdf), and the algorithm presented therein.

The encoding stage encodes vectors into a single bit per coordinate. It is
designed to execute after a random rotation (such as the result of the
`HadamardEncodingStage`). It calculates an appropriate scale that ensures
that the dequantized tensor will be (a) unbiased, or (b) with a minimal MSE
(after the inverse rotation).
"""

ENCODED_VALUES_KEY = 'bit_values'
SCALE_VALUES_KEY = 'scale'

def __init__(self, bias_correction=True):
"""Initializer for the RotationAwareSignEncodingStage.

Args:
bias_correction: A Python bool, whether to use bias correcting or
MSE minimizing scale.
If `True`, the encoding (post-rotation) is on expectation unbiased.
If `False`, the encoding (post-rotation) minimizes the MSE.
"""

self._bias_correction = bias_correction

@property
def name(self):
"""See base class."""
return 'rotation_aware_sign'

@property
def compressible_tensors_keys(self):
"""See base class."""
return [self.ENCODED_VALUES_KEY]

@property
def commutes_with_sum(self):
"""See base class."""
return False

@property
def decode_needs_input_shape(self):
"""See base class."""
return False

def get_params(self):
"""See base class."""
return {}, {}

def encode(self, x, encode_params):
"""See base class."""
del encode_params

if self._bias_correction:
# the bias correcting scale is (||x||_2)^2 / (||x||_1)
scale = tf.reduce_sum(x**2) / tf.reduce_sum(tf.abs(x))
else:
# the MSE minimizing scale is (||x||_1) / x's dimension
scale = tf.reduce_sum(tf.abs(x)) / tf.size(x)

# send one-bit sign
# We note that zeros are unlikely after running the high-dimensional
# Hadamard rotation that precedes this stage, so the bias derived from
# always rounding 0 to 1 can be safely ignored.
onebit_signs = tf.cast(tf.greater(x, 0), x.dtype)

return {self.ENCODED_VALUES_KEY: onebit_signs,
self.SCALE_VALUES_KEY: scale}

def decode(self,
encoded_tensors,
decode_params,
num_summands=None,
shape=None):
"""See base class."""
del decode_params, num_summands, shape # Unused.

onebit_signs = encoded_tensors[self.ENCODED_VALUES_KEY]
scale = encoded_tensors[self.SCALE_VALUES_KEY]

signs = onebit_signs * 2 - 1

return scale * signs
Original file line number Diff line number Diff line change
Expand Up @@ -450,5 +450,51 @@ def test_bad_input_bits_raises(self):
stages_impl.BitpackingEncodingStage(17)


class RotationAwareSignEncodingStageStageTest(test_utils.BaseEncodingStageTest):

def default_encoding_stage(self):
"""See base class."""
return stages_impl.RotationAwareSignEncodingStage()

def default_input(self):
"""See base class."""
return tf.random.uniform([50], minval=-1.0, maxval=1.0)

@property
def is_lossless(self):
"""See base class."""
return False

def common_asserts_for_test_data(self, data):
"""See base class."""

# Asserts that float type values are integers.
bits = data.encoded_x[
stages_impl.RotationAwareSignEncodingStage.ENCODED_VALUES_KEY]
assert bits.dtype == np.float32
self.assertAllClose(bits, tf.cast(tf.cast(bits, np.int32), np.float32))

def test_scaled_sign(self):
stage = stages_impl.RotationAwareSignEncodingStage()
test_data = self.run_one_to_many_encode_decode(stage, self.default_input)

self.assertAllClose(
tf.sign(test_data.x) * test_data.encoded_x[
stages_impl.RotationAwareSignEncodingStage.SCALE_VALUES_KEY],
test_data.decoded_x)

def test_one_bit_encoding(self):
stage = stages_impl.RotationAwareSignEncodingStage()
test_data = self.run_one_to_many_encode_decode(stage, self.default_input)

bits = test_data.encoded_x[
stages_impl.RotationAwareSignEncodingStage.ENCODED_VALUES_KEY]
one_bit_packing_stage = stages_impl.BitpackingEncodingStage(1)
packing_test_data = self.run_one_to_many_encode_decode(
one_bit_packing_stage, lambda: bits)

self.assertAllClose(bits, packing_test_data.decoded_x)


if __name__ == '__main__':
tf.test.main()