In [2]:
import tensorly as tl
from tensorly.decomposition import parafac, quantized_parafac
from tensorly.kruskal_tensor import kruskal_to_tensor, KruskalTensor
from tensorly.base import unfold
from tensorly.quantization import quantize_qint

import torch
tl.set_backend('pytorch')

# Example 1. Tensor quantization
Quantization scheme can be either affine or symmetric.

Scale and zero_point values to perform quantization are computed either per channel or per tensor (i.e. we get either vectors or scalars).

Thus, there are 4 types of quantization scheme:

    ``torch.per_tensor_affine``
    ``torch.per_tensor_symmetric``
    ``torch.per_channel_affine``
    ``torch.per_channel_symmetric``

##### Generate a random tensor

In [3]:
t = torch.randn(256, 256, 9)
print('||float_tensor|| = {}'.format(tl.norm(t)))

dtype = torch.qint8

||float_tensor|| = 767.585693359375


##### Per channel  quantization

In [4]:
for qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
    print("\nPer channel quantization, dtype: {}, qscheme: {}".format(dtype, qscheme))
    
    for dim in range(len(t.shape)):
        qt, scale, zero_point = quantize_qint(t,\
                                              dtype,\
                                              qscheme,\
                                              dim = dim,\
                                              return_scale_zeropoint=True)

        print('Per dim {}, ||float_tensor - quant_tensor|| = {}'.format(dim, tl.norm(t - qt)))


Per channel quantization, dtype: torch.qint8, qscheme: torch.per_channel_affine
Per dim 0, ||float_tensor - quant_tensor|| = 7.277091026306152
Per dim 1, ||float_tensor - quant_tensor|| = 7.280121803283691
Per dim 2, ||float_tensor - quant_tensor|| = 7.608396053314209

Per channel quantization, dtype: torch.qint8, qscheme: torch.per_channel_symmetric
Per dim 0, ||float_tensor - quant_tensor|| = 7.277091026306152
Per dim 1, ||float_tensor - quant_tensor|| = 7.280121803283691
Per dim 2, ||float_tensor - quant_tensor|| = 7.608396053314209


##### Per tensor quantization

In [5]:
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
    print("\nPer tensor quantization, dtype: {}, qscheme: {}".format(dtype, qscheme))

    
    qt, scale, zero_point = quantize_qint(t,\
                                          dtype,\
                                          qscheme,\
                                          dim = dim,\
                                          return_scale_zeropoint=True)
    print('Per tensor, ||float_tensor - quant_tensor|| = {}'.format(tl.norm(t - qt)))


Per tensor quantization, dtype: torch.qint8, qscheme: torch.per_tensor_affine
Per tensor, ||float_tensor - quant_tensor|| = 8.38266372680664

Per tensor quantization, dtype: torch.qint8, qscheme: torch.per_tensor_symmetric
Per tensor, ||float_tensor - quant_tensor|| = 8.38266372680664


# Example 2. Quantization of a tensor in Kruskal format:
    a) via quantization of the corresponding full tensor
    b) via quantization of decomposition factors.

##### Generate tensor 

In [6]:
rank = 16
shape = (64, 64, 9)

factors = [torch.randn((i, rank)) for i in shape] 
weights = torch.ones(rank)

# tensor in Kruscal format
krt = KruskalTensor((weights, factors))

# corresponding tensor in full format
t = kruskal_to_tensor(krt)

tnorm = tl.norm(t)
print('||float_factors||: {}'.format(tnorm))

||float_factors||: 737.1425170898438


##### Choose quantization scheme

In [7]:
dtype = torch.qint8

## Per tensor quantization
qscheme, dim = torch.per_tensor_affine, None

## Uncomment for per channel quantization
# qscheme, dim = torch.per_channel_affine, 0

##### a) Quantize the full tensor

In [8]:
t_quant = quantize_qint(t, dtype, qscheme, dim = dim)
print('||float_factors - float_factors_quantized|| = {}'.format(tl.norm(t - t_quant)))

||float_factors - float_factors_quantized|| = 10.051530838012695


##### b) Quantize several factors

In [9]:
num_factors = len(factors)
for num_quant_factors in range(1, num_factors + 1):
    
    qfactors = [quantize_qint(factors[i], dtype, qscheme, dim = dim)\
                for i in range(num_quant_factors)\
               ] + [factors[i] for i in range(num_quant_factors, num_factors)]

    qkrt = KruskalTensor((weights, qfactors))
    qt = kruskal_to_tensor(qkrt)
    print('\n[{}/{}] factors are quantized'.format(num_quant_factors, num_factors))
    print('||quant_factors - float_factors|| = {}'.format(tl.norm(qt - t)))
#     print('||quant_factors - float_factors_quantized|| = {}'.format(tl.norm(qt - t_quant)))

#     qt_quant = quantize_qint(qt, dtype, qscheme, dim = dim)
#     print('\nquant_factors_quantized - t_quant_factors: {}'.format(tl.norm(qt_quant - qt)/tnorm))
    
#     print('||quant_factors_quantized - float_factors|| = {}'.format(tl.norm(qt_quant - t)/tnorm))
#     print('||quant_factors_quantized - float_factors_quantized|| = {}'.format(tl.norm(qt_quant - t_quant)/tnorm))



[1/3] factors are quantized
||quant_factors - float_factors|| = 9.89563274383545

[2/3] factors are quantized
||quant_factors - float_factors|| = 13.896100044250488

[3/3] factors are quantized
||quant_factors - float_factors|| = 17.108083724975586


# Example 3. Quantized ALS
Compare standard ALS algorithm for finding  CP decomposition with its quantized version, when at the end of each ALS step approximated factor is quantized.

##### Generate tensor

In [15]:
rank = 8
shape = (128, 128, 9)

factors = [torch.randn((i, rank)) for i in shape] 
weights = torch.ones(rank)

# tensor in Kruscal format
krt = KruskalTensor((weights, factors))

# corresponding tensor in full format
t = kruskal_to_tensor(krt)

tnorm = tl.norm(t)
print('||float_factors||: {}'.format(tnorm))

||float_factors||: 992.4262084960938


##### Find an approximation using ALS

In [16]:
normalize_factors = True
(factors_als, weights_als), _ =  quantized_parafac(t, rank, n_iter_max=50000,\
                                    init='random', tol=1e-8, svd = None,
                                    normalize_factors = normalize_factors)

In parafac original_tensor_norm = 992.4262084960938
parafac has stopped after iteration 63


In [17]:
tl.norm(t - kruskal_to_tensor(KruskalTensor((weights_als, factors_als))))

tensor(0.0004)

In [18]:
weights_als

tensor([391.0013, 355.6662, 269.9358, 423.7779, 310.8251, 315.8976, 267.2900,
        430.1736])

##### Find an approximation using quantized ALS

In [19]:
dtype = torch.qint8

## Per tensor quantization
qscheme, dim = torch.per_tensor_affine, None

## Uncomment for per channel quantization
# qscheme, dim = torch.per_channel_affine, 0

In [20]:
normalize_factors = True
(factors_qals, weights_qals), _, scales, zero_points = quantized_parafac(
                                    t, rank, n_iter_max=1001,\
                                    init='random', tol= None, svd = None,\
                                    normalize_factors = normalize_factors,\
                                    qmodes = [0, 1],
                                    quantize_every = 2,
                                    qscheme = qscheme, dtype = dtype, dim = dim,
                                    return_scale_zeropoint=True)

In parafac original_tensor_norm = 992.4262084960938
iteration 500, diff_from_norm = 17.587434768676758, rel_rec_error = 0.017721654887901908
iteration 1000, diff_from_norm = 17.587434768676758, rel_rec_error = 0.017721654887901908


In [21]:
tl.norm(t - kruskal_to_tensor(KruskalTensor((weights_qals, factors_qals))))

tensor(17.5874)

In [22]:
weights_qals

tensor([312.2208, 430.0728, 315.8293, 359.4211, 267.2516, 391.3284, 423.4779,
        271.4605])

In [23]:
scales, zero_points

((tensor(0.0021), tensor(0.0021), None),
 (tensor(0, dtype=torch.int32), tensor(0, dtype=torch.int32), None))

In [24]:
factors_qals[0]

tensor([[-0.1367,  0.1035,  0.0601,  ...,  0.0352, -0.0518,  0.0911],
        [ 0.0911,  0.1139,  0.0083,  ..., -0.0393, -0.2340, -0.1470],
        [ 0.0166, -0.1346, -0.1077,  ..., -0.0476,  0.0145, -0.0041],
        ...,
        [-0.1035,  0.0021, -0.0456,  ...,  0.1346, -0.1160,  0.0704],
        [ 0.0725,  0.0414,  0.0393,  ...,  0.0766,  0.0311, -0.0683],
        [ 0.0456, -0.0663,  0.0725,  ...,  0.1077, -0.0228,  0.0704]])

In [25]:
factors_qals[0]/scales[0]

tensor([[ -66.,   50.,   29.,  ...,   17.,  -25.,   44.],
        [  44.,   55.,    4.,  ...,  -19., -113.,  -71.],
        [   8.,  -65.,  -52.,  ...,  -23.,    7.,   -2.],
        ...,
        [ -50.,    1.,  -22.,  ...,   65.,  -56.,   34.],
        [  35.,   20.,   19.,  ...,   37.,   15.,  -33.],
        [  22.,  -32.,   35.,  ...,   52.,  -11.,   34.]])

In [26]:
factors_qals[1]/scales[1]

tensor([[ -43.0000,  -20.0000,  -18.0000,  ...,  -94.0000,   -2.0000,
            1.0000],
        [  94.0000,   19.0000,   -1.0000,  ...,   23.0000,   74.0000,
          -34.0000],
        [  -3.0000,    8.0000,  -23.0000,  ...,  -36.0000, -107.0000,
           37.0000],
        ...,
        [  59.0000,   38.0000,   10.0000,  ...,   14.0000,   -1.0000,
          -54.0000],
        [  -5.0000,   21.0000,   84.0000,  ...,   38.0000,  -10.0000,
           71.0000],
        [ -11.0000,  -23.0000,  108.0000,  ...,   -8.0000,  -16.0000,
           60.0000]])