Skip to content

Commit

Permalink
Implement TF QAT 2.0 backward pass for symmetric quantization
Browse files Browse the repository at this point in the history
Signed-off-by: Geunho Lee <quic_geunlee@quicinc.com>
  • Loading branch information
quic-geunlee committed Apr 10, 2023
1 parent 9b970d4 commit 615e9e2
Show file tree
Hide file tree
Showing 3 changed files with 483 additions and 25 deletions.
Expand Up @@ -2,7 +2,7 @@
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2020, Qualcomm Innovation Center, Inc. All rights reserved.
# Copyright (c) 2020-2023, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
Expand Down Expand Up @@ -35,13 +35,14 @@
# @@-COPYRIGHT-END-@@
# =============================================================================

""" implements straight through graident computation for Quantize Op"""
from dataclasses import dataclass
""" implements straight through gradient computation for Quantize Op """
from typing import Tuple
from dataclasses import dataclass

import tensorflow as tf
from tensorflow.python.framework import ops as tf_ops

from aimet_common import libpymo
from aimet_tensorflow.defs import AxisHandling
from aimet_tensorflow.utils.constants import QuantizeOpIndices

Expand Down Expand Up @@ -91,7 +92,7 @@ def _compute_dloss_by_dx(op, grad):
return dloss_by_dx


# by default we will have this registered for Qc Quantize op.
# by default, we will have this registered for Qc Quantize op.
@tf_ops.RegisterGradient("QcQuantize")
def _qc_straight_through_estimator_grad(op, grad):
# pylint: disable=unused-argument
Expand Down Expand Up @@ -320,6 +321,115 @@ def _compute_dloss_by_dmin_dmax_and_dx(inputs: tf.Tensor, bitwidth: tf.Tensor, o
return dloss_by_dmin, dloss_by_dmax, dloss_by_dx



def _compute_dloss_by_dmin_dmax_and_dx_symmetric(inputs: tf.Tensor,
bitwidth: tf.Tensor,
encoding_min: tf.Tensor,
encoding_max: tf.Tensor,
grad: tf.Tensor):
"""
Calculate dloss_by_dmin, dloss_by_dmax, and dloss_by_dx tensors computed by symmetric quantization
:param inputs: Inputs to op
:param bitwidth: Bitwidth used to quantize
:param encoding_min: Encoding min value(s), will be more than one if per channel is active
:param encoding_max: Encoding max value(s), will be more than one if per channel is active
:param grad: Gradient from child layer
:return: Tensors for dloss_by_dmin, dloss_by_dmax, and dloss_by_dx
"""
x = tf.cast(inputs, tf.float32)
bitwidth = tf.cast(bitwidth, tf.float32)
encoding_min = tf.cast(encoding_min, tf.float32)
encoding_max = tf.cast(encoding_max, tf.float32)

# handle min == max to avoid divide by zero
epsilon = tf.constant(1e-5, dtype=tf.float32)
encoding_max = tf.math.maximum(encoding_max, tf.add(encoding_min, epsilon))

num_steps = tf.cast(tf.pow(tf.cast(tf.constant(2), tf.float32), bitwidth) - 1, tf.float32)
half_num_steps = tf.divide(num_steps, tf.constant(2.0))
delta = encoding_max / tf.math.floor(half_num_steps)
offset = -tf.math.ceil(half_num_steps)

zero = tf.zeros_like(num_steps)
x_round = tf.round(inputs / delta) - offset
x_quant = tf.clip_by_value(x_round, zero, num_steps)

mask_tensor = tf.cast(tf.math.greater_equal(x_round, zero), tf.float32) * \
tf.cast(tf.math.less_equal(x_round, num_steps), tf.float32)
grad_tensor = mask_tensor * grad

axis = tf.cond(tf.equal(tf.rank(delta), 0),
lambda: tf.range(0, tf.rank(x)), # Per-tensor
lambda: tf.range(0, tf.rank(x) - 1)) # Per-channel

grad_encoding_max = tf.reduce_sum((x_quant + offset) * grad, axis=axis) - \
tf.reduce_sum(mask_tensor * (inputs / delta) * grad, axis=axis)

grad_encoding_max = grad_encoding_max / tf.math.floor(half_num_steps)
grad_encoding_max = tf.cast(grad_encoding_max, tf.float64)

return tf.negative(grad_encoding_max), grad_encoding_max, grad_tensor


def _compute_dloss_by_dmin_dmax_and_dx_asymmetric(inputs: tf.Tensor,
bitwidth: tf.Tensor,
encoding_min: tf.Tensor,
encoding_max: tf.Tensor,
grad: tf.Tensor):
"""
Calculate dloss_by_dmin, dloss_by_dmax, and dloss_by_dx tensors computed by asymmetric quantization
:param inputs: Inputs to op
:param bitwidth: Bitwidth used to quantize
:param encoding_min: Encoding min value(s), will be more than one if per channel is active
:param encoding_max: Encoding max value(s), will be more than one if per channel is active
:param grad: Gradient from child layer
:return: Tensors for dloss_by_dmin, dloss_by_dmax, and dloss_by_dx
"""
x = tf.cast(inputs, tf.float32)
bitwidth = tf.cast(bitwidth, tf.float32)
encoding_min = tf.cast(encoding_min, tf.float32)
encoding_max = tf.cast(encoding_max, tf.float32)

# handle min == max to avoid divide by zero
epsilon = tf.constant(1e-5, dtype=tf.float32)
encoding_max = tf.math.maximum(encoding_max, tf.add(encoding_min, epsilon))

num_steps = tf.cast(tf.pow(tf.cast(tf.constant(2), tf.float32), bitwidth) - 1, tf.float32)
delta = (encoding_max - encoding_min) / num_steps
b_zero = tf.round(tf.negative(encoding_min) / delta)
b_zero = tf.minimum(num_steps, tf.maximum(tf.constant(0.0), b_zero))
offset = tf.negative(b_zero)

zero = tf.zeros_like(num_steps)
x_round = tf.round(inputs / delta) - offset
x_quant = tf.clip_by_value(x_round, zero, num_steps)

mask_tensor = tf.cast(tf.math.greater_equal(x_round, zero), tf.float32) * \
tf.cast(tf.math.less_equal(x_round, num_steps), tf.float32)
grad_tensor = mask_tensor * grad

grad_scale = (x_quant + offset - x * mask_tensor / delta) * grad
grad_xq = delta * grad
grad_offset = grad_xq * (1 - mask_tensor)

axis = tf.cond(tf.equal(tf.rank(delta), 0),
lambda: tf.range(0, tf.rank(x)), # Per-tensor
lambda: tf.range(0, tf.rank(x) - 1)) # Per-channel

intermediate_term1 = tf.reduce_sum(grad_scale, axis=axis) / num_steps
intermediate_term2 = num_steps / (encoding_max - encoding_min) ** 2 * tf.reduce_sum(grad_offset, axis=axis)

grad_encoding_min = -intermediate_term1 + encoding_max * intermediate_term2
grad_encoding_max = intermediate_term1 - encoding_min * intermediate_term2

grad_encoding_max = tf.cast(grad_encoding_max, tf.float64)
grad_encoding_min = tf.cast(grad_encoding_min, tf.float64)

return grad_encoding_min, grad_encoding_max, grad_tensor


# pylint: disable=too-many-locals
# pylint: disable=too-many-arguments
def _compute_dloss_by_dmin_dmax_and_dx_for_per_channel(inputs: tf.Tensor, bitwidth: tf.Tensor, op_mode: tf.Tensor,
Expand Down Expand Up @@ -392,38 +502,105 @@ def reshape_dloss_by_dx_for_axis_handling(inputs, dloss_by_dx, axis_handling) ->
return dloss_by_dx

reshaped_inputs, grad = reshape_input_and_grad_for_axis_handling(inputs, grad, axis_handling)
dloss_by_dmin, dloss_by_dmax, dloss_by_dx = \
_compute_dloss_by_dmin_dmax_and_dx(reshaped_inputs, bitwidth, op_mode, encoding_min, encoding_max, is_symmetric,
grad)
dloss_by_dmin, dloss_by_dmax, dloss_by_dx = _calculate_gradients(reshaped_inputs,
bitwidth,
encoding_min,
encoding_max,
is_symmetric,
op_mode,
grad)

dloss_by_dx = reshape_dloss_by_dx_for_axis_handling(inputs, dloss_by_dx, axis_handling)

#return grad in case of floating-point mode
dloss_by_dx = tf.cond(is_int_data_type, lambda: dloss_by_dx, lambda: grad)

return dloss_by_dmin, dloss_by_dmax, dloss_by_dx


def _calculate_gradients(input_tensor: tf.Tensor,
bit_width: tf.Tensor,
encoding_min: tf.Tensor,
encoding_max: tf.Tensor,
use_symmetric_encoding: tf.Tensor,
op_mode: tf.Tensor,
grad: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Calculate dloss_by_dmin, dloss_by_dmax, and dloss_by_dx tensors
:param input_tensor: Inputs tensor
:param bit_width: Bitwidth used to quantize
:param encoding_min: Encoding min value(s), will be more than one if per channel is active
:param encoding_max: Encoding max value(s), will be more than one if per channel is active
:param use_symmetric_encoding: Symmetric encoding boolean tensor
:param op_mode: Op mode (if passthrough, gradient is returned as is)
:param grad: Gradient from child layer
:return: Tensors for dloss_by_dmin, dloss_by_dmax, and dloss_by_dx
"""

def _asymmetric_gradients():
return _compute_dloss_by_dmin_dmax_and_dx_asymmetric(input_tensor,
bit_width,
encoding_min,
encoding_max,
grad)

def _symmetric_gradients():
return _compute_dloss_by_dmin_dmax_and_dx_symmetric(input_tensor,
bit_width,
encoding_min,
encoding_max,
grad)

dloss_by_dmin, dloss_by_dmax, dloss_by_dx = tf.cond(use_symmetric_encoding,
_symmetric_gradients,
_asymmetric_gradients)

# Pass through gradient for skipped ops
op_mode = tf.cast(op_mode, tf.int8)
pass_through_mode = int(libpymo.TensorQuantizerOpMode.passThrough)
dloss_by_dx = tf.cond(tf.equal(op_mode, pass_through_mode), lambda: grad, lambda: dloss_by_dx)
dloss_by_dmin = tf.cond(tf.equal(op_mode, pass_through_mode),
lambda: tf.zeros_like(encoding_min, dtype=tf.float64),
lambda: dloss_by_dmin)
dloss_by_dmax = tf.cond(tf.equal(op_mode, pass_through_mode),
lambda: tf.zeros_like(encoding_max, dtype=tf.float64),
lambda: dloss_by_dmax)

return dloss_by_dmin, dloss_by_dmax, dloss_by_dx


@tf_ops.RegisterGradient("QcQuantizeRangeLearningCustomGradient")
def quantsim_custom_grad_learned_grid(op, grad):
"""
Performs custom gradient calculations for trained Quantize op
:param op: Tf operation for which gradients are to be computed
:param grad: Gradient flowing through
"""
dloss_by_dmin, dloss_by_dmax, dloss_by_dx = \
_compute_dloss_by_dmin_dmax_and_dx(op.inputs[0],
op.inputs[int(QuantizeOpIndices.bit_width)],
op.inputs[int(QuantizeOpIndices.op_mode)],
op.inputs[int(QuantizeOpIndices.encoding_min)],
op.inputs[int(QuantizeOpIndices.encoding_max)],
op.inputs[int(QuantizeOpIndices.use_symmetric_encoding)],
grad)
input_tensor = op.inputs[0]
bit_width = op.inputs[int(QuantizeOpIndices.bit_width)]
encoding_min = op.inputs[int(QuantizeOpIndices.encoding_min)]
encoding_max = op.inputs[int(QuantizeOpIndices.encoding_max)]
use_symmetric_encoding = op.inputs[int(QuantizeOpIndices.use_symmetric_encoding)]
op_mode = op.inputs[int(QuantizeOpIndices.op_mode)]

dloss_by_dmin, dloss_by_dmax, dloss_by_dx = _calculate_gradients(input_tensor,
bit_width,
encoding_min,
encoding_max,
use_symmetric_encoding,
op_mode,
grad)

return dloss_by_dx, None, None, dloss_by_dmin, dloss_by_dmax, None, None, None


@tf_ops.RegisterGradient("QcQuantizePerChannelRangeLearningCustomGradient")
def quantsim_per_channel_custom_grad_learned_grid(op, grad):
"""
Performs custom gradient calculations for trained QcQuantizePerChannel op
:param op: Tf operation for which gradients are to be computed
:param grad: Gradient flowing through
"""
Expand Down

0 comments on commit 615e9e2

Please sign in to comment.