In [1]:
from brevitas.nn import QuantReLU
from brevitas.quant import Uint8ActPerTensorFloat
from brevitas.core.restrict_val import RestrictValueType
from brevitas.nn.quant_layer import QuantNonLinearActLayer as QuantNLAL
from brevitas.nn.quant_layer import ActQuantType
from brevitas.inject.defaults import Int8ActPerTensorFloatMinMaxInit


from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class CommonUintActQuant(Uint8ActPerTensorFloat):
    """
    Common unsigned act quantizer with bit-width set to None so that it's forced to be specified by
    each layer.
    """
    scaling_min_val = 2e-16
    bit_width = None
    restrict_scaling_type = RestrictValueType.POWER_OF_TWO

Esto no vale, porque deberían estar los valores truncados a 6

In [3]:
quant_relu6 = QuantReLU(
    act_quant=CommonUintActQuant,
    bit_width=4,
    max_val = 6.0
)    

In [4]:
dummy_in = 24*torch.rand(4)-1
print(f'in: {dummy_in}')
out = quant_relu6(dummy_in)
print(f'out: {out}')

in: tensor([20.8909, 12.6119, 12.6123, 13.6961])
out: tensor([19.5852, 13.0568, 13.0568, 13.0568], grad_fn=<MulBackward0>)


  return super().rename(names)


In [5]:
class QuantReLU6(QuantNLAL):

    def __init__(
            self,
            act_quant: Optional[ActQuantType] = Uint8ActPerTensorFloat,
            input_quant: Optional[ActQuantType] = None,
            return_quant_tensor: bool = False,
            **kwargs):
        QuantNLAL.__init__(
            self,
            act_impl=nn.ReLU6,
            passthrough_act=True,
            input_quant=input_quant,
            act_quant=act_quant,
            return_quant_tensor=return_quant_tensor,
            **kwargs)

In [6]:
quant_relu6 = QuantReLU6(
    act_quant=CommonUintActQuant,
    bit_width=4,
)  

In [7]:
dummy_in = 24*torch.rand(4)-1
print(f'in: {dummy_in}')
out = quant_relu6(dummy_in)
print(f'out: {out}')

in: tensor([2.0931, 8.1025, 9.3996, 9.0934])
out: tensor([2.3499, 8.2246, 8.8121, 8.8121], grad_fn=<MulBackward0>)


In [8]:
class act_relu6(nn.Module):
    def __init__(self):
        super(act_relu6, self).__init__()
    
    def forward(self, x):
        return F.relu6(x)

In [9]:
my_relu6 = act_relu6()

In [10]:
dummy_in = 24*torch.rand(4)-1
print(f'in: {dummy_in}')
out = my_relu6(dummy_in)
print(f'out: {out}')

in: tensor([21.4349, 18.8414,  3.1319, 13.9103])
out: tensor([6.0000, 6.0000, 3.1319, 6.0000])


In [11]:
class myact_relu6(nn.Hardtanh):
    def __init__(self):
        super().__init__(0.0, 6.0)

    def forward(self,x):
        return F.hardtanh(x, 0.0, 6.0)

In [12]:
pt_relu6 = myact_relu6()

In [13]:
dummy_in = 24*torch.rand(4)-12
print(f'in: {dummy_in}')
out = pt_relu6(dummy_in)
print(f'out: {out}')

in: tensor([ 7.0071,  0.5854, -2.7119, -3.6952])
out: tensor([6.0000, 0.5854, 0.0000, 0.0000])


In [14]:
class QuantRelu6(QuantNLAL):

    def __init__(
            self,
            act_quant: Optional[ActQuantType] = Int8ActPerTensorFloatMinMaxInit,
            input_quant: Optional[ActQuantType] = None,
            return_quant_tensor: bool = False,
            **kwargs):
        QuantNLAL.__init__(
            self,
            act_impl=myact_relu6,
            passthrough_act=True,
            input_quant=input_quant,
            act_quant=act_quant,
            return_quant_tensor=return_quant_tensor,
            **kwargs)

In [15]:
my_relu6 = QuantRelu6(
    min_val=0,
    max_val=6,
    signed=False
)

In [16]:
dummy_in = 24*torch.rand(4)-12
print(f'in: {dummy_in}')
out = my_relu6(dummy_in)
print(f'out: {out}')

in: tensor([-1.6754, -3.0991,  1.1525, -5.5890])
out: tensor([0.0000, 0.0000, 1.1529, 0.0000], grad_fn=<MulBackward0>)


In [23]:
class my_relu6_model(nn.Module):
    def __init__(self):
        super(my_relu6_model, self).__init__()

        self.qrelu6 = QuantRelu6(min_val=0.0, max_val=6.0, signed=False)

    def forward(self, x):
        return self.qrelu6(x)

In [24]:
relu6_model = my_relu6_model()

In [25]:
from brevitas.export import export_qonnx

In [26]:
input_shape = (2, 4)

In [34]:
# export_qonnx(relu6_model, torch.randn(input_shape), 'relu6_model__QONNX.onnx');

In [32]:
dummy_in = 24*torch.rand(4,4)-12
print(f'in: {dummy_in}')
out = relu6_model(dummy_in)
print(f'out: {out}')

in: tensor([[ -8.4986,  -9.6218,  -0.4147,  10.5439],
        [ -3.7002,   6.3618,  -2.1464,   6.0561],
        [-10.6148,   1.9633,   1.5422,  -8.8977],
        [  7.7430,   7.5358,   1.9808,   2.2477]])
out: tensor([[0.0000, 0.0000, 0.0000, 6.0000],
        [0.0000, 6.0000, 0.0000, 6.0000],
        [0.0000, 1.9529, 1.5529, 0.0000],
        [6.0000, 6.0000, 1.9765, 2.2588]], grad_fn=<MulBackward0>)
