-
Notifications
You must be signed in to change notification settings - Fork 27.8k
Expand file tree
/
Copy pathquantizer.py
More file actions
136 lines (116 loc) · 5.49 KB
/
quantizer.py
File metadata and controls
136 lines (116 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
from torch import Tensor
import numpy as np
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float, quant_dequant_util
# class to store APoT quantizer and
# implement quantize and dequantize
class APoTQuantizer():
alpha: torch.Tensor
gamma: torch.Tensor
quantization_levels: torch.Tensor
level_indices: torch.Tensor
def __init__(
self,
alpha: torch.Tensor,
gamma: torch.Tensor,
quantization_levels: torch.Tensor,
level_indices: torch.Tensor) -> None:
self.alpha = alpha
self.gamma = gamma
self.quantization_levels = quantization_levels
self.level_indices = level_indices
r""" Quantizes fp Tensor to integer APoT representation.
Conversion is based on the qparams from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
tensor2quantize: fp Tensor
Returns:
result: APoT Tensor representation of tensor2quantize
"""
def quantize(self, tensor2quantize: Tensor):
result = torch.tensor([])
# map float_to_apot over tensor2quantize elements
tensor2quantize = tensor2quantize.detach().apply_(lambda x: float_to_apot(x,
self.quantization_levels,
self.level_indices,
self.alpha))
# convert to APoT int representation for dtype
tensor2quantize = tensor2quantize.int()
from torch.ao.quantization.experimental.APoT_tensor import TensorAPoT
result = TensorAPoT(self, tensor2quantize)
return result
r""" Dequantizes integer Tensor to floating point (fp) representation
based on the calculated quantization levels from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
tensor2quantize: fp Tensor
Returns:
result: fp reduced precision representation of input Tensor
"""
def dequantize(self, apot_tensor) -> Tensor:
orig_size = apot_tensor.data.size()
apot_tensor_data = apot_tensor.data.flatten()
print(apot_tensor_data)
# map apot_to_float over tensor2quantize elements
result_temp = np.empty(shape=apot_tensor_data.size())
for i in range(len(apot_tensor_data)):
new_ele = apot_to_float(apot_tensor_data[i], self.quantization_levels, self.level_indices)
result_temp[i] = new_ele
result = torch.from_numpy(result_temp).reshape(orig_size)
return result
r""" Returns result of quantize -> dequantize on a fp Tensor (reduced precision)
based on the calculated quantization levels from a specified APoT non-uniform observer.
The approach follows the method outlined in the APoT paper: https://arxiv.org/pdf/1909.13144.pdf.
Args:
apot_tensor: quantized APoT Tensor to dequantize
Returns:
result: fp representation of input Tensor
"""
def quant_dequant(self, tensor2quantize: Tensor) -> Tensor:
levels_lst = list(self.quantization_levels)
result = tensor2quantize.apply_(lambda x: quant_dequant_util(x, levels_lst))
return result
def q_apot_alpha(self) -> float:
raise NotImplementedError
r""" Global method to create quantizer and call quantizer quantize_APoT
Args:
tensor2quantize: fp Tensor to quantize
alpha: Tensor qparam alpha (clipping level)
gamma: Tensor qparam gamma (scale factor for quantization levels)
quantization levels: Tensor with fp quantization levels
level indices: Tensor with integer quantization level indices
Returns:
result: ApoT Tensor representation of tensor2quantize
"""
def quantize_APoT(tensor2quantize: Tensor, alpha: Tensor, gamma: Tensor, quantization_levels: Tensor, level_indices: Tensor):
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
result = quantizer.quantize(tensor2quantize)
return result
r""" Global method to create quantizer and call quantizer dequantize_APoT
Args:
apot_tensor: APoT Tensor to dequantize
Returns:
result: fp Tensor dequantized from apot_tensor
"""
def dequantize_APoT(apot_tensor) -> Tensor:
quantizer = apot_tensor.quantizer
result = quantizer.dequantize(apot_tensor)
return result
r""" Global method to create quantizer and call quantizer quant_dequant
Args:
tensor2quantize: fp Tensor to quantize
alpha: Tensor qparam alpha (clipping level)
gamma: Tensor qparam gamma (scale factor for quantization levels)
quantization levels: Tensor with fp quantization levels
level indices: Tensor with integer quantization level indices
Returns:
result: fp reduced precision Tensor from tensor2quantize
"""
def quant_dequant_APoT(tensor2quantize: Tensor,
alpha: Tensor,
gamma: Tensor,
quantization_levels: Tensor,
level_indices: Tensor) -> Tensor:
quantizer = APoTQuantizer(alpha=alpha, gamma=gamma, quantization_levels=quantization_levels, level_indices=level_indices)
result = quantizer.quant_dequant(tensor2quantize)
return result