In [1]:
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import functools
import glob
import logging
import numpy as np
import os
import torch
import importlib as imp

from tqdm import tqdm
tprint = tqdm.write

from torch_geometric.data import Data
from torch_geometric.nn.pool import radius_graph
from torch_geometric.transforms import FixedPoints
from tqdm import tqdm
from typing import Callable, Dict, List, Optional, Union

from typing import Callable, Optional, Union

import torch
from torch import Tensor
from torch.nn import Linear

# from torch_geometric.nn.conv import PointNetConv
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import (
    Adj,
    OptTensor,
    PairOptTensor,
    PairTensor
)
from torch_geometric.utils import add_self_loops, remove_self_loops
from torch_geometric.nn.norm import BatchNorm
import pytorch_lightning as pl
import pandas as pd

import aegnn
from aegnn.models.networks.my_fuse import MyConvBNReLU

from torch.nn import Linear
from torch.nn import Parameter as P
from torch.nn.functional import relu

pl.seed_everything(12345)
device = torch.device('cuda')

torch.set_printoptions(precision=4)

def quant_tensor(real, scale, bit, signed):
    if signed:
        max = pow(2, bit-1) - 1
        min = - max  # symmetric clamp
    else:
        max = pow(2, bit) - 1
        min = 0

    quant = torch.round(real/scale)
    if torch.max(quant) > max:
        print(f'overflow: max={torch.max(quant).item()}')
    if torch.min(quant) < min:
        print(f'underflow: min={torch.min(quant).item()}')
    quant = torch.clamp(quant, min=min, max=max)

    return quant

def dequant_tensor(quant, scale):
    real = quant * scale
    return real


def q(tensor, bit, signed, replace_scale = None):
    t_max = torch.max(tensor)
    t_min = torch.min(tensor)
    t_abs_max = torch.maximum(torch.abs(t_max), torch.abs(t_min))

    if replace_scale is not None:
        scale = replace_scale
    else:
        if signed:
            q_max = pow(2, bit-1) - 1
            q_min = -q_max
            scale = 2*t_abs_max / (q_max-q_min)
        else:
            q_max = pow(2, bit) - 1
            q_min = 0
            scale = (t_abs_max - 0.0) / (q_max-q_min)
    qtensor = quant_tensor(tensor, scale=scale, bit=bit, signed=signed)

    dqtensor = dequant_tensor(qtensor, scale)

    return qtensor, scale, dqtensor

def reg(fc, w, b):
    fc.weight = P(w)
    fc.bias = P(b)
    return fc






  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 12345


In [5]:
# test pass
fc1 = Linear(2,3)
fc2 = Linear(3,4)
qfc1 = Linear(2,3)
qfc2 = Linear(3,4)

x0 = torch.rand(2,2)
print(f'x0 = \n{x0}\n')
x1 = fc1(x0)
print(f'x1 = \n{x1}\n')
xr1 = relu(x1)
print(f'xr1 = \n{xr1}\n')
x2 = fc2(xr1)
print(f'x2 = \n{x2}\n')
y = relu(x2)
print(f'y = \n{y}\n')

qx0, x0_scale, dqx0 = q(x0, bit=8, signed=False)
qw1, w1_scale, dqw1 = q(fc1.weight, bit=8, signed=True)
qb1, b1_scale, dqb1 = q(fc1.bias, bit=32, signed=True, replace_scale=x0_scale*w1_scale)
qfc1 = reg(qfc1, qw1, qb1)
print(x0)
print(dqx0)
print('')

qxr1, xr1_scale, dqxr1 = q(xr1, bit=8, signed=False)
qw2, w2_scale, dqw2 = q(fc2.weight, bit=8, signed=True)
qb2, b2_scale, dqb2 = q(fc2.bias, bit=32, signed=True, replace_scale=xr1_scale*w2_scale)
qfc2 = reg(qfc2, qw2, qb2)
print(xr1)
print(dqxr1)
print('')

qy, y_scale, dqy = q(y, bit=8, signed=False)
print(y)
print(dqy)
print('')

M1 = x0_scale*w1_scale/xr1_scale
M2 = xr1_scale*w2_scale/y_scale
print(M1)
print(M2)

print(f'qfc1.weight = \n{qfc1.weight}')
dw1 = dequant_tensor(qfc1.weight, scale=w1_scale)
print(f'd qfc1.weight = \n{dw1}')
print(f'fc1.weight = \n{fc1.weight}')
print('')
print(f'qfc1.bias = \n{qfc1.bias}')
db1 = dequant_tensor(qfc1.bias, scale=b1_scale)
print(f'd qfc1.weight = \n{db1}')
print(f'fc1.bias = \n{fc1.bias}')
print('')

print(f'qfc2.weight = \n{qfc2.weight}')
dw2 = dequant_tensor(qfc2.weight, scale=w2_scale)
print(f'd qfc2.weight = \n{dw2}')
print(f'fc2.weight = \n{fc2.weight}')
print('')
print(f'qfc2.bias = \n{qfc2.bias}')
db2 = dequant_tensor(qfc2.bias, scale=b2_scale)
print(f'd qfc2.weight = \n{db2}')
print(f'fc2.bias = \n{fc2.bias}')
print('')

print(f'qx0 = \n{qx0}\n')
c_qx1 = torch.round(qfc1(qx0)*M1)
print(f'c_qx1 = \n{c_qx1}\n')
c_qxr1 = relu(c_qx1)
print(f'c_qxr1 = \n{c_qxr1}\n')

c_qx2 = torch.round(qfc2(c_qxr1)*M2)
print(f'c_qx2 = \n{c_qx2}\n')
c_qy = relu(c_qx2)
print(f'c_qy = \n{c_qy}\n')

dc_qy = dequant_tensor(c_qy, scale=y_scale)
print(f'dc_qy = \n{dc_qy}\n')
print(f'y = \n{y}\n')

x0 = 
tensor([[0.1759, 0.8837],
        [0.4747, 0.1197]])

x1 = 
tensor([[ 0.5631,  0.0649, -0.1187],
        [ 0.2499, -0.0433, -0.6513]], grad_fn=<AddmmBackward0>)

xr1 = 
tensor([[0.5631, 0.0649, 0.0000],
        [0.2499, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)

x2 = 
tensor([[ 0.4669, -0.6187, -0.6027, -0.2498],
        [ 0.2995, -0.4151, -0.4490, -0.4279]], grad_fn=<AddmmBackward0>)

y = 
tensor([[0.4669, 0.0000, 0.0000, 0.0000],
        [0.2995, 0.0000, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)

tensor([[0.1759, 0.8837],
        [0.4747, 0.1197]])
tensor([[0.1767, 0.8837],
        [0.4748, 0.1213]])

tensor([[0.5631, 0.0649, 0.0000],
        [0.2499, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)
tensor([[0.5631, 0.0640, 0.0000],
        [0.2495, 0.0000, 0.0000]], grad_fn=<MulBackward0>)

tensor([[0.4669, 0.0000, 0.0000, 0.0000],
        [0.2995, 0.0000, 0.0000, 0.0000]], grad_fn=<ReluBackward0>)
tensor([[0.4669, 0.0000, 0.0000, 0.0000],
        [0.3003, 0.0000, 0.0000, 0.0000]]

In [8]:
class QNet(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.pos_dim = 2
        self.lin1 = Linear(2+self.pos_dim,3)
        self.qlin1 = Linear(2+self.pos_dim,3)

        self.lin2 = Linear(3+self.pos_dim,4)
        self.qlin2 = Linear(3+self.pos_dim,4)

        

    
    def forward(self, x0, CONST):
        self.x0 = x0
        self.CONST = CONST
        self.x0_c = torch.cat([self.x0, self.CONST], dim=1)
        self.x1 = relu(self.lin1(self.x0_c))

        self.x1_c = torch.cat([self.x1, self.CONST], dim=1)
        self.x2 = relu(self.lin2(self.x1_c))

        print(f'x0 = \n{self.x0}')
        print(f'x0_c = \n{self.x0_c}')
        print(f'x1 = \n{self.x1}')
        print(f'x1_c = \n{self.x1_c}')
        print(f'x2 = \n{self.x2}\n')

        return self.x2

    def quantize(self):
        # self.dpos_scale = 4.0 / 255
        # self.dpos_scale = 4.0 / 2048  # 11bit = 8bit +3
        self.dpos_scale = 4.0 / pow(2,19)


        self.qx0, self.x0_scale, self.dqx0 = q(self.x0, bit=16, signed=False)


        # self.qw1, self.w1_scale, self.dqw1 = q(self.lin1.weight, bit=8, signed=True)
        x0_max = torch.max(self.x0)
        wx1_max = torch.max(self.lin1.weight[:, :-self.pos_dim])
        wx1_min = torch.min(self.lin1.weight[:, :-self.pos_dim])
        wx1_abs_max = torch.maximum(torch.abs(wx1_max), torch.abs(wx1_min))
        self.wx1_scale = 2 * wx1_abs_max / (32767 - (-32767)) #int16

        wpos1_max = torch.max(self.lin1.weight[:, -self.pos_dim:])
        wpos1_min = torch.min(self.lin1.weight[:, -self.pos_dim:])
        wpos1_abs_max = torch.maximum(torch.abs(wpos1_max), torch.abs(wpos1_min))
        print(f'wpos1_abs_max = {wpos1_abs_max}')
        # wpos psedo max is only for dpos part of weight. Though dpos_max = arg.radius = 3. Here take 4 for fast calculation
        # wpos1_pseudo_max = wx1_abs_max * x0_max / 4.0
        wpos1_pseudo_max = wx1_abs_max * x0_max * 2.0
        print(f'wpos1_pseudo_max = {wpos1_pseudo_max}')
        self.wpos1_scale = 2 * wpos1_pseudo_max / (32767 - (-32767))

        print(f'quant wx')
        self.wx1_quant = quant_tensor(self.lin1.weight[:, :-self.pos_dim], scale=self.wx1_scale, bit=16, signed=True)
        print(f'quant wpos')
        self.wpos1_quant = quant_tensor(self.lin1.weight[:, -self.pos_dim:], scale=self.wpos1_scale, bit=16, signed=True)
        self.w1_quant = torch.cat([self.wx1_quant, self.wpos1_quant], dim=1)

        self.qb1, self.b1_scale, self.dqb1 = q(self.lin1.bias, bit=32, signed=True, replace_scale=self.x0_scale*self.wx1_scale)
        reg(self.qlin1, w=self.w1_quant, b=self.qb1)




        self.qx1, self.x1_scale, self.dqx1 = q(self.x1, bit=16, signed=False)
        self.M1 = self.x0_scale*self.wx1_scale / self.x1_scale





        # self.qw2, self.w2_scale, self.dqw2 = q(self.lin2.weight, bit=8, signed=True)


        x1_max = torch.max(self.x1)
        wx2_max = torch.max(self.lin2.weight[:, :-self.pos_dim])
        wx2_min = torch.min(self.lin2.weight[:, :-self.pos_dim])
        wx2_abs_max = torch.maximum(torch.abs(wx2_max), torch.abs(wx2_min))
        self.wx2_scale = 2 * wx2_abs_max / (32767 - (-32767))

        wpos2_max = torch.max(self.lin2.weight[:, -self.pos_dim:])
        wpos2_min = torch.min(self.lin2.weight[:, -self.pos_dim:])
        wpos2_abs_max = torch.maximum(torch.abs(wpos2_max), torch.abs(wpos2_min))
        print(f'wpos2_abs_max = {wpos2_abs_max}')
        # wpos psedo max is only for dpos part of weight. Though dpos_max = arg.radius = 3. Here take 4 for fast calculation
        # wpos2_pseudo_max = wx2_abs_max * x0_max / 4.0
        wpos2_pseudo_max = wx2_abs_max * x1_max * 2.0
        print(f'wpos2_pseudo_max = {wpos2_pseudo_max}')
        self.wpos2_scale = 2 * wpos2_pseudo_max / (32767 - (-32767))

        print(f'quant wx')
        self.wx2_quant = quant_tensor(self.lin2.weight[:, :-self.pos_dim], scale=self.wx2_scale, bit=16, signed=True)
        print(f'quant wpos')
        self.wpos2_quant = quant_tensor(self.lin2.weight[:, -self.pos_dim:], scale=self.wpos2_scale, bit=16, signed=True)
        self.w2_quant = torch.cat([self.wx2_quant, self.wpos2_quant], dim=1)





        self.qb2, self.b2_scale, self.dqb2 = q(self.lin2.bias, bit=32, signed=True, replace_scale=self.x1_scale*self.wx2_scale)
        reg(self.qlin2, w=self.w2_quant, b=self.qb2)

        self.qx2, self.x2_scale, self.dqx2 = q(self.x2, bit=16, signed=False)
        self.M2 = self.x1_scale*self.wx2_scale / self.x2_scale

    def qforward(self):
        print(f'\nqforward start\n')
        self.qc = quant_tensor(self.CONST, scale=self.dpos_scale, bit=19, signed=False)
        self.c_qx0_c = torch.cat([self.qx0, self.qc], dim=1)
        self.c_qx1 = torch.round(relu(self.M1 * self.qlin1(self.c_qx0_c)))
        # self.c_qx2 = torch.round(relu(self.M2 * self.qlin2(self.c_qx1)))
        self.c_qx1_c = torch.cat([self.c_qx1, self.qc], dim=1)
        self.c_qx2 = torch.round(relu(self.M2 * self.qlin2(self.c_qx1_c)))

        # self.dc_qx0 = dequant_tensor(self.c_qx0, scale=self.x0_scale)
        self.dc_qx1 = dequant_tensor(self.c_qx1, scale=self.x1_scale)
        self.dc_qx2 = dequant_tensor(self.c_qx2, scale=self.x2_scale)

        print(f'c_qx0_c = \n{self.c_qx0_c}')
        # print(f'dc_qx0 = \n{self.dc_qx0}')
        print(f'x0 = \n{self.x0}\n')
        print(f'c_qx1 = \n{self.c_qx1}')
        print(f'dc_qx1 = \n{self.dc_qx1}')
        print(f'x1 = \n{self.x1}\n')
        print(f'c_qx2 = \n{self.c_qx2}')
        print(f'dc_qx2 = \n{self.dc_qx2}')
        print(f'x2 = \n{self.x2}\n')

        # return self.dc_qx2


In [10]:
with torch.no_grad():
    x0 = torch.rand(2,2)
    c = torch.tensor([[3.0,0.0],[0.0,3.0]])
    net = QNet()

    x2 = net(x0,c)

    net.quantize()
    net.qforward()
    # dc_qx2 = net.qforward()


x0 = 
tensor([[0.1975, 0.9458],
        [0.6762, 0.7570]])
x0_c = 
tensor([[0.1975, 0.9458, 3.0000, 0.0000],
        [0.6762, 0.7570, 0.0000, 3.0000]])
x1 = 
tensor([[0.0000, 1.2900, 1.0098],
        [0.2307, 0.1481, 0.1677]])
x1_c = 
tensor([[0.0000, 1.2900, 1.0098, 3.0000, 0.0000],
        [0.2307, 0.1481, 0.1677, 0.0000, 3.0000]])
x2 = 
tensor([[0.4949, 1.9529, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6714, 1.1601]])

wpos1_abs_max = 0.34854263067245483
wpos1_pseudo_max = 0.8938669562339783
quant wx
quant wpos
wpos2_abs_max = 0.43047958612442017
wpos2_pseudo_max = 1.1475560665130615
quant wx
quant wpos

qforward start

c_qx0_c = 
tensor([[ 13687.,  65535., 393216.,      0.],
        [ 46855.,  52456.,      0., 393216.]])
x0 = 
tensor([[0.1975, 0.9458],
        [0.6762, 0.7570]])

c_qx1 = 
tensor([[    0., 65537., 51302.],
        [11720.,  7522.,  8519.]])
dc_qx1 = 
tensor([[0.0000, 1.2900, 1.0098],
        [0.2307, 0.1481, 0.1677]])
x1 = 
tensor([[0.0000, 1.2900, 1.0098],
      