Skip to content

Commit

Permalink
Moved runtimeBundle to backendUtils and made it a class
Browse files Browse the repository at this point in the history
  • Loading branch information
gcatron committed Nov 20, 2018
1 parent f8ea2b9 commit a370c0d
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 131 deletions.
68 changes: 53 additions & 15 deletions include/glow/Backends/BackendUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,64 @@
#include "glow/IR/IR.h"

namespace glow {
/// At compile time condense constants to a single block of memory.
/// This allows the graph to go away after compile time.
/// Allocates a block of memory of size \p constantMaxSize then walks the given
/// function \p F and and copies weights to their address as specified by
/// offsets contained in \p symbolTable.
uint8_t *collectConstants(
const IRFunction *F, uint64_t constantMaxSize,
const std::unordered_map<std::string, runtime::RuntimeSymbolInfo>
&symbolTable);
/// Helper function to retrieve offset for Value: \p v from \p symbolTable.
size_t
getValueOffset(Value *v,
const std::unordered_map<std::string, runtime::RuntimeSymbolInfo>
&symbolTable);
namespace runtime {
/// Contains information for initialization and handling of symbol at runtime.
struct RuntimeSymbolInfo {
/// The size in bytes.
size_t size;
/// Offset in bytes from the base address.
size_t offset;
/// Type of symbol.
Type type;
};

/// Contains the information needed to be passed forward from compile time to
/// runtime. In order to allocate and initialize memory.
class RuntimeBundle {
/// Map from symbol name to a RuntimeSymbolInfo.
std::unordered_map<std::string, RuntimeSymbolInfo> symbolTable_;
/// Pointer to memory containing the weights for execution.
uint8_t *constants_;
/// Amount of memory needed for weights.
const size_t constantWeightVarsMemSize_;
/// Amount of memory needed for mutable vars.
const size_t mutableWeightVarsMemSize_;
/// Amount of memory needed for activations.
const size_t activationsMemSize_;

public:
/// Get Constant Weights memory size.
size_t getConstantWeightSize() const { return constantWeightVarsMemSize_; }
/// Get Mutable Weights memory size.
size_t getMutableWeightSize() const { return mutableWeightVarsMemSize_; }
/// Get Activations Weights memory size.
size_t getActivationsSize() const { return activationsMemSize_; }
/// Get pointer to memory block of constants.
uint8_t *getConstants() const { return constants_; }
/// Helper function, gets offset of \p v.
size_t getValueOffset(const Named *v) const;
/// Helper function, gets symbol info for \p v.
RuntimeSymbolInfo getSymbolInfo(const Named *v) const;
/// At compile time condense constants to a single block of memory.
/// This allows the graph to go away after compile time.
/// Allocates a block of memory of size \p constantMaxSize then walks the
/// given function \p F and and copies weights to their address as specified
/// by offsets contained in symbolTable_.
void collectConstants(const IRFunction *F);
RuntimeBundle(std::unordered_map<std::string, RuntimeSymbolInfo> &symbolTable,
size_t constWeight, size_t mutableWeight, size_t activations)
: symbolTable_(std::move(symbolTable)),
constantWeightVarsMemSize_(constWeight),
mutableWeightVarsMemSize_(mutableWeight),
activationsMemSize_(activations) {}
};
} // namespace runtime
/// Computes offsets and total allocation for Constants, Placeholders, and
/// Activations to build runtime symbol table. Returns RuntimeBundle.
runtime::RuntimeBundle
generateRuntimeBundle(const IRFunction &F, MemoryAllocator &constantAllocator,
MemoryAllocator &placeholderAllocator,
MemoryAllocator &activationsAllocator);
} // end namespace glow

} // end namespace glow
#endif // GLOW_BACKENDS_BACKENDUTILS_H
31 changes: 0 additions & 31 deletions include/glow/Backends/CompiledFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,6 @@
namespace glow {

class Context;
namespace runtime {
/// RuntimeSymbolInfo
/// Contains information for initialization and handling of symbol at runtime.
struct RuntimeSymbolInfo {
/// The size in bytes.
size_t size;
/// Offset in bytes from the base address.
size_t offset;
/// Type of symbol.
Type type;
};
/// Runtime Bundle
/// Contains the information needed to be passed forward from compile time to
/// runtime. In order to allocate and initialize memory.
struct RuntimeBundle {
/// Map from symbol name to a RuntimeSymbolInfo.
std::unordered_map<std::string, RuntimeSymbolInfo> symbolTable;
/// Pointer to memory containing the weights for execution.
uint8_t *constants;
/// Amount of memory needed for weights.
const size_t constantWeightVarsMemSize;
/// Amount of memory needed for mutable vars.
const size_t mutableWeightVarsMemSize;
/// Amount of memory needed for activations.
const size_t activationsMemSize;
RuntimeBundle(size_t constWeight, size_t mutableWeight, size_t activations)
: constantWeightVarsMemSize(constWeight),
mutableWeightVarsMemSize(mutableWeight),
activationsMemSize(activations) {}
};
} // end namespace runtime
/// Interface for executing a compiled function.
class CompiledFunction {
public:
Expand Down
46 changes: 21 additions & 25 deletions lib/Backends/BackendUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,41 +20,40 @@ using namespace glow;
using llvm::cast;
using llvm::isa;

uint8_t *glow::collectConstants(
const IRFunction *F, uint64_t constantMaxSize,
const std::unordered_map<std::string, runtime::RuntimeSymbolInfo>
&symbolTable) {
void glow::runtime::RuntimeBundle::collectConstants(const IRFunction *F) {

// At compile time condense constants to a single block of memory.
// This allows the graph to go away after compile time.
// If there are no constants return nullptr.
if (constantMaxSize == 0) {
return nullptr;
if (constantWeightVarsMemSize_ == 0) {
constants_ = nullptr;
return;
}
uint8_t *baseConstantWeightVarsStore =
(uint8_t *)alignedAlloc(constantMaxSize, TensorAlignment);
constants_ =
(uint8_t *)alignedAlloc(constantWeightVarsMemSize_, TensorAlignment);
for (auto &v : F->getGraph()->getParent()->getConstants()) {
assert(isa<WeightVar>(F->getWeightForNode(v)));
auto *w = cast<WeightVar>(F->getWeightForNode(v));
auto payload = v->getPayload().getUnsafePtr();
auto numBytes = w->getSizeInBytes();
auto it = symbolTable.find(std::string(w->getName()));
assert(it != symbolTable.end() && "Symbol not found.");
auto addr = it->second.offset;
auto addr = getValueOffset(v);
// Copy weight to offset.
memcpy(baseConstantWeightVarsStore + addr, payload, numBytes);
memcpy(constants_ + addr, payload, numBytes);
}
return baseConstantWeightVarsStore;
}

/// Helper function, gets offset of \p v from \p symbolTable.
size_t glow::getValueOffset(
Value *v, const std::unordered_map<std::string, runtime::RuntimeSymbolInfo>
&symbolTable) {
auto it = symbolTable.find(std::string(v->getName()));
assert(it != symbolTable.end() && "Symbol not found.");
size_t glow::runtime::RuntimeBundle::getValueOffset(const Named *v) const {
auto it = symbolTable_.find(std::string(v->getName()));
assert(it != symbolTable_.end() && "Symbol not found.");
return it->second.offset;
}

runtime::RuntimeSymbolInfo
runtime::RuntimeBundle::getSymbolInfo(const Named *v) const {
auto it = symbolTable_.find(std::string(v->getName()));
assert(it != symbolTable_.end() && "Symbol not found.");
return it->second;
}
runtime::RuntimeBundle
glow::generateRuntimeBundle(const IRFunction &F,
MemoryAllocator &constantAllocator,
Expand All @@ -64,7 +63,6 @@ glow::generateRuntimeBundle(const IRFunction &F,
// Symbol table mapping symbol name to offset for runtime.
std::unordered_map<std::string, runtime::RuntimeSymbolInfo> symbolTable;
// Compute the offsets for Constants.

for (auto &v : F.getGraph()->getParent()->getConstants()) {
assert(isa<WeightVar>(F.getWeightForNode(v)) && "Expected WeightVar");
auto *w = cast<WeightVar>(F.getWeightForNode(v));
Expand All @@ -79,7 +77,6 @@ glow::generateRuntimeBundle(const IRFunction &F,
auto constantMaxSize = constantAllocator.getMaxMemoryUsage();

// Compute the offsets for Placeholders.

for (auto &v : F.getGraph()->getParent()->getPlaceholders()) {
// Get the WeightVar for each Placeholder to calculate offsets.
assert(isa<WeightVar>(F.getWeightForNode(v)) && "Expected WeightVar");
Expand Down Expand Up @@ -125,7 +122,7 @@ glow::generateRuntimeBundle(const IRFunction &F,
assert(symbolTable.count(std::string(tvSource->getName())) &&
"Source allocation not found!");
runtime::RuntimeSymbolInfo symbol;
symbol.offset = getValueOffset(tvSource, symbolTable) +
symbol.offset = symbolTable[std::string(tvSource->getName())].offset +
(offsetLength * TV->getType()->getElementSize());
symbol.size = TV->getSizeInBytes();
symbol.type = *TV->getType();
Expand All @@ -143,9 +140,8 @@ glow::generateRuntimeBundle(const IRFunction &F,
}
auto activationsMaxSize = activationsAllocator.getMaxMemoryUsage();

runtime::RuntimeBundle info(constantMaxSize, placeholderMaxSize,
runtime::RuntimeBundle info(symbolTable, constantMaxSize, placeholderMaxSize,
activationsMaxSize);
info.symbolTable = std::move(symbolTable);
info.constants = collectConstants(&F, constantMaxSize, info.symbolTable);
info.collectConstants(&F);
return info;
}
2 changes: 1 addition & 1 deletion lib/Backends/CPU/AllocationsInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Constant;
class Context;

namespace runtime {
struct RuntimeBundle;
class RuntimeBundle;
}
/// Information about allocations for activations, constant weight variables
/// and mutable weight variables.
Expand Down
28 changes: 12 additions & 16 deletions lib/Backends/CPU/CPUFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,27 @@ CPUFunction::CPUFunction(std::unique_ptr<llvm::orc::GlowJIT> JIT,
const runtime::RuntimeBundle &runtimeBundle)
: JIT_(std::move(JIT)), runtimeBundle_(runtimeBundle) {}

CPUFunction::~CPUFunction() { alignedFree(runtimeBundle_.constants); }
CPUFunction::~CPUFunction() { alignedFree(runtimeBundle_.getConstants()); }

void CPUFunction::setupRuns() {
if (runtimeBundle_.activationsMemSize != 0) {
if (runtimeBundle_.getActivationsSize() != 0) {
baseActivationsAddress_ = (uint8_t *)alignedAlloc(
runtimeBundle_.activationsMemSize, TensorAlignment);
runtimeBundle_.getActivationsSize(), TensorAlignment);
}

if (runtimeBundle_.mutableWeightVarsMemSize != 0) {
if (runtimeBundle_.getMutableWeightSize() != 0) {
baseMutableWeightVarsAddress_ = (uint8_t *)alignedAlloc(
runtimeBundle_.mutableWeightVarsMemSize, TensorAlignment);
runtimeBundle_.getMutableWeightSize(), TensorAlignment);
}
}

void CPUFunction::beforeRun(const Context &ctx) {
// Copy Placeholders into allocated memory.
for (auto PH : ctx.pairs()) {
auto payload = PH.second->getUnsafePtr();
auto symbolInfo =
runtimeBundle_.symbolTable.find(std::string(PH.first->getName()));
assert(symbolInfo != runtimeBundle_.symbolTable.end() &&
"Symbol not found");
auto addr = symbolInfo->second.offset;
auto numBytes = symbolInfo->second.size;
auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first);
auto addr = symbolInfo.offset;
auto numBytes = symbolInfo.size;
// copy PH to allocated memory.
memcpy(baseMutableWeightVarsAddress_ + addr, payload, numBytes);
}
Expand All @@ -57,10 +54,9 @@ void CPUFunction::beforeRun(const Context &ctx) {
void CPUFunction::afterRun(const Context &ctx) {
// Copy placeholders from device back into context.
for (auto PH : ctx.pairs()) {
auto symbolInfo =
runtimeBundle_.symbolTable.find(std::string(PH.first->getName()));
auto payload = baseMutableWeightVarsAddress_ + symbolInfo->second.offset;
auto numBytes = symbolInfo->second.size;
auto symbolInfo = runtimeBundle_.getSymbolInfo(PH.first);
auto payload = baseMutableWeightVarsAddress_ + symbolInfo.offset;
auto numBytes = symbolInfo.size;
auto addr = PH.second->getUnsafePtr();
// copy PH from allocated memory.
memcpy(addr, payload, numBytes);
Expand All @@ -86,7 +82,7 @@ void CPUFunction::execute() {
auto address = sym.getAddress();
if (address) {
JitFuncType funcPtr = reinterpret_cast<JitFuncType>(address.get());
funcPtr(runtimeBundle_.constants, baseMutableWeightVarsAddress_,
funcPtr(runtimeBundle_.getConstants(), baseMutableWeightVarsAddress_,
baseActivationsAddress_);
} else {
GLOW_ASSERT(false && "Error getting address.");
Expand Down
2 changes: 1 addition & 1 deletion lib/Backends/CPU/CPUFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@

#include "GlowJIT.h"

#include "glow/Backends/BackendUtils.h"
#include "glow/Backends/CompiledFunction.h"

namespace glow {

/// A Glow IR function compiled for the CPU using LLVM.
class CPUFunction final : public CompiledFunction {
/// The LLVM JIT engine. The jit must be initialized after the ctor
Expand Down
14 changes: 6 additions & 8 deletions lib/Backends/Interpreter/InterpreterFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,15 @@ InterpreterFunction::~InterpreterFunction() {
}
tensors_.clear();
externalTensors_.clear();
alignedFree(bundle_.constants);
alignedFree(bundle_.getConstants());
}
void InterpreterFunction::setupRuns() {
if (bundle_.constantWeightVarsMemSize) {
if (bundle_.getConstantWeightSize()) {
for (const auto &v : F_->getGraph()->getParent()->getConstants()) {
auto it = bundle_.symbolTable.find(std::string(v->getName()));
if (it != bundle_.symbolTable.end()) {
auto addr = bundle_.constants + it->second.offset;
auto tensor = new Tensor(addr, &it->second.type);
constants_.emplace(it->first, tensor);
}
auto symbolInfo = bundle_.getSymbolInfo(v);
auto addr = bundle_.getConstants() + symbolInfo.offset;
auto tensor = new Tensor(addr, &symbolInfo.type);
constants_.emplace(std::string(v->getName()), tensor);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions lib/Backends/Interpreter/InterpreterFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#ifndef GLOW_BACKENDS_INTERPRETER_INTERPRETERFUNCTION_H
#define GLOW_BACKENDS_INTERPRETER_INTERPRETERFUNCTION_H

#include "glow/Backends/BackendUtils.h"
#include "glow/Backends/CompiledFunction.h"
#include "glow/Base/Tensor.h"
#include "glow/Graph/Context.h"
Expand Down
Loading

0 comments on commit a370c0d

Please sign in to comment.