Skip to content
Permalink
Browse files

Add data and model parallel transformation (#3753)

Summary:
Pull Request resolved: #3753

Add support for data and model parallel transformation.

Reviewed By: nrsatish

Differential Revision: D18171649

fbshipit-source-id: 459f8c170d0b44e1d1454305fded08bdbd8d7529
  • Loading branch information...
jfix71 authored and facebook-github-bot committed Nov 8, 2019
1 parent 297c8c9 commit 55bb797c4a1234100d34c04412927388564564f3
@@ -161,6 +161,16 @@ class Node : public Named,
/// responsible to update these if need be.
void setType(unsigned idx, TypeRef ty);

/// Set the \p idx'th result type of the node, without checking if the dims of
/// the old type match the dims of the new one.
/// \note This setter only changes the type of this one
/// result. If that type is incompatible with
/// the inputs of the node, the caller is
/// responsible to update these if need be.
/// This function does not check for validity
/// of input dims and whether the result exists.
void setTypeUnsafe(unsigned idx, TypeRef ty);

/// Methods that forward to the result type (that must be valid):
/// @{
ElemKind getElementType(unsigned resNo) const;
@@ -86,6 +86,8 @@ struct NodeValue {
TypeRef getType() const;
/// Set the type of the referenced value.
void setType(TypeRef ty);
/// Set the type of the referenced value. Does not check that dims() match.
void setTypeUnsafe(TypeRef ty);

/// Methods that forward to the result type (that must be valid):
/// @{
@@ -95,6 +95,19 @@ Error executeConstantFunction(Backend &backend, Function &F,
bool executeVerticalFCWeightsSplit(Function *F, unsigned numOfChunks,
unsigned minKToSplit);

/// Represents what kind of parallelization transformation should be performed
/// by \ref parallelizeOps().
enum class ParallelTransformKind { None, Data, Model };

/// Perform data or model parallel transformation of supported Nodes in \p F.
/// \p numOfChunksMap maps Nodes to how many chunks they should be split into;
/// if not listed this falls back to \p numOfChunks. \p parOpts represents what
/// kind of parallelism to use.
bool parallelizeOps(
Function *F, const llvm::DenseMap<Node *, size_t> &numOfChunksMap,
const llvm::DenseMap<Node *, ParallelTransformKind> &parOpts,
size_t numOfChunks);

} // namespace glow

#endif // GLOW_OPTIMIZER_GRAPHOPTIMIZER_GRAPHOPTIMIZER_H
@@ -32,9 +32,13 @@ TypeRef Node::getType(unsigned idx) const {
}

void Node::setType(unsigned idx, TypeRef ty) {
assert(idx < getNumResults() && "Result number does not exist.");
assert(types_[idx]->dims() == ty->dims() &&
"Better create a new node at this point");
setTypeUnsafe(idx, ty);
}

void Node::setTypeUnsafe(unsigned idx, TypeRef ty) {
assert(idx < getNumResults() && "Result number does not exist.");
types_[idx] = ty;
}

@@ -91,6 +91,7 @@ llvm::iterator_range<NodeValueConstIterator> NodeValue::getUsers() const {

TypeRef NodeValue::getType() const { return node_->getType(resNo_); }
void NodeValue::setType(TypeRef ty) { node_->setType(resNo_, ty); }
void NodeValue::setTypeUnsafe(TypeRef ty) { node_->setTypeUnsafe(resNo_, ty); }

ElemKind NodeValue::getElementType() const {
return getType()->getElementType();
@@ -3340,3 +3340,223 @@ bool glow::executeVerticalFCWeightsSplit(Function *F, unsigned numOfChunks,

return changed;
}

/// Helper to parallelize a node \p curNode from \p F into \p numOfChunksNode
/// Nodes by slicing its inputs, creating clones of it and changing the inputs
/// of the clones to the slices, and then concatenating all of the clones
/// together and replacing \p curNode with the concat. \p inputBatchIdx is the
/// input idx from \p curNode that will be split (there may be more than one
/// input to split, but their splitDim should all have the same size).
/// \p splitDim represents what dimension to split for each of the inputs to
/// \p curNode. \p resultDim is the dimension on which we are splitting and then
/// concatenating the results. \p resultIdx represents the result index from
/// \p curNode that is being split and later concatenated.
static void parallelizeAndReplaceNode(Function *F, Node *curNode,
size_t numOfChunksNode,
size_t inputBatchIdx, size_t resultIdx,
llvm::ArrayRef<int> splitDims,
size_t resultDim) {
const int inputIdx = splitDims[inputBatchIdx];
CHECK_GE(inputIdx, 0) << "Input batch idx must be split";
const size_t batchSize = curNode->getNthInput(inputBatchIdx).dims()[inputIdx];
const size_t elemPerChunk = batchSize / numOfChunksNode;
const size_t remain = batchSize % numOfChunksNode;

std::vector<NodeValue> newNodes(numOfChunksNode);
for (size_t i = 0; i < numOfChunksNode; ++i) {
// Calculate the out type of this chunk.
const size_t sliceStart = i * elemPerChunk + std::min(i, remain);
const size_t sliceEnd = sliceStart + elemPerChunk + ((i < remain) ? 1 : 0);
std::cout << "\tChunk " << i << ": start: " << sliceStart
<< " end: " << sliceEnd << "\n";
auto outDims = curNode->dims(resultIdx).vec();
outDims[resultDim] = (sliceEnd - sliceStart);
std::cout << "outDims: ";
std::copy(outDims.begin(), outDims.end(),
std::ostream_iterator<int>(std::cout, " "));
std::cout << "\n";

// Clone the original Node, so that it keeps all of the inputs/members of
// the original Node. Then modify the output type so that its new shape is
// correct, and below change the inputs the sliced inputs.
Node *clone = curNode->clone();
clone->getNthResult(resultIdx).setTypeUnsafe(
F->getParent()->uniqueTypeWithNewShape(curNode->getType(resultIdx),
outDims));
F->addNode(clone);

// Loop over all of the inputs and slice those inputs that need to be
// sliced, and set them on the clone.
for (int j = 0, e = curNode->getNumInputs(); j < e; j++) {
int dim = splitDims[j];
if (dim == -1) {
continue;
}

NodeValue currInput = curNode->getNthInput(j);
auto sliceDimsStart = std::vector<size_t>(currInput.dims().size(), 0);
sliceDimsStart[dim] = sliceStart;
auto sliceDimsEnd = currInput.dims().vec();
sliceDimsEnd[dim] = sliceEnd;
std::cout << "start: ";
std::copy(sliceDimsStart.begin(), sliceDimsStart.end(),
std::ostream_iterator<int>(std::cout, " "));
std::cout << "\nend: ";
std::copy(sliceDimsEnd.begin(), sliceDimsEnd.end(),
std::ostream_iterator<int>(std::cout, " "));
std::cout << "\n";
std::cout << "Input name: " << currInput.getNode()->getName().str()
<< "\n";

auto *inputSlice =
F->createSlice("dp_slice." + currInput.getNode()->getName().str() +
"." + std::to_string(i),
currInput, sliceDimsStart, sliceDimsEnd);
clone->setNthInput(j, inputSlice);

newNodes[i] = clone;
}
}

// Now that we have split the node into many, concat all of the pieces back
// together and replace the original by the concat.
std::cout << "\tCreating concat\n";
auto *concat = F->createConcat("concat." + curNode->getName().str(), newNodes,
resultDim);
curNode->getNthResult(resultIdx).replaceAllUsesOfWith(concat);
}

bool glow::parallelizeOps(
Function *F, const llvm::DenseMap<Node *, size_t> &numOfChunksMap,
const llvm::DenseMap<Node *, ParallelTransformKind> &parOpts,
size_t numOfChunks) {
// Since we will be transforming the original list of nodes, reverse iterate.
auto &nodes = F->getNodes();
size_t numProcessedNodes = 0;
for (auto it = nodes.rbegin(), e = nodes.rend(); it != e; it++) {
Node *curNode = &*it;
auto numOfChunksIt = numOfChunksMap.find(curNode);
if (numOfChunksIt != numOfChunksMap.end()) {
numOfChunks = numOfChunksIt->second;
}

ParallelTransformKind parTransformMode = ParallelTransformKind::None;
auto parOptsIt = parOpts.find(curNode);
if (parOptsIt != parOpts.end()) {
parTransformMode = parOptsIt->second;
++numProcessedNodes;
}

std::cout << "Node name: " << curNode->getName().str() << "\n";

// Use this vector to communicate what dims to split to
// parallelizeAndReplaceNode(). -1 represents not splitting at all.
llvm::SmallVector<int, 3> splitDims(curNode->getNumInputs(), -1);
switch (parTransformMode) {
case ParallelTransformKind::Data: {
switch (curNode->getKind()) {
case Kinded::Kind::FullyConnectedNodeKind: {
splitDims[FullyConnectedNode::InputIdx] = 0;
parallelizeAndReplaceNode(F, curNode, numOfChunks,
FullyConnectedNode::InputIdx,
FullyConnectedNode::ResultIdx, splitDims, 0);
break;
}
case Kinded::Kind::AddNodeKind: {
splitDims[AddNode::LHSIdx] = 0;
splitDims[AddNode::RHSIdx] = 0;
parallelizeAndReplaceNode(F, curNode, numOfChunks, AddNode::LHSIdx,
AddNode::ResultIdx, splitDims, 0);
break;
}
case Kinded::Kind::MulNodeKind: {
splitDims[AddNode::LHSIdx] = 0;
splitDims[AddNode::RHSIdx] = 0;
parallelizeAndReplaceNode(F, curNode, numOfChunks, MulNode::LHSIdx,
MulNode::ResultIdx, splitDims, 0);
break;
}
case Kinded::Kind::SigmoidNodeKind: {
splitDims[SigmoidNode::InputIdx] = 0;
parallelizeAndReplaceNode(F, curNode, numOfChunks,
SigmoidNode::InputIdx, SigmoidNode::ResultIdx,
splitDims, 0);
break;
}
case Kinded::Kind::TanhNodeKind: {
splitDims[TanhNode::InputIdx] = 0;
parallelizeAndReplaceNode(F, curNode, numOfChunks, TanhNode::InputIdx,
TanhNode::ResultIdx, splitDims, 0);
break;
}
case Kinded::Kind::TransposeNodeKind: {
splitDims[TransposeNode::InputIdx] = 0;
unsigned_t resultDim = cast<TransposeNode>(curNode)->getShuffle()[0];
parallelizeAndReplaceNode(
F, curNode, numOfChunks, TransposeNode::InputIdx,
TransposeNode::ResultIdx, splitDims, resultDim);
break;
}
case Kinded::Kind::ReluNodeKind: {
splitDims[ReluNode::InputIdx] = 0;
parallelizeAndReplaceNode(F, curNode, numOfChunks, ReluNode::InputIdx,
ReluNode::ResultIdx, splitDims, 0);
break;
}
case Kinded::Kind::ConvertToNodeKind: {
splitDims[ConvertToNode::InputIdx] = 0;
parallelizeAndReplaceNode(F, curNode, numOfChunks,
ConvertToNode::InputIdx,
ConvertToNode::ResultIdx, splitDims, 0);
break;
}
default:
std::cout << "Op Type: " << curNode->getKindName()
<< "not yet supported" << std::endl;
break;
}
break;
}

case ParallelTransformKind::Model: {
switch (curNode->getKind()) {
case Kinded::Kind::FullyConnectedNodeKind: {
splitDims[FullyConnectedNode::WeightsIdx] = 1;
splitDims[FullyConnectedNode::BiasIdx] = 0;
parallelizeAndReplaceNode(F, curNode, numOfChunks,
FullyConnectedNode::WeightsIdx,
FullyConnectedNode::ResultIdx, splitDims, 1);
break;
}
case Kinded::Kind::ReluNodeKind: {
if (curNode->getNthInput(ReluNode::InputIdx).dims().size() < 2) {
break;
}
splitDims[ReluNode::InputIdx] = 1;
parallelizeAndReplaceNode(F, curNode, numOfChunks, ReluNode::InputIdx,
ReluNode::ResultIdx, splitDims, 1);
break;
}
default:
break;
}
break;
}

case ParallelTransformKind::None:
break;
}

std::cout << "Done : " << curNode->getName().str() << "\n";
}

// Because we transformed Node types unsafely, make sure all types of the
// Function still are valid.
bool ret = F->verify();
DCHECK(ret) << "Verification issue post parallelization";

const bool allProcessed = numProcessedNodes == parOpts.size();
DCHECK(allProcessed) << "Not all Nodes specified in parOpts were processed.";

return ret && allProcessed;
}
@@ -364,6 +364,16 @@ std::unordered_set<Tensor *> cloneFunInsideFun(FunctionTensorPair FTP,
return resultTensors;
}

unsigned countNodeKind(Function *F, Kinded::Kind kind) {
unsigned count = 0;
for (auto &n : F->getNodes()) {
if (n.getKind() == kind) {
count++;
}
}
return count;
}

void inferIntLookupTableNet(Tensor *input, Tensor *out,
llvm::ArrayRef<int8_t> table,
llvm::StringRef kind) {
@@ -287,6 +287,9 @@ std::unordered_set<Tensor *> cloneFunInsideFun(FunctionTensorPair FTP,
CompilationContext &cctx,
unsigned parallelCount);

/// \returns the number of nodes in \p F of kind \p kind.
unsigned countNodeKind(Function *F, Kinded::Kind kind);

void inferConvNet(Tensor *inputs, Tensor *filter, Tensor *bias, Tensor *out,
llvm::StringRef kind);

0 comments on commit 55bb797

Please sign in to comment.
You can’t perform that action at this time.