Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure calibrator don't miss last batch #20234

Merged
merged 12 commits into from
Jul 2, 2018
Merged
2 changes: 1 addition & 1 deletion tensorflow/contrib/tensorrt/convert/convert_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ tensorflow::Status ConvertCalibGraphToInferGraph(
"Need to run graph with calibration data first!");
}
if (cres->calibrator_) {
cres->calibrator_->setDone();
cres->calibrator_->waitAndSetDone();
cres->thr_->join();
const auto& calibration_table =
cres->calibrator_->getCalibrationTableAsString();
Expand Down
37 changes: 25 additions & 12 deletions tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,29 @@ TRTInt8Calibrator::TRTInt8Calibrator(
: batch_size_(batch_size),
done_(false),
dev_buffers_(dev_buffers),
// Make sure setBatch() waits until getBatch() is called (the first time).
calib_running_(true),
batch_is_set_(false),
engine_name_(engine_name) {}

TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
: batch_size_(0),
done_(false),
done_(true),
calib_running_(false),
batch_is_set_(false),
calibration_table_(calib_data) {}

bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
const cudaStream_t stream) {
tensorflow::mutex_lock lock(cond_mtx_);
// wait while calibration is running.
while ((calib_running_ || batch_is_set_) && !done_) {
cond_.wait(lock);
}

// Wait while the queue is full or calibration is running.
while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
if (done_) return false;
CHECK(!calib_running_ && !batch_is_set_);
VLOG(1) << "Set Batch Waiting finished";

// Sets the batch.
for (const auto it : data) {
auto devptr = dev_buffers_.find(it.first);
if (devptr == dev_buffers_.end()) {
Expand All @@ -76,8 +78,8 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
}

// TODO(Sami, aaorey): Find an alternative way!
cudaStreamSynchronize(
stream); // we have to wait for the stream before returning!
// we have to wait for the stream before returning!
cudaStreamSynchronize(stream);
batch_is_set_ = true;
cond_.notify_all();
return true;
Expand All @@ -86,28 +88,39 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
int num_bindings) {
tensorflow::mutex_lock lock(cond_mtx_);
// Notify finish of last round of calibration.
calib_running_ = false;
cond_.notify_all();
// wait until new batch arrives
while ((!batch_is_set_ && !done_)) {
cond_.wait(lock);
}

// Wait until new batch arrives
while ((!batch_is_set_ && !done_)) cond_.wait(lock);
if (done_) return false;

// Gets the batch
for (int i = 0; i < num_bindings; i++) {
auto it = dev_buffers_.find(names[i]);
if (it == dev_buffers_.end()) {
LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
<< names[i] << "' at position " << i;
}

bindings[i] = it->second.first;
}
batch_is_set_ = false;
calib_running_ = true;
return true;
}

void TRTInt8Calibrator::waitAndSetDone() {
tensorflow::mutex_lock lock(cond_mtx_);
// Wait while the queue is full or calibration is running, so we don't miss
// the last batch.
while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
if (!done_) {
done_ = true;
cond_.notify_all();
}
}

const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
if (calibration_table_.empty()) return nullptr;
length = calibration_table_.size();
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ namespace tensorrt {

struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
public:
// Construct a calibrator for future calibration.
TRTInt8Calibrator(
const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers,
int batch_size, string engine_name);

// Construct a finalized calibrator where we don't need to run calibration any
// more, as the calibration data is provided.
TRTInt8Calibrator(const string& calibration_data);

~TRTInt8Calibrator();
Expand All @@ -52,6 +55,11 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
bool setBatch(const std::unordered_map<string, void*>& data,
const cudaStream_t stream);

// Wait until the last batch is consumed by the calibrator and set done.
void waitAndSetDone();

// Notify that calibration is done and future batches provided by setBatch()
// will be ignored.
void setDone();

// If not null, calibration is skipped.
Expand Down