Skip to content

Commit

Permalink
[ONNX] Reduce exporter memory usage by removing intermediate values
Browse files Browse the repository at this point in the history
This commit reduces the exporter memory usage by as much as 50%.
During the shape inference step, the exporter caches the values
of intermediate tensors. This can use as much memory as the model
itself, or even more. For example, model weight tensors are often
fed to a Transpose layer, and the output of that is the same size
of the weights. This commit fixes the issue by removing the
intermediate tensor values after they are used by all consumers.
  • Loading branch information
ilyasher authored and pytorchmergebot committed May 30, 2023
1 parent b02f48b commit 35b26fd
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch/csrc/jit/passes/onnx/constant_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ c10::optional<at::Tensor> ConstantValueMap::GetValue(
return ConstantValueMap::getInstance().tensorValueMap[tensorName];
}

void ConstantValueMap::EraseValue(const std::string& tensorName) {
ConstantValueMap::getInstance().tensorValueMap.erase(tensorName);
}

std::vector<int64_t> ConstantValueMap::GetCompleteShapeInto1DInt64Vector(
const c10::SymbolicShape& shape) {
TORCH_INTERNAL_ASSERT(shape.isComplete());
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/passes/onnx/constant_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ConstantValueMap {
static void SetValue(const std::string& tensorName, const at::Tensor& value);
static bool HasValue(const std::string& tensorName);
static c10::optional<at::Tensor> GetValue(const std::string& tensorName);
static void EraseValue(const std::string& tensorName);

static std::vector<int64_t> GetCompleteShapeInto1DInt64Vector(
const c10::SymbolicShape& shape);
Expand Down
35 changes: 35 additions & 0 deletions torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1821,6 +1821,40 @@ void FetchBlockInputMetadataFromParent(Block* b) {
}
}

void RemoveProcessedInputs(const Node* n) {
// After processing a node for shape inference, remove intermediate tensors
// that are stored in ConstantValueMap to reduce memory usage.
// This will only remove tensors that are no longer needed by any other node.

// Returns whether a node was already processed for shape inference.
const auto isNodeProcessed = [](const Node* node) {
const auto& outputs = node->outputs();
return std::any_of(outputs.begin(), outputs.end(), [](const Value* output) {
// Assumes shape inference can at least determine the rank of the outputs.
// If this assumption is wrong, some intermediate tensors will only be
// deleted once shape inference is completed for the entire graph.
return ConstantValueMap::HasRank(output->debugName());
});
};

// An input value is no longer needed if all of its consumer nodes
// have already been processed.
const auto isValueNoLongerNeeded = [isNodeProcessed](const Value* input) {
const auto& uses = input->uses();
return std::all_of(
uses.begin(), uses.end(), [isNodeProcessed](const Use& use) {
return isNodeProcessed(use.user);
});
};

for (const auto* input : n->inputs()) {
if (ConstantValueMap::HasValue(input->debugName()) &&
isValueNoLongerNeeded(input)) {
ConstantValueMap::EraseValue(input->debugName());
}
}
}

void ONNXShapeTypeInference(
Block* b,
const ParamMap& params_dict,
Expand Down Expand Up @@ -1850,6 +1884,7 @@ void ONNXShapeTypeInference(
ONNXShapeTypeInference(subblock, params_dict, opset_version);
}
ONNXShapeTypeInference(n, params_dict, opset_version);
RemoveProcessedInputs(n);
}
}

Expand Down

0 comments on commit 35b26fd

Please sign in to comment.