Skip to content
Closed
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
67ccbdb
[quant][graphmode][refactor] Add registerQParams function
jerryzh168 Nov 28, 2019
9392f76
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 2, 2019
9ab2431
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 2, 2019
5650fcf
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 2, 2019
e7601a9
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 2, 2019
a6293ac
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 3, 2019
60a8596
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 3, 2019
5c1b97e
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 3, 2019
0d71efe
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 4, 2019
7ca13bf
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 4, 2019
eee3dba
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 4, 2019
f8a40d7
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 4, 2019
7cbdc98
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 4, 2019
defbdec
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 4, 2019
cc7e7f6
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 5, 2019
ec957e2
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 6, 2019
091dffc
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 6, 2019
6a7f8e8
Update on "[quant][graphmode][refactor] Add registerQParams function"
jerryzh168 Dec 6, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions torch/csrc/jit/passes/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,9 @@ class InsertQuantDeQuantHelper {
Value* child_instance);
void collectObserverNodesAndValueToQuantize(script::Module& module, Value*);
void removeObservers(script::Module& module, Graph* g);
bool registerQParams(script::Module& module,
const std::tuple<IValue, IValue>& qparams_and_scalar_type,
const std::string& prefix);
void quantizeTensors(script::Module& module, Graph* g, Value* self);

private:
Expand Down Expand Up @@ -679,31 +682,37 @@ void InsertQuantDeQuantHelper::removeObservers(script::Module& module, Graph* g)
}
}

bool InsertQuantDeQuantHelper::registerQParams(
script::Module& module,
const std::tuple<IValue, IValue>& qparams_and_scalar_type,
const std::string& prefix) {
auto qparams = std::get<0>(qparams_and_scalar_type);
auto scalar_type = std::get<1>(qparams_and_scalar_type);
// Register attributes for quantization parameters
auto tp = qparams.toTuple();
at::Tensor scale = tp->elements()[0].toTensor().to(at::kFloat);
at::Tensor zero_point = tp->elements()[1].toTensor().to(at::kInt);
// TODO: get this info from qscheme
bool is_per_channel = scale.numel() > 1;
if (is_per_channel) {
module.register_attribute(prefix + "_scale", TensorType::get(), scale);
module.register_attribute(prefix + "_zero_point", TensorType::get(), zero_point);
module.register_attribute(prefix + "_axis", IntType::get(), tp->elements()[2].toInt());
} else {
module.register_attribute(prefix + "_scale", FloatType::get(), scale.item<double>());
module.register_attribute(prefix + "_zero_point", IntType::get(), zero_point.item<int64_t>());
}
module.register_attribute(prefix + "_scalar_type", IntType::get(), scalar_type);
return is_per_channel;
}

void InsertQuantDeQuantHelper::quantizeTensors(script::Module& module, Graph* g, Value* self) {
if (!values_to_quantize_.count(g)) {
return;
}
for (auto& v : values_to_quantize_.at(g)) {
TORCH_INTERNAL_ASSERT(values_to_qparams_.at(g).count(v));
auto qparams_and_scalar_type = values_to_qparams_.at(g).at(v);
auto qparams = std::get<0>(qparams_and_scalar_type);
auto scalar_type = std::get<1>(qparams_and_scalar_type);
// Register attributes for quantization parameters
auto tp = qparams.toTuple();
at::Tensor scale = tp->elements()[0].toTensor().to(at::kFloat);
at::Tensor zero_point = tp->elements()[1].toTensor().to(at::kInt);
// TODO: get this info from qscheme
bool is_per_channel = scale.numel() > 1;
std::string prefix = v->debugName();
if (is_per_channel) {
module.register_attribute(prefix + "_scale", TensorType::get(), scale);
module.register_attribute(prefix + "_zero_point", TensorType::get(), zero_point);
module.register_attribute(prefix + "_axis", IntType::get(), tp->elements()[2].toInt());
} else {
module.register_attribute(prefix + "_scale", FloatType::get(), scale.item<double>());
module.register_attribute(prefix + "_zero_point", IntType::get(), zero_point.item<int64_t>());
}
module.register_attribute(prefix + "_scalar_type", IntType::get(), scalar_type);
auto is_per_channel = registerQParams(module, qparams_and_scalar_type, v->debugName());
insertQuantDeQuantCall(self, v, is_per_channel);
}
}
Expand Down