Skip to content
Permalink
Browse files

Optimize zero length input (#31602)

Summary:
Pull Request resolved: pytorch/pytorch#31602

Pull Request resolved: #3943

Zero length input is something we hit fairly frequently in practice. Previous handling of global TensorPool involves two locks per input (acquire and reclaim). Here we use a specialized anchor tensor to host zero length input. Note that it is only padded to max sequence length. If necessary, an easy extension can be added to pad to max `InputPlaceholder.getType().size()`.

Reviewed By: jfix71

Differential Revision: D19192467

fbshipit-source-id: cafdc1eb7bf9b9d6ead04a0243b0be838f6b71cd
  • Loading branch information
yinghai authored and facebook-github-bot committed Dec 27, 2019
1 parent 2459e30 commit 34eced91e556e283ea70b15c05b72a469d3838fc
@@ -135,6 +135,12 @@ std::pair<bool, onnxStatus> Event::waitFor(size_t timeoutMs) {
return {/*signalled*/ true, status_};
}

void Graph::setZeroLengthSequence(dim_t maxSeqLength) {
Type ty(ElemKind::Int64ITy, {maxSeqLength});
zeroLengthSequence_.reset(ty);
zeroLengthSequence_.zero();
}

onnxStatus Graph::setIOAndRun(uint32_t inputsCount,
const onnxTensorDescriptorV1 *inputDescriptors,
uint32_t outputsCount,
@@ -201,6 +207,11 @@ onnxStatus Graph::setIOAndRun(uint32_t inputsCount,
// remembers the actual size of the input.
ctx->getPlaceholderBindings()->insert(
inPhPtr, Tensor(inOnnxBuffer, inPhPtr->getType(), onnxBytes));
} else if (!inOnnxBuffer && inPhPtr->getType()->size() <=
zeroLengthSequence_.getType().size()) {
ctx->getPlaceholderBindings()->insert(
inPhPtr, Tensor((void *)(zeroLengthSequence_.getUnsafePtr()),
inPhPtr->getType()));
} else {
Tensor *inputTensor = tensorPool_.get(inPhPtr->getType());
if (!inputTensor) {
@@ -132,6 +132,7 @@ class Graph {
virtual onnxStatus initGraph(const void *onnxModel, size_t onnxModelSize,
uint32_t weightCount,
const onnxTensorDescriptorV1 *weightDescriptors,
uint32_t maxSeqLength,
void *deferedBlobReader) = 0;

virtual onnxStatus run(std::unique_ptr<ExecutionContext> ctx,
@@ -161,6 +162,12 @@ class Graph {
/// An object pool for tensors, to share allocations.
TensorPool tensorPool_;

/// An anchor tensor specialized for zero length indices
Tensor zeroLengthSequence_;

/// Set the zero length tensor
void setZeroLengthSequence(dim_t maxSeqLength);

private:
/// inference dump counter
std::atomic<size_t> ioDumpCounter_{0};
@@ -162,9 +162,11 @@ onnxStatus HostManagerBackend::removeNetwork(const Graph *graph) {
return ONNXIFI_STATUS_SUCCESS;
}

onnxStatus HostManagerGraph::initGraph(
const void *onnxModel, size_t onnxModelSize, uint32_t weightCount,
const onnxTensorDescriptorV1 *weightDescriptors, void *deferedBlobReader) {
onnxStatus
HostManagerGraph::initGraph(const void *onnxModel, size_t onnxModelSize,
uint32_t weightCount,
const onnxTensorDescriptorV1 *weightDescriptors,
uint32_t maxSeqLength, void *deferedBlobReader) {

netName_ = strFormat("onnxifi_function_%lu", makeUniqueGraphId());

@@ -180,6 +182,7 @@ onnxStatus HostManagerGraph::initGraph(
onnxInputToPlaceholder_ = loader->getInputVarsMapping();
onnxOutputToPlaceholder_ = loader->getOutputVarsMapping();

setZeroLengthSequence(maxSeqLength);
// Make sure the pool is ready to go.
for (auto &obj : onnxInputToPlaceholder_) {
tensorPool_.reserve(obj.second->getType(), 10);
@@ -64,6 +64,7 @@ class HostManagerGraph : public Graph {
onnxStatus initGraph(const void *onnxModel, size_t onnxModelSize,
uint32_t weightCount,
const onnxTensorDescriptorV1 *weightDescriptors,
uint32_t maxSeqLengths,
void *deferedBlobReader) override;

/// Async run HostManagerGraph with the given ExecutionContext \p ctx then
@@ -43,9 +43,11 @@ void computeModelHash(const void *onnxModel, size_t onnxModelSize,
}
} // namespace

onnxStatus InlineGraph::initGraph(
const void *onnxModel, size_t onnxModelSize, uint32_t weightCount,
const onnxTensorDescriptorV1 *weightDescriptors, void *deferedBlobReader) {
onnxStatus
InlineGraph::initGraph(const void *onnxModel, size_t onnxModelSize,
uint32_t weightCount,
const onnxTensorDescriptorV1 *weightDescriptors,
uint32_t maxSeqLength, void * /*unused */) {
function_ = executionEngine_.getModule().createFunction("function");

std::unique_ptr<ONNXIFIModelLoader> loader =
@@ -59,6 +61,7 @@ onnxStatus InlineGraph::initGraph(
saveOnnxifiModel(function_);
}

setZeroLengthSequence(maxSeqLength);
computeModelHash(onnxModel, onnxModelSize, modelHash_);
optimize(function_, CompilationMode::Infer);

@@ -36,7 +36,7 @@ class InlineGraph : public Graph {
onnxStatus initGraph(const void *onnxModel, size_t onnxModelSize,
uint32_t weightCount,
const onnxTensorDescriptorV1 *weightDescriptors,
void *deferedBlobReader) override;
uint32_t maxSeqLength, void *deferedBlobReader) override;

onnxStatus run(std::unique_ptr<ExecutionContext> ctx, EventPtr outputEvent,
onnxTraceEventList *traceEvents) override;
@@ -400,7 +400,7 @@ GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxInitGraph)(
onnxBackend backend, const uint64_t *auxPropertiesList,
size_t onnxModelSize, const void *onnxModel, uint32_t weightsCount,
const onnxTensorDescriptorV1 *weightDescriptors, onnxGraph *graph,
void *deferredBlobReader) {
uint32_t maxSeqLength, void *deferredBlobReader) {
if (!onnxModel || (!weightDescriptors && weightsCount) || !graph) {
return ONNXIFI_STATUS_INVALID_POINTER;
}
@@ -425,8 +425,9 @@ GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxInitGraph)(
}

auto *glowGraph = manager.createGraph(glowBackend, quantizationMode);
auto ret = glowGraph->initGraph(onnxModel, onnxModelSize, weightsCount,
weightDescriptors, deferredBlobReader);
auto ret =
glowGraph->initGraph(onnxModel, onnxModelSize, weightsCount,
weightDescriptors, maxSeqLength, deferredBlobReader);
if (ret != ONNXIFI_STATUS_SUCCESS) {
manager.release(glowGraph);
return ret;
Submodule foxi updated 1 files
+2 −0 foxi/onnxifi.h

0 comments on commit 34eced9

Please sign in to comment.
You can’t perform that action at this time.