Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions caffe2/quantization/server/activation_distribution_observer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,13 @@ void OutputMinMaxObserver::Stop() {
OutputMinMaxNetObserver::OutputMinMaxNetObserver(
NetBase* subject,
const string& out_file_name,
int dump_freq)
int dump_freq,
string delimiter)
: NetObserver(subject),
dump_freq_(dump_freq),
cnt_(0),
out_file_name_(out_file_name) {
out_file_name_(out_file_name),
delimiter_(delimiter) {
VLOG(2) << out_file_name;
min_max_infos_.resize(subject->GetOperators().size());
int i = 0;
Expand Down Expand Up @@ -184,15 +186,15 @@ void OutputMinMaxNetObserver::DumpAndReset_(
op_info->tensor_infos[i];

ostringstream ost;
ost << op_index << " " << op_info->type << " " << i << " "
<< tensor_info.name << " ";
ost << op_index << delimiter_ << op_info->type << delimiter_ << i
<< delimiter_ << tensor_info.name << delimiter_;
if (print_total_min_max) {
ost << tensor_info.total_min << " " << tensor_info.total_max;
ost << tensor_info.total_min << delimiter_ << tensor_info.total_max;
} else {
ost << tensor_info.min << " " << tensor_info.max;
ost << tensor_info.min << delimiter_ << tensor_info.max;
}

LOG(INFO) << this << " " << ost.str();
LOG(INFO) << this << delimiter_ << ost.str();
f << ost.str() << endl;

op_info->tensor_infos[i].min = numeric_limits<float>::max();
Expand Down Expand Up @@ -385,12 +387,32 @@ HistogramNetObserver::HistogramNetObserver(
const string& out_file_name,
int nbins,
int dump_freq,
bool mul_nets)
bool mul_nets,
string op_filter,
string delimiter)
: NetObserver(subject),
dump_freq_(dump_freq),
cnt_(0),
mul_nets_(mul_nets),
op_filter_(op_filter),
delimiter_(delimiter),
out_file_name_(out_file_name) {
net_name_ = subject->Name();
if (op_filter != "") {
bool has_op = false;
for (auto* op : subject->GetOperators()) {
if (op->debug_def().type() == op_filter) {
has_op = true;
break;
}
}
if (!has_op) {
LOG(INFO) << "Net " << net_name_ << " doesn't include operator "
<< op_filter;
return;
}
}

hist_infos_.resize(subject->GetOperators().size());

int i = 0;
Expand All @@ -414,8 +436,12 @@ HistogramNetObserver::HistogramNetObserver(
void HistogramNetObserver::DumpAndReset_(
const string& out_file_name,
bool print_total_min_max) {
if (hist_infos_.size() == 0) {
return;
}
stringstream file_name;
file_name << out_file_name;
LOG(INFO) << "Dumping histograms of net " << net_name_ << " in " << this;
if (mul_nets_) {
file_name << ".";
file_name << this;
Expand Down Expand Up @@ -447,16 +473,17 @@ void HistogramNetObserver::DumpAndReset_(
}

ostringstream ost;
ost << op_index << " " << info->min_max_info.type << " " << i << " "
<< info->min_max_info.tensor_infos[i].name << " " << hist->Min()
<< " " << hist->Max() << " " << hist->GetHistogram()->size();
ost << op_index << delimiter_ << info->min_max_info.type << delimiter_
<< i << delimiter_ << info->min_max_info.tensor_infos[i].name
<< delimiter_ << hist->Min() << delimiter_ << hist->Max()
<< delimiter_ << hist->GetHistogram()->size();

for (uint64_t c : *hist->GetHistogram()) {
ost << " " << c;
ost << delimiter_ << c;
}

if (print_total_min_max) {
LOG(INFO) << this << " " << ost.str();
LOG(INFO) << this << delimiter_ << ost.str();
}

f << ost.str() << endl;
Expand All @@ -466,11 +493,12 @@ void HistogramNetObserver::DumpAndReset_(
}
}
}
f.flush();
f.close();
}

HistogramNetObserver::~HistogramNetObserver() {
DumpAndReset_(out_file_name_, true);
DumpAndReset_(out_file_name_, false);
}

void HistogramNetObserver::Stop() {
Expand Down Expand Up @@ -512,12 +540,14 @@ OutputColumnMaxHistogramNetObserver::OutputColumnMaxHistogramNetObserver(
const std::vector<std::string>& observe_column_max_for_blobs,
int nbins,
int dump_freq,
bool mul_nets)
bool mul_nets,
string delimiter)
: NetObserver(subject),
dump_freq_(dump_freq),
cnt_(0),
mul_nets_(mul_nets),
out_file_name_(out_file_name) {
out_file_name_(out_file_name),
delimiter_(delimiter) {
if (observe_column_max_for_blobs.size() == 0) {
return;
}
Expand Down Expand Up @@ -591,17 +621,17 @@ void OutputColumnMaxHistogramNetObserver::DumpAndReset_(
}
ostringstream ost;
// op_idx, output_idx, blob_name, col, min, max, nbins
ost << it.first << " " << output_idx << " "
<< info->min_max_info.tensor_infos[i].name << " " << i << " "
<< hist->Min() << " " << hist->Max() << " "
<< hist->GetHistogram()->size();
ost << it.first << delimiter_ << output_idx << delimiter_
<< info->min_max_info.tensor_infos[i].name << delimiter_ << i
<< delimiter_ << hist->Min() << delimiter_ << hist->Max()
<< delimiter_ << hist->GetHistogram()->size();

// bins
for (uint64_t c : *hist->GetHistogram()) {
ost << " " << c;
ost << delimiter_ << c;
}
if (print_total_min_max) {
LOG(INFO) << this << " " << ost.str();
LOG(INFO) << this << delimiter_ << ost.str();
}
f << ost.str() << endl;
if (!print_total_min_max) {
Expand Down
15 changes: 12 additions & 3 deletions caffe2/quantization/server/activation_distribution_observer.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class OutputMinMaxNetObserver final : public NetObserver {
explicit OutputMinMaxNetObserver(
NetBase* subject,
const std::string& out_file_name,
int dump_freq = -1);
int dump_freq = -1,
string delimiter = " ");
~OutputMinMaxNetObserver();

private:
Expand All @@ -74,6 +75,7 @@ class OutputMinMaxNetObserver final : public NetObserver {

int dump_freq_, cnt_;
const std::string out_file_name_;
std::string delimiter_;
std::vector<std::shared_ptr<OutputMinMaxObserver::OperatorInfo>>
min_max_infos_;
};
Expand Down Expand Up @@ -135,7 +137,9 @@ class HistogramNetObserver final : public NetObserver {
const std::string& out_file_name,
int nbins,
int dump_freq = -1,
bool mul_nets = false);
bool mul_nets = false,
string op_filter = "",
string delimiter = " ");
~HistogramNetObserver();

private:
Expand All @@ -150,6 +154,9 @@ class HistogramNetObserver final : public NetObserver {
* files for the nets will be appended with netbase addresses.
*/
bool mul_nets_;
string net_name_;
string op_filter_;
string delimiter_;
const std::string out_file_name_;
std::vector<std::shared_ptr<HistogramObserver::Info>> hist_infos_;
};
Expand All @@ -162,7 +169,8 @@ class OutputColumnMaxHistogramNetObserver final : public NetObserver {
const std::vector<std::string>& observe_column_max_for_blobs,
int nbins,
int dump_freq = -1,
bool mul_nets = false);
bool mul_nets = false,
string delimiter = " ");
~OutputColumnMaxHistogramNetObserver();

private:
Expand All @@ -173,6 +181,7 @@ class OutputColumnMaxHistogramNetObserver final : public NetObserver {
int dump_freq_, cnt_;
bool mul_nets_;
const std::string out_file_name_;
std::string delimiter_;
std::unordered_set<std::string> col_max_blob_names_;

// {op_idx: {output_index: col_hists}}
Expand Down
49 changes: 34 additions & 15 deletions caffe2/quantization/server/pybind.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "activation_distribution_observer.h"
#include "caffe2/opt/custom/fakefp16_transform.h"
#include "caffe2_dnnlowp_utils.h"
#include "quantization_error_minimization.h"
#include "caffe2/opt/custom/fakefp16_transform.h"

namespace caffe2 {
namespace python {
Expand All @@ -20,35 +20,50 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) {

m.def(
"ObserveMinMaxOfOutput",
[](const string& min_max_file_name, int dump_freq) {
[](const string& min_max_file_name, int dump_freq, string delimiter) {
AddGlobalNetObserverCreator(
[dump_freq, min_max_file_name](NetBase* net) {
[dump_freq, min_max_file_name, delimiter](NetBase* net) {
return make_unique<OutputMinMaxNetObserver>(
net, min_max_file_name, dump_freq);
net, min_max_file_name, dump_freq, delimiter);
});
},
pybind11::arg("min_max_file_name"),
pybind11::arg("dump_freq") = -1);
pybind11::arg("dump_freq") = -1,
pybind11::arg("delimiter") = " ");

m.def(
"ObserveHistogramOfOutput",
[](const string& out_file_name, int dump_freq, bool mul_nets) {
[](const string& out_file_name,
int dump_freq,
bool mul_nets,
string op_filter,
string delimiter) {
AddGlobalNetObserverCreator(
[out_file_name, dump_freq, mul_nets](NetBase* net) {
[out_file_name, dump_freq, mul_nets, op_filter, delimiter](
NetBase* net) {
return make_unique<HistogramNetObserver>(
net, out_file_name, 2048, dump_freq, mul_nets);
net,
out_file_name,
2048,
dump_freq,
mul_nets,
op_filter,
delimiter);
});
},
pybind11::arg("out_file_name"),
pybind11::arg("dump_freq") = -1,
pybind11::arg("mul_nets") = false);
pybind11::arg("mul_nets") = false,
pybind11::arg("op_filter") = "",
pybind11::arg("delimiter") = " ");

m.def(
"AddHistogramObserver",
[](const string& net_name,
const string& out_file_name,
int dump_freq,
bool mul_nets) {
bool mul_nets,
string delimiter) {
Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace();
CAFFE_ENFORCE(gWorkspace);
CAFFE_ENFORCE(
Expand All @@ -59,23 +74,25 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) {
const Observable<NetBase>::Observer* observer = nullptr;

observer = net->AttachObserver(make_unique<HistogramNetObserver>(
net, out_file_name, 2048, dump_freq, mul_nets));
net, out_file_name, 2048, dump_freq, mul_nets, delimiter));

CAFFE_ENFORCE(observer != nullptr);
return pybind11::cast(observer);
},
pybind11::arg("net_name"),
pybind11::arg("out_file_name"),
pybind11::arg("dump_freq") = -1,
pybind11::arg("mul_nets") = false);
pybind11::arg("mul_nets") = false,
pybind11::arg("delimiter") = " ");

m.def(
"AddOutputColumnMaxHistogramObserver",
[](const string& net_name,
const string& out_file_name,
const std::vector<std::string>& observe_column_max_for_blobs,
int dump_freq,
bool mul_nets) {
bool mul_nets,
string delimiter) {
Workspace* gWorkspace = caffe2::python::GetCurrentWorkspace();
CAFFE_ENFORCE(gWorkspace);
CAFFE_ENFORCE(
Expand All @@ -92,7 +109,8 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) {
observe_column_max_for_blobs,
2048,
dump_freq,
mul_nets));
mul_nets,
delimiter));

CAFFE_ENFORCE(observer != nullptr);
return pybind11::cast(observer);
Expand All @@ -101,7 +119,8 @@ PYBIND11_MODULE(dnnlowp_pybind11, m) {
pybind11::arg("out_file_name"),
pybind11::arg("observe_column_max_for_blobs"),
pybind11::arg("dump_freq") = -1,
pybind11::arg("mul_nets") = false);
pybind11::arg("mul_nets") = false,
pybind11::arg("delimiter") = " ");

m.def(
"ChooseQuantizationParams",
Expand Down