In [1]:
from chop.nn.quantized.functional import linearMXInt
import torch
from chop.tools.logger import set_logging_verbosity

set_logging_verbosity("debug")

[32mINFO    [0m [34mSet logging level to debug[0m


In [2]:

width, exp_width, parallel = 8,3, [1,2]

quant_config = {
    "weight_width": width,
    "weight_exponent_width": exp_width,
    "weight_block_size": parallel,
    "bias_width": width,
    "bias_exponent_width": exp_width,
    "bias_block_size": parallel,
    "data_in_width": width,
    "data_in_exponent_width": exp_width,
    "data_in_block_size": parallel
}

x = torch.rand([1,8,8])
weight = torch.rand([4,8])
bias = torch.rand([4])

print(weight)
print(bias)
print()

y = linearMXInt(x,weight,bias=bias, config=quant_config)

print(y)
y.shape

tensor([[0.7745, 0.4369, 0.5191, 0.6159, 0.8102, 0.9801, 0.1147, 0.3168],
        [0.6965, 0.9143, 0.9351, 0.9412, 0.5995, 0.0652, 0.5460, 0.1872],
        [0.0340, 0.9442, 0.8802, 0.0012, 0.5936, 0.4158, 0.4177, 0.2711],
        [0.6923, 0.2038, 0.6833, 0.7529, 0.8579, 0.6870, 0.0051, 0.1757]])
tensor([0.7497, 0.6047, 0.1100, 0.2121])

tensor([[[2.8045, 2.5164, 1.8172, 1.7331],
         [2.1832, 2.4806, 1.3740, 1.4209],
         [3.6595, 3.0752, 2.2102, 2.5467],
         [2.6786, 2.8069, 2.1306, 1.8768],
         [2.4971, 1.8619, 1.5163, 1.4707],
         [2.2756, 2.4578, 1.6089, 1.4782],
         [3.4909, 3.8820, 2.3495, 2.8353],
         [2.9913, 2.1998, 1.3497, 2.1574]]])


torch.Size([1, 8, 4])

In [3]:
from chop.nn.quantized.modules.linear import LinearMXInt

layer = torch.nn.Linear(8,4)
layer.weight = torch.nn.Parameter(weight)
layer.bias = torch.nn.Parameter(bias)

kwargs = {
    "in_features": layer.in_features,
    "out_features": layer.out_features,
    "config": quant_config
}

new_layer = LinearMXInt(**kwargs)
new_layer.weight = layer.weight
new_layer.bias = layer.bias

print(new_layer.forward(x) - y)

tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]], grad_fn=<SubBackward0>)


In [4]:
from chop.passes.graph.transforms import quantize_transform_pass
from chop.passes.graph.analysis import report_node_type_analysis_pass, init_metadata_analysis_pass, add_common_metadata_analysis_pass
from chop.ir.graph.mase_graph import MaseGraph

In [5]:
class MLP(torch.nn.Module):
    """
    Toy FC model for digit recognition on MNIST
    """

    def __init__(self) -> None:
        super().__init__()

        self.fc1 = torch.nn.Linear(8, 4, bias=True)
        self.fc1.weight.data = weight
        self.fc1.bias.data = bias

    def forward(self, x):
        x = self.fc1(x)
        return x

In [6]:
mlp = MLP()
mg = MaseGraph(model=mlp)


dummy_in = {"x": torch.randn((1, 8, 8))}

mg, _ = init_metadata_analysis_pass(mg, None)

mg, _ = add_common_metadata_analysis_pass(
    mg, {"dummy_in": dummy_in, "add_value": False}
)

quan_args = {
    'by': 'type',
    'report': True,
    'default': 
    {
        'config': quant_config
    }
}
quan_args['default']['config']['name'] = 'mxint'

mg, _ = quantize_transform_pass(mg, quan_args)

_ = report_node_type_analysis_pass(mg)


[36mDEBUG   [0m [34mgraph():
    %x : [num_users=1] = placeholder[target=x]
    %fc1 : [num_users=1] = call_module[target=fc1](args = (%x,), kwargs = {})
    return fc1[0m
[32mINFO    [0m [34mInspecting graph [add_common_node_type_analysis_pass][0m
[32mINFO    [0m [34m
Node name    Fx Node op    Mase type            Mase op      Value type
-----------  ------------  -------------------  -----------  ------------
x            placeholder   placeholder          placeholder  NA
fc1          call_module   module_related_func  linear       mxint
output       output        output               output       NA[0m


In [7]:
print(mlp.forward(x))
print(mg.model.forward(x))



print(mg.model.forward(x) - mlp.forward(x))
print(mg.model.forward(x) - y)

tensor([[[2.8076, 2.5207, 1.8242, 1.7314],
         [2.1888, 2.4864, 1.3769, 1.4250],
         [3.6686, 3.0870, 2.2192, 2.5497],
         [2.6791, 2.8037, 2.1303, 1.8741],
         [2.4995, 1.8638, 1.5168, 1.4671],
         [2.2822, 2.4645, 1.6113, 1.4794],
         [3.4983, 3.8944, 2.3525, 2.8433],
         [2.9918, 2.2058, 1.3558, 2.1525]]], grad_fn=<ViewBackward0>)
tensor([[[2.8045, 2.5164, 1.8172, 1.7331],
         [2.1832, 2.4806, 1.3740, 1.4209],
         [3.6595, 3.0752, 2.2102, 2.5467],
         [2.6786, 2.8069, 2.1306, 1.8768],
         [2.4971, 1.8619, 1.5163, 1.4707],
         [2.2756, 2.4578, 1.6089, 1.4782],
         [3.4909, 3.8820, 2.3495, 2.8353],
         [2.9913, 2.1998, 1.3497, 2.1574]]], grad_fn=<ViewBackward0>)
tensor([[[-0.0030, -0.0043, -0.0070,  0.0017],
         [-0.0057, -0.0057, -0.0029, -0.0041],
         [-0.0091, -0.0118, -0.0090, -0.0031],
         [-0.0006,  0.0032,  0.0003,  0.0027],
         [-0.0024, -0.0019, -0.0005,  0.0036],
         [-0.0066, -0.0