In [44]:
import numpy as np
import torch

from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
from brevitas.nn import QuantIdentity

from modules.models.common_imagenet import CommonIntActQuant
import modules.models.brevitas_example_common as brevitas_common

from brevitas.core.zero_point import (
    StatsFromParameterZeroPoint, 
    ParameterFromRuntimeZeroPoint, 
    ParameterZeroPoint,
    ParameterFromStatsFromParameterZeroPoint)

# Quant Input: [0, 1] -> [-1, 1] Q1.7

In [2]:
quant_inp = QuantIdentity( # for Q1.7 input format -> sign.7bits
                act_quant = CommonIntActQuant,
                bit_width = 8,
                min_val = -1.0,
                max_val = 1.0 - 2.0 ** (-7),
                narrow_range = False,
                return_quant_tensor=True,
                restrict_scaling_type = RestrictValueType.POWER_OF_TWO)

In [3]:
inp_sample = torch.tensor(np.array([-1, -0.5, 0, 0.5, 1]), dtype=torch.float32)
inp_sample

tensor([-1.0000, -0.5000,  0.0000,  0.5000,  1.0000])

In [4]:
out = quant_inp(inp_sample)

  return super().rename(names)


In [5]:
out

QuantTensor(value=tensor([-1.0000, -0.5000,  0.0000,  0.5000,  0.9922], grad_fn=<MulBackward0>), scale=tensor(0.0078, grad_fn=<DivBackward0>), zero_point=tensor(0.), bit_width=tensor(8.), signed_t=tensor(True), training_t=tensor(True))

# Quant Identity with Zero Point not Zero: from stats

In [46]:
class Act_NonZero(brevitas_common.CommonQuant, brevitas_common.ActQuantSolver):
    zero_point_impl = ParameterFromRuntimeZeroPoint
    scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS 
    quantize_zero_point = True

In [47]:
idtty_non_zero = QuantIdentity( # for Q1.7 input format -> sign.7bits
    act_quant = Act_NonZero,
    bit_width = 1,
    return_quant_tensor=True)

DependencyError: 'Act_NonZero' can not resolve attribute 'collect_stats_steps' while building 'scaling_impl'

In [39]:
zero_sample = torch.tensor(np.array([-1, -0.5, 0, 0.5]), dtype=torch.float32)
zero_sample

tensor([-1.0000, -0.5000,  0.0000,  0.5000])

In [40]:
out_zero = idtty_non_zero(zero_sample)
out_zero

QuantTensor(value=tensor([-1., -1.,  1.,  1.]), scale=tensor(1.), zero_point=tensor(0.), bit_width=tensor(1.), signed_t=tensor(True), training_t=tensor(True))