In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def LinearQuantizeOut(x, k, alpha):                # only quantize >0 values (relu must be preceded)
    L = 2.**k - 1
    xdiv = x.div(alpha)
    xc = xdiv.clamp(min=0., max=1.)
    xq = xc.mul(L).round()
    xmul = xq.div(L).mul(alpha)
    return xmul

def LinearQuantizeW(x, k, max_val, min_val):       # asymetric quant
    delta = max_val - min_val
    L= 2 ** k - 1
    stepSize = delta / L
    index = torch.clamp(torch.round((x-min_val) / delta * L), 0, L)
    return index, index * stepSize + min_val, stepSize


# Ideally, same to inputQ * weightQ if ADC=32bits
def MAC(inputQ, weightQ, abits, wbits, adcbits, output_size, subArray):
    assert wbits >= 2
    outputreal = F.linear(inputQ, weightQ, None)
    # print(f'real: {F.linear(inputQ, weightQ, None)}')
    outputShiftIN = torch.zeros_like(outputreal)
    for z in range(abits):
        inputB = torch.fmod(inputQ, 2)              # 12,10,14 = [0,0,0] / [0,1,1] / [1,0,1] / [1,1,1]
        inputQ = torch.round((inputQ-inputB)/2)     # 12,10,14 = [6,5,7] / [3,2,3] / [1,1,1] / [0,0,0]
        weightQb = weightQ
        outputShiftW = torch.zeros_like(outputreal)
        for k in range (wbits):
            weightB = torch.fmod(weightQb, 2)
            weightQb = torch.round((weightQb-weightB)/2)
            outputPartial = F.linear(inputB, weightB, None)
            # Add ADC quanization effects here !!!
            outputADC = LinearQuantizeOut(outputPartial, adcbits, subArray)
            # shift per w bit sequence
            outputShiftW = outputShiftW + outputADC * (2 ** k)
        # shift per input bit sequence
        outputShiftIN = outputShiftIN + outputShiftW * (2 ** z)
    # since inputQ [0, 15] when k=4, rescale output by divide 16
    # output = output + inputS * (outputIN * wS + outputIND * w.min())     # suppose I=[0~15], W=[-8~7] -> I*W = I[0~15]*W[0~15] + I[0~15]*W_constant[-8] (which is w.min())
    if output_size != outputShiftIN.size():
        outputShiftIN = outputShiftIN.transpose(1,2).reshape(output_size)
    return outputShiftIN

# Ideally, same to inputQ * weightQ if ADC=32bits
def CAM(inputQ, weightQ, abits, wbits, adcbits, output_size, subArray):
    # print(f'real: {F.linear(inputQ, weightQi, None)}')
    outputreal = F.linear(inputQ, weightQ, None)
    outputShiftIN = torch.zeros_like(outputreal)
    for z in range(abits):
        inputB = torch.fmod(inputQ, 2)              # 12,10,14 = [0,0,0] / [0,1,1] / [1,0,1] / [1,1,1]
        inputQ = torch.round((inputQ-inputB)/2)     # 12,10,14 = [6,5,7] / [3,2,3] / [1,1,1] / [0,0,0]
        uniqs = torch.unique(weightQ)
        outputShiftW = torch.zeros_like(outputreal)
        # outputShiftW = F.linear(inputB, weightQ, None)  # approximately ~16.5s
        for un in uniqs:    # approximately ~28s
            maskCAM = (weightQ == un).float()
            outML = F.linear(inputB, maskCAM, None)
            # outMLADC = LinearQuantizeOut(outML, adcbits, subArray)
            outMLADC = outML # (original is of course better)
            outputShiftW = outputShiftW + outMLADC * un
        outputShiftIN = outputShiftIN + outputShiftW * (2 ** z)
        
    if output_size != outputShiftIN.size():
        outputShiftIN = outputShiftIN.transpose(1,2).reshape(output_size)
    return outputShiftIN

# Ideally, same to inputQ * weightQ if ADC=32bits
def ZP_MAC(inputQ, weightQ, abits, adcbits, output_size, subArray, zero_point_opt):
    outputreal = F.linear(inputQ, weightQ, None)
    outputDummyShift = torch.zeros_like(outputreal)
    if zero_point_opt:
        outputDummyShift = F.linear(inputQ, weightQ, None)
    else:
        for z in range(abits):
            inputB = torch.fmod(inputQ, 2)              
            inputQ = torch.round((inputQ-inputB)/2)     
            outputDummy = F.linear(inputB, weightQ, None)
            outputDummyADC = LinearQuantizeOut(outputDummy, adcbits, subArray)
            outputDummyShift = outputDummyShift + outputDummyADC * (2 ** z)
    # since inputQ [0, 15] when k=4, rescale output by divide 16
    # output = output + inputS * (outputIN * wS + outputIND * w.min())     # suppose I=[0~15], W=[-8~7] -> I*W = I[0~15]*W[0~15] + I[0~15]*W_constant[-8] (which is w.min())
    if output_size != outputDummyShift.size():
        outputDummyShift = outputDummyShift.transpose(1,2).reshape(output_size)
    return outputDummyShift
    
    
# default: vanilla
# inference = 0: real (FP32)
# inference = 1: Activation/Weight Quantize
# inference = 2: Activation/Weight/Output Quantize
# inference = 3: PIM (Activation/Weight/Output Quantize)
# inference = -1: PIM-mimic (where consider dummy)
class QConv2dCAM(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, bitActivation=8,bitWeight=[4,4],
                 inference=0,subArray=128,bitADC=5,zero_point_opt=False):
        super(QConv2dCAM, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        assert isinstance(bitWeight, list)
        self.bitWeight = bitWeight
        self.bitWeightMSB, self.bitWeightLSB = bitWeight
        self.bitActivation = bitActivation
        self.inference = inference
        self.subArray = subArray
        self.bitADC = bitADC
        self.zero_point_opt = zero_point_opt
        
    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        if self.padding_mode != 'zeros':
            s += ', padding_mode={padding_mode}'
        s += ', ibits={bitActivation}, wbits={bitWeight}, inference={inference}, subArray={subArray}, ADCbits={bitADC}'
        return s.format(**self.__dict__)
        
    # convert input -> # [N, OH * OW, IC * KH * KW]
    def input_2d_mapping(self, input):
        fold_param = dict(kernel_size=self.kernel_size, dilation=self.dilation, padding=self.padding, stride=self.stride)
        unfold_module = nn.Unfold(**fold_param)
        unfold_out = unfold_module(input)
        return unfold_out.transpose(1,2)
    
    # convert: weight -> # [IC * KH * KW, OC]
    def weight_2d_mapping(self, weight):
        return weight.reshape(weight.shape[0], -1)
    
    def forward(self, input):
        # make input & weight to 2D tensors
        input2D = self.input_2d_mapping(input)                                                           
        weight2D = self.weight_2d_mapping(self.weight)                              

        inputQ, inputQS, inputS = LinearQuantizeW(input2D, self.bitActivation, input2D.max(), input2D.min())
        weightQ, weightQS, weightS = LinearQuantizeW(weight2D, sum(self.bitWeight), weight2D.max(), weight2D.min())
        
        outputreal = F.conv2d(input, self.weight, self.bias, self.stride, self.padding)
        weightQM = weightQ // (2.**self.bitWeightLSB)
        if self.bitWeightMSB == 0:                              # if MSB=0bit, then use weightQ
            weightQL = weightQ
        else:
            weightQL = weightQ % (2.**self.bitWeightLSB)
        
        numSubArray = int(weight2D.shape[1]/self.subArray)
        if self.inference == 3:
            if numSubArray == 0:
                outputM = CAM(inputQ, weightQM, self.bitActivation, self.bitWeightMSB, self.bitADC, outputreal.size(), weight2D.shape[1])
                outputL = MAC(inputQ, weightQL, self.bitActivation, self.bitWeightLSB, self.bitADC, outputreal.size(), weight2D.shape[1])
                outputDL = ZP_MAC(inputQ, torch.ones_like(weightQ), self.bitActivation, self.bitADC, outputreal.size(), weight2D.shape[1], self.zero_point_opt)
                
                outputP = (outputM * (2.**self.bitWeightLSB) + outputL) * weightS
                outputD = outputDL * weight2D.min()
                out = inputS * (outputP + outputD)
            else:
                numSubRow = [self.subArray] * numSubArray + ([] if weight2D.shape[1] % self.subArray == 0 else [weight2D.shape[1] % self.subArray])
                out = torch.zeros_like(outputreal)
                for s, rowArray in enumerate(numSubRow):
                    mask = torch.zeros_like(weight2D)
                    mask[:,(s*self.subArray):(s+1)*self.subArray] = 1
                    outputM = CAM(inputQ, weightQM*mask, self.bitActivation, self.bitWeightMSB, self.bitADC, outputreal.size(), rowArray)
                    outputL = MAC(inputQ, weightQL*mask, self.bitActivation, self.bitWeightLSB, self.bitADC, outputreal.size(), rowArray)
                    outputDL = ZP_MAC(inputQ, torch.ones_like(weightQ)*mask, self.bitActivation, self.bitADC, outputreal.size(), rowArray, self.zero_point_opt)
                    
                    outputP = (outputM * (2.**self.bitWeightLSB) + outputL) * weightS
                    outputD = outputDL * weight2D.min()
                    out = out + inputS * (outputP + outputD)
            if self.bias is not None:
                out = out + self.bias
        elif self.inference == 2:
            output = F.linear(inputQS, weightQS, self.bias)
            _, out, _ = LinearQuantizeW(output, self.bitADC, output.max(), output.min())
        elif self.inference == 1:
            out = F.linear(inputQS, weightQS, self.bias)
        elif self.inference == -1:
            outputM = F.linear(inputQ, weightQM, None).transpose(1,2).reshape(outputreal.size())
            outputL = F.linear(inputQ, weightQL, None).transpose(1,2).reshape(outputreal.size())
            outputDL = F.linear(inputQ, torch.ones_like(weightQ), None).transpose(1,2).reshape(outputreal.size())
            
            outputP = (outputM * (2.**self.bitWeightLSB) + outputL) * weightS
            outputD = outputDL * weight2D.min()
            out = inputS * (outputP + outputD)
            if self.bias is not None:
                out = out + self.bias
        else:
            out = outputreal
        return out

In [2]:
abits= 8
wbits= 8
inference= 3
subarray= 100_000
adcbits= 6
zero_point_opt = True
LSB_bits = 4

x = torch.load("/app/base/base_input.pth")[0][:10]
w = torch.load("/app/base/base_weight.pth")
oc, ic, k, _ = w.size()
s, p, d, g = 1, 0, 1, 1

# test(x, w, abits, wbits, LSB_bits)
print(f'subarray=original', '=='*10)
L = nn.Conv2d(ic, oc, k, bias=False)
L.weight.data = w
origout = L(x)
outputconvsize = origout.size()
print(origout.flatten()[:10])
print('=='*20)

qx, qxs, xs = LinearQuantizeW(x, abits, x.max(), x.min())
# print('qx: ', torch.unique(qx), torch.unique(qxs, return_counts=True))
qw, qws, ws = LinearQuantizeW(w, wbits, w.max(), w.min())
# print('qw: ', torch.unique(qw), torch.unique(qws, return_counts=True))
L = QConv2dCAM(ic, oc, k, s, p, d, g, False, abits, [wbits-LSB_bits, LSB_bits], inference, subarray, adcbits, zero_point_opt)
# 2D mapping
iqx = L.input_2d_mapping(qx)
iqw = L.weight_2d_mapping(qw)

def linear(x, w, outputconvsize):
    return F.linear(x, w, None).transpose(1,2).reshape(outputconvsize)

print(f'subarray=quant MSB', '-'*10)
msb_iqw = iqw // (2.**LSB_bits)
msb_out = linear(iqx, msb_iqw, outputconvsize)
msb_out_scale = msb_out * ws
print(msb_out_scale.flatten()[:10])

print(f'subarray=quant LSB', '-'*10)
lsb_iqw = iqw % (2.**LSB_bits)
lsb_out = linear(iqx, lsb_iqw, outputconvsize)
lsb_out_scale = lsb_out * ws
print(lsb_out_scale.flatten()[:10])

print(f'subarray=quant zero_point', '-'*10)
lsb_dum_out = linear(iqx, torch.ones_like(iqw), outputconvsize) * w.min()
print(lsb_dum_out.flatten()[:10])

out = xs * (msb_out_scale * (2.**LSB_bits) + lsb_out_scale + lsb_dum_out) 
print(out.flatten()[:10])
print(f'MSE Error: {round(abs(torch.sum((out - origout)**2).item() / origout.nelement()), 2)}')
print('=='*20)

print(f'subarray=qconv2dcam inference: 0', '-'*10)
inference = 0
layer = QConv2dCAM(ic, oc, k, 1, 0, 1, 1, False, abits, [wbits-LSB_bits,LSB_bits], inference, 100_000, adcbits)
layer.weight.data = w
out = layer(x)
print(out.flatten()[:10])
print('=='*20)

tensor([-0.3429, -0.3891, -0.4199, -0.3860, -0.3704, -0.4122, -0.4087, -0.4121,
        -0.3671, -0.3403], device='cuda:0', grad_fn=<SliceBackward0>)
subarray=quant MSB ----------
tensor([67.0770, 70.0010, 70.0984, 69.7911, 69.7089, 69.3103, 69.5355, 69.6085,
        69.7911, 68.3610], device='cuda:0', grad_fn=<SliceBackward0>)
subarray=quant LSB ----------
tensor([59.9326, 64.0829, 64.3324, 64.3020, 64.6793, 64.3507, 64.5728, 64.7767,
        65.5130, 63.6630], device='cuda:0', grad_fn=<SliceBackward0>)
subarray=quant zero_point ----------
tensor([-1164.2705, -1219.4513, -1224.0117, -1215.8030, -1213.5228, -1210.7866,
        -1214.4349, -1215.8030, -1215.3469, -1188.4407], device='cuda:0',
       grad_fn=<SliceBackward0>)
tensor([-0.3453, -0.3925, -0.4230, -0.3868, -0.3719, -0.4160, -0.4140, -0.4140,
        -0.3683, -0.3442], device='cuda:0', grad_fn=<SliceBackward0>)
MSE Error: 0.0
subarray=qconv2dcam inference: 0 ----------
tensor([-0.3429, -0.3891, -0.4199, -0.3860, -0.3704, -0.4

In [3]:
print(f'subarray=qconv2dcam inference: 3', '-'*10)
L = QConv2dCAM(ic, oc, k, s, p, d, g, False, abits, [wbits-LSB_bits, LSB_bits], inference, subarray, adcbits, zero_point_opt)
L.weight.data = w
out = L(x)
print(out.flatten()[:10])
print(f'MSE Error: {round(abs(torch.sum((out - origout)**2).item() / origout.nelement()), 2)}')
print('=='*20)

subarray=qconv2dcam inference: 3 ----------
tensor([-0.3429, -0.3891, -0.4199, -0.3860, -0.3704, -0.4122, -0.4087, -0.4121,
        -0.3671, -0.3403], device='cuda:0', grad_fn=<SliceBackward0>)
MSE Error: 0.0
