In [1]:
import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerTensorFloat, Uint8ActPerTensorFloat, Int8ActPerTensorFloat
from brevitas.core.scaling import ConstScaling
from brevitas.inject.enum import ScalingImplType, RestrictValueType, BitWidthImplType
from brevitas.quant import Int8WeightPerTensorFixedPoint, Uint8ActPerTensorFixedPoint, Int8ActPerTensorFixedPoint
from brevitas.quant import Int8BiasPerTensorFixedPointInternalScaling
from brevitas.quant import Int8WeightPerChannelFixedPoint
from brevitas.quant import Int32Bias, Int16Bias, IntBias

import torch

# Two QuantIdentity

- First: Float scale
- Second: Fixed Point Scale

Input belongs to [0, 256] / 256.0

In [None]:
torch.manual_seed(123)

q_id1_fp = qnn.QuantIdentity(act_quant=Uint8ActPerTensorFloat)
q_id1_fixed = qnn.QuantIdentity(
                    act_quant=Uint8ActPerTensorFixedPoint,
                    return_quant_tensor=True)

In [None]:
inp1 = torch.randint(0, 256, (1, 3, 2, 2))
inp1 = inp1 / 256
inp1

In [None]:
q_out1_fp = q_id1_fp(inp1)
q_out1_fixed = q_id1_fixed(inp1)

print(f'Quant Identity Float\n{q_out1_fp}')
print(f'Quant Identity Fixed Point\n{q_out1_fixed}')

In [None]:
print(f'Float scale: {q_id1_fp.act_quant.scale()}')
print(f'Float zero: {q_id1_fp.act_quant.zero_point()}')
print(f'Float bits: {q_id1_fp.act_quant.bit_width()}')

print(f'Fixed scale: {q_id1_fixed.act_quant.scale()}')
print(f'Fixed zero: {q_id1_fixed.act_quant.zero_point()}')
print(f'Fixed bits: {q_id1_fixed.act_quant.bit_width()}')

# QuantIdentity with Constant Scale

It seems scaling_init refers to dynamic range, although docs say it is a constant scale factor

In [None]:
class MyQuantId(Uint8ActPerTensorFloat):
    bit_width=8
    scaling_impl_type = ScalingImplType.CONST
    restrict_value_impl = RestrictValueType.POWER_OF_TWO
    scaling_init = 1 

In [None]:
q_id2_fp = qnn.QuantIdentity(act_quant=MyQuantId)

In [None]:
print(f'Float scale: {q_id2_fp.act_quant.scale()}')
print(f'Float zero: {q_id2_fp.act_quant.zero_point()}')
print(f'Float bit: {q_id2_fp.act_quant.bit_width()}')

In [None]:
q_id3_fp = qnn.QuantIdentity(act_quant=Uint8ActPerTensorFloat, 
                          scaling_impl_type = ScalingImplType.CONST, 
                          restrict_value_impl = RestrictValueType.POWER_OF_TWO,
                          scaling_init=1)

In [None]:
print(f'Float scale: {q_id3_fp.act_quant.scale()}')
print(f'Float zero: {q_id3_fp.act_quant.zero_point()}')
print(f'Float bit: {q_id3_fp.act_quant.bit_width()}')

In [None]:
q_id3_fixed = qnn.QuantIdentity(act_quant=Uint8ActPerTensorFixedPoint, 
                          scaling_impl_type = ScalingImplType.CONST, 
                          restrict_value_impl = RestrictValueType.POWER_OF_TWO,
                          scaling_init=1)

In [None]:
print(f'Fixed scale: {q_id3_fixed.act_quant.scale()}')
print(f'Fixed zero: {q_id3_fixed.act_quant.zero_point()}')
print(f'Fixed bit: {q_id3_fixed.act_quant.bit_width()}')

In [None]:
class MyQuantIdFixed(Uint8ActPerTensorFixedPoint):
    bit_width=4
    scaling_impl_type = ScalingImplType.CONST
    restrict_value_impl = RestrictValueType.POWER_OF_TWO
    scaling_init = 5 

In [None]:
q_id4_fixed = qnn.QuantIdentity(MyQuantIdFixed)

In [None]:
print(f'Fixed scale: {q_id4_fixed.act_quant.scale()}')
print(f'Fixed zero: {q_id4_fixed.act_quant.zero_point()}')
print(f'Fixed bit: {q_id4_fixed.act_quant.bit_width()}')

# Conv2d and Bias Test

Esto no funciona, porque IntxxBias calcula la escala automáticamente, mirando la escala de la entrada y de los pesos

In [None]:
class MyBias(Int16Bias):
    bit_witdh=16
    restrict_value_impl = RestrictValueType.POWER_OF_TWO

In [None]:
q_conv = qnn.QuantConv2d(
    3, 3, 1,
    weight_quant=Int8WeightPerTensorFixedPoint,
    bias=True,
    bias_quant=MyBias
)

In [None]:
q_out_conv = q_conv(q_out1_fixed)

In [None]:
q_out_conv

In [None]:
print(f'Conv Weights Scale: {q_conv.weight_quant.scale()}')
print(f'Conv Weights Zero: {q_conv.weight_quant.zero_point()}')
print(f'Conv Weights bit width: {q_conv.weight_quant.bit_width()}')

In [None]:
print(f'Conv Bias: {q_conv.bias}')
print(f'Conv Bias Scale: {q_conv.bias_quant.scale()}')
print(f'Conv Bias zero: {q_conv.bias_quant.zero_point()}')
print(f'Conv Bias bit width: {q_conv.bias_quant.bit_width()}')

### Bias with Internal Scaling

In [None]:
q_conv_2 = qnn.QuantConv2d(
    3, 3, 1,
    weight_quant=Int8WeightPerTensorFixedPoint,
    bias=True,
    bias_quant=Int8BiasPerTensorFixedPointInternalScaling
)

In [None]:
q_out_conv = q_conv_2(q_out1_fixed)

In [None]:
print(f'Conv Weights Scale: {q_conv_2.weight_quant.scale()}')
print(f'Conv Weights Zero: {q_conv_2.weight_quant.zero_point()}')
print(f'Conv Weights bit width: {q_conv_2.weight_quant.bit_width()}')

In [None]:
print(f'Conv Bias: {q_conv_2.bias}')
print(f'Conv Bias Scale: {q_conv_2.bias_quant.scale()}')
print(f'Conv Bias zero: {q_conv_2.bias_quant.zero_point()}')
print(f'Conv Bias bit width: {q_conv_2.bias_quant.bit_width()}')

# Per Channel Weights

In [None]:
class MyInt8WeightPerChannelFixedPoint(Int8WeightPerTensorFixedPoint):
    scaling_per_output_channel = True
    #restrict_scaling_type = RestrictValueType.POWER_OF_TWO

In [None]:
q_conv_3 = qnn.QuantConv2d(
    3, 3, 1,
    weight_quant=MyInt8WeightPerChannelFixedPoint,
    bias=True,
    bias_quant=Int8BiasPerTensorFixedPointInternalScaling
)

In [None]:
q_out_conv_3 = q_conv_3(q_out1_fixed)

In [None]:
print(f'Conv Weights Scale: {q_conv_3.weight_quant.scale()}')
print(f'Conv Weights Zero: {q_conv_3.weight_quant.zero_point()}')
print(f'Conv Weights bit width: {q_conv_3.weight_quant.bit_width()}')

In [None]:
q_conv_4 = qnn.QuantConv2d(
    3, 3, 1,
    weight_quant=Int8WeightPerChannelFixedPoint,
    bias=True,
    bias_quant=Int8BiasPerTensorFixedPointInternalScaling
)

In [None]:
q_out_conv_4 = q_conv_4(q_out1_fixed)

In [None]:
print(f'Conv Weights Scale: {q_conv_4.weight_quant.scale()}')
print(f'Conv Weights Zero: {q_conv_4.weight_quant.zero_point()}')
print(f'Conv Weights bit width: {q_conv_4.weight_quant.bit_width()}')

# Per Channel with Auto Scale Bias

In [None]:
inp_channel = torch.randint(0, 256, (1, 3, 2, 2))
inp_channel = inp_channel / 256
inp_channel

In [None]:
class MyQuantId2(Uint8ActPerTensorFloat):
    bit_width=8
    scaling_impl_type = ScalingImplType.CONST
    restrict_value_impl = RestrictValueType.POWER_OF_TWO
    scaling_init = 1 
    
q_id_fixed_channel = qnn.QuantIdentity(
                    act_quant=MyQuantId2,
                    return_quant_tensor=True)
q_conv_fixed_channel = qnn.QuantConv2d(
    3, 3, 1,
    weight_quant=Int8WeightPerChannelFixedPoint,
    bias=True,
    bias_quant=Int16Bias
)

In [None]:
q_out_channel = q_conv_fixed_channel(q_id_fixed_channel(inp_channel))

In [None]:
q_out_channel

### Print Weights

In [None]:
print(f'Conv Weights Scale: {q_conv_fixed_channel.weight_quant.scale()}')
print(f'Conv Weights Zero: {q_conv_fixed_channel.weight_quant.zero_point()}')
print(f'Conv Weights bit width: {q_conv_fixed_channel.weight_quant.bit_width()}')

In [None]:
print(q_conv_fixed_channel.weight_quant.scale().shape)
print(q_conv_fixed_channel.weight_quant.scale()[2, 0, 0, 0].item())

### Print Bias

In [None]:
print(f'Conv Weights Scale: {q_conv_fixed_channel.bias_quant.scale()}')
print(f'Conv Weights Zero: {q_conv_fixed_channel.bias_quant.zero_point()}')
print(f'Conv Weights bit width: {q_conv_fixed_channel.bias_quant.bit_width()}')

In [None]:
q_conv_fixed_channel.bias