Skip to content

Commit

Permalink
Make a function to be const member function in partitioner if possibl…
Browse files Browse the repository at this point in the history
…e. (#3150)

Summary:
This PR is created from the following discussion.
#3142 (comment)

Basically,  in Partitioner, some functions were not const member function (although it won't modify any value in this class), and they can't be used in const class objects.  This PR marks those functions as const.

Documentation:

[Optional Fixes #issue]
Pull Request resolved: #3150

Test Plan: Please see a detailed explanation of how to fill out the fields in the relevant sections in PULL_REQUEST.md.

Differential Revision: D15926744

Pulled By: beicy

fbshipit-source-id: 1cf5e9a69719efa586f86b68c185d2375c3893b7
  • Loading branch information
Man Wang authored and facebook-github-bot committed Jun 20, 2019
1 parent 0ff81fa commit cf68258
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
21 changes: 15 additions & 6 deletions include/glow/Partitioner/Partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class NodeToFunctionMap {
functionToBackendName_[F] = backendName;
}

std::string getPartitionBackendName(Function *F) {
std::string getPartitionBackendName(Function *F) const {
DCHECK(functionToBackendName_.find(F) != functionToBackendName_.end())
<< "Unknown partition in Function: " << F->getName().str();
return functionToBackendName_.find(F)->second;
Expand All @@ -99,8 +99,11 @@ class NodeToFunctionMap {
const FunctionList &getPartitions() const { return functions_; }

/// Get the list of logical device ID related to this function \p F.
std::vector<DeviceIDTy> getLogicalDeviceIDList(Function *F) {
return logicalDeviceIDMap_[F];
const std::vector<DeviceIDTy> getLogicalDeviceIDList(Function *F) const {
if (logicalDeviceIDMap_.find(F) == logicalDeviceIDMap_.end()) {
return {};
}
return logicalDeviceIDMap_.at(F);
}

void appendLogicalDeviceID(Function *F, DeviceIDTy id) {
Expand Down Expand Up @@ -144,7 +147,12 @@ class NodeToFunctionMap {
}

/// Get the memory consumption for a partition \p func.
GraphMemInfo getGraphMemInfo(Function *func) { return partitionCost_[func]; }
GraphMemInfo getGraphMemInfo(Function *func) const {
if (partitionCost_.find(func) == partitionCost_.end()) {
return GraphMemInfo{};
}
return partitionCost_.find(func)->second;
}
};

/// Given a module, partitions each of the its functions into multiple ones
Expand Down Expand Up @@ -231,11 +239,12 @@ class Partitioner {
/// Check if \p partitions satisfies number of physical devices restriction.
/// I.e. check if the number of logical devices is less than the given
/// physical devices.
llvm::Error logicalDevicesValidation(NodeToFunctionMap &partitions);
llvm::Error
logicalDevicesValidation(const NodeToFunctionMap &partitions) const;

/// Check if the memory usage of each partition meets the physical device
/// memory restriction.
llvm::Error memoryUsageValidation(NodeToFunctionMap &partitions);
llvm::Error memoryUsageValidation(const NodeToFunctionMap &partitions) const;

/// Duplicates all networks in the module order to saturate the Host.
void saturateHost(unsigned logicalDeviceCount);
Expand Down
17 changes: 10 additions & 7 deletions lib/Partitioner/Partitioner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ void Partitioner::dumpDAG(llvm::StringRef dotFilename) const {
return;
}

llvm::Error
Partitioner::logicalDevicesValidation(NodeToFunctionMap &partitions) {
llvm::Error Partitioner::logicalDevicesValidation(
const NodeToFunctionMap &partitions) const {
std::map<std::string, std::set<DeviceIDTy>> partitionsNum;
for (auto &func : partitions.getPartitions()) {
auto backendName = partitions.getPartitionBackendName(func);
Expand All @@ -117,27 +117,30 @@ Partitioner::logicalDevicesValidation(NodeToFunctionMap &partitions) {
for (size_t i = 0, e = logicalIDList.size(); i < e; i++) {
partitionsNum[backendName].insert(logicalIDList[i]);
}
auto backendNum = backendMap_.at(backendName).num;
RETURN_ERR_IF_NOT(
partitionsNum[backendName].size() <= backendMap_[backendName].num,
partitionsNum[backendName].size() <= backendNum,
llvm::formatv("Partition failed: the number of given({0}) devices({1}) "
"is fewer than the required minimal partitions({2}).",
backendName, backendMap_[backendName].num,
backendName, backendNum,
partitionsNum[backendName].size())
.str());
}
return llvm::Error::success();
}

llvm::Error Partitioner::memoryUsageValidation(NodeToFunctionMap &partitions) {
llvm::Error
Partitioner::memoryUsageValidation(const NodeToFunctionMap &partitions) const {
for (auto &func : partitions.getPartitions()) {
auto backendName = partitions.getPartitionBackendName(func);
auto usedMemSize = partitions.getGraphMemInfo(func).getTotalMemSize();
auto availableMemSize = backendMap_.at(backendName).memSize;
RETURN_ERR_IF_NOT(
usedMemSize <= backendMap_[backendName].memSize,
usedMemSize <= availableMemSize,
llvm::formatv(
"Partition failed: the memory usage({0}) of one partition exceeds "
"the available memory({1}) of given devices({2}).",
usedMemSize, backendMap_[backendName].memSize, backendName)
usedMemSize, availableMemSize, backendName)
.str());
}
return llvm::Error::success();
Expand Down

0 comments on commit cf68258

Please sign in to comment.