Skip to content

Commit

Permalink
[JIT][DRAFT] profile guided memory planning
Browse files Browse the repository at this point in the history
ghstack-source-id: 5256d94abbf80ddda00980dade0623674bb05299
Pull Request resolved: #63873

reorder

refactor profiling allocator

put profiling back

reconcile profiling

ghstack-source-id: 5256d94abbf80ddda00980dade0623674bb05299
Pull Request resolved: #64351

use make_pair instead of make_tuple

incorporate uniqueliverange

fix tests

whoops

rename stuff

size_t for memory tracing
  • Loading branch information
Maksim Levental committed Sep 13, 2021
1 parent 4b5c319 commit 7f67655
Show file tree
Hide file tree
Showing 12 changed files with 429 additions and 1 deletion.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace c10 {
_(prim, AllocateSlab) \
_(prim, ReleaseSlab) \
_(prim, AllocateTensor) \
_(prim, PreallocateTensor) \
_(prim, ConstantMKLDNNTensor) \
_(prim, BroadcastMKLDNNTensors) \
_(prim, MKLDNNGroup) \
Expand Down
4 changes: 4 additions & 0 deletions c10/core/Allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ C10_API at::Allocator* allocator_array[at::COMPILE_TIME_MAX_DEVICE_TYPES];
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
C10_API uint8_t allocator_priority[at::COMPILE_TIME_MAX_DEVICE_TYPES] = {0};

uint8_t GetAllocatorPriority(at::DeviceType t) {
return allocator_priority[static_cast<int>(t)];
}

void SetAllocator(at::DeviceType t, at::Allocator* alloc, uint8_t priority) {
if (priority >= allocator_priority[static_cast<int>(t)]) {
allocator_array[static_cast<int>(t)] = alloc;
Expand Down
2 changes: 2 additions & 0 deletions c10/core/Allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ inline bool operator!=(std::nullptr_t, const DataPtr& dp) noexcept {
// possible, or the raw interface will incorrectly reported as unsupported,
// when it is actually possible.

C10_API uint8_t GetAllocatorPriority(at::DeviceType t);

struct C10_API Allocator {
virtual ~Allocator() = default;

Expand Down
6 changes: 6 additions & 0 deletions test/cpp/jit/test_memory_planning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ struct AllocAttrs {
c10::TensorTypePtr ttp;
};

struct PreAllocAttrs {
int64_t size;
int64_t offset;
DeviceType device_type;
};

void checkAllocNodes(
Graph& graph,
StorageAttrs expected_storage,
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/passes/memory_planning/greedy_by_size.cpp",
"torch/csrc/jit/passes/memory_planning/greedy_by_breadth.cpp",
"torch/csrc/jit/passes/memory_planning/greedy_util.cpp",
"torch/csrc/jit/passes/memory_planning/MemoryPlanningAllocator.cpp",
"torch/csrc/jit/passes/normalize_ops.cpp",
"torch/csrc/jit/passes/peephole_dict_idioms.cpp",
"torch/csrc/jit/passes/peephole_list_idioms.cpp",
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/ir/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ void AliasDb::analyzeImpl(Node* node) {
case prim::AllocateSlab:
case prim::ReleaseSlab:
case prim::AllocateTensor:
case prim::PreallocateTensor:
case prim::Closure:
case prim::CreateObject:
case prim::tolist:
Expand Down
232 changes: 231 additions & 1 deletion torch/csrc/jit/passes/memory_planning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

#include <jit/tensorexpr/kernel.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/runtime/static/ops.h>
#include <limits>

Expand Down Expand Up @@ -182,6 +181,55 @@ void insertAllocNodes(
release_slab->insertBefore(graph->return_node());
}

struct frameNodeLTCmp {
size_t operator()(
const std::pair<FrameNodeId, std::vector<UniqueLiveRange>>& f1,
const std::pair<FrameNodeId, std::vector<UniqueLiveRange>>& f2) const {
return f1.first.pc < f2.first.pc;
}
};

void insertPreallocNodes(
std::shared_ptr<Graph>& graph,
size_t total_size,
std::vector<MemAllocation> allocations,
std::vector<std::pair<FrameNodeId, std::vector<UniqueLiveRange>>>
collected_node_live_ranges,
c10::optional<at::Device> device_type = c10::nullopt) {
auto slab = insertSlabNode(graph, total_size, device_type);
auto release_slab = graph->create(prim::ReleaseSlab, 0);
release_slab->addInput(slab->output());

SortedLiveRangeMap<MemRegion> allocations_map;
for (const auto& item : allocations) {
allocations_map[item.ulvr] = item.reg;
}

std::sort(
collected_node_live_ranges.begin(),
collected_node_live_ranges.end(),
frameNodeLTCmp());

for (auto& item : collected_node_live_ranges) {
auto frame_id = item.first;
auto lvrs = item.second;
std::sort(lvrs.begin(), lvrs.end(), liveRangeStartCmp());
auto node = frame_id.node;

for (const auto& lvr : lvrs) {
auto* alloc = graph->create(prim::PreallocateTensor, 1);
alloc->insertBefore(node);
alloc->addInput(slab->output());

auto region = allocations_map[lvr];
alloc->i_(attr::size, (int64_t)region.size);
alloc->i_(attr::offset, (int64_t)region.offset);
alloc->i_(attr::device_type, slab->i(attr::device_type));
}
}
release_slab->insertBefore(graph->return_node());
}

bool hasOutVariant(Node* node) {
if (!node->maybeSchema()) {
return false;
Expand Down Expand Up @@ -344,6 +392,188 @@ bool validateAllocations(
return true;
}

struct frameNodeEqCmp {
bool operator()(const FrameNodeId& lhs, const FrameNodeId& rhs) const {
if (!lhs.node->maybeSchema() || !rhs.node->maybeSchema()) {
return false;
}

return lhs.pc == rhs.pc && lhs.node->schema() == rhs.node->schema() &&
getHeader(lhs.node) == getHeader(rhs.node);
}
};

struct frameNodeHash {
size_t operator()(const FrameNodeId& frame_node_id) const {
return std::hash<size_t>()(frame_node_id.pc) ^
(std::hash<std::string>()(getHeader(frame_node_id.node)) << 1);
}
};

std::vector<std::pair<FrameNodeId, std::vector<UniqueLiveRange>>>
collectLiveRangesPerNode(std::vector<std::pair<UniqueLiveRange, FrameNodeId>>
live_range_node_header) {
std::unordered_map<
FrameNodeId,
std::vector<UniqueLiveRange>,
frameNodeHash,
frameNodeEqCmp>
node_live_ranges;
for (const auto& item : live_range_node_header) {
auto lvr = item.first;
auto frame_node_id = item.second;
node_live_ranges[frame_node_id].emplace_back(lvr);
}

std::vector<std::pair<FrameNodeId, std::vector<UniqueLiveRange>>>
collected_node_live_ranges;
for (const auto& item : node_live_ranges) {
std::vector<UniqueLiveRange> lvrs(item.second.begin(), item.second.end());
std::sort(lvrs.begin(), lvrs.end(), liveRangeStartCmp());
collected_node_live_ranges.emplace_back(std::make_pair(item.first, lvrs));
}
std::sort(
collected_node_live_ranges.begin(),
collected_node_live_ranges.end(),
frameNodeLTCmp());
return collected_node_live_ranges;
}

std::pair<
SortedLiveRangeMap<size_t>,
std::vector<std::pair<UniqueLiveRange, FrameNodeId>>>
getManagedLiveRangesFromMemoryEvents(
std::vector<MemoryEvent> mem_events,
const std::shared_ptr<Graph> graph) {
SortedLiveRangeMap<size_t> managed_live_ranges;
std::vector<std::pair<UniqueLiveRange, FrameNodeId>> live_range_node_header;
live_range_node_header.reserve(mem_events.size());

std::unordered_map<intptr_t, MemoryEvent> allocs;
auto trace_hasher = std::hash<std::string>();
// validate
for (auto& mem_event : mem_events) {
if (mem_event.type == MemoryEvent::EventType::ALLOCATE) {
if (mem_event.frame_node_id.has_value()) {
allocs.insert({mem_event.addr, mem_event});
} else {
// created before interpreter started e.g. inputs and weights...

TORCH_INTERNAL_ASSERT(mem_event.frame_node_id->pc == 0);
}
} else if (mem_event.type == MemoryEvent::EventType::FREE) {
TORCH_INTERNAL_ASSERT(allocs.count(mem_event.addr) > 0);
TORCH_INTERNAL_ASSERT(allocs.find(mem_event.addr) != allocs.end());
auto alloc = allocs.at(mem_event.addr);
TORCH_INTERNAL_ASSERT(
alloc.type == MemoryEvent::EventType::ALLOCATE,
" ",
alloc.type,
" ",
MemoryEvent::EventType::ALLOCATE);
TORCH_INTERNAL_ASSERT(
alloc.size == mem_event.size, " ", alloc.size, " ", mem_event.size);
TORCH_INTERNAL_ASSERT(
alloc.ts < mem_event.ts, " ", alloc.ts, " ", mem_event.ts);

auto lvr = UniqueLiveRange{
{alloc.ts, mem_event.ts},
std::to_string(trace_hasher(mem_event.stack_trace.value()))};
managed_live_ranges.insert({lvr, alloc.size});

live_range_node_header.emplace_back(
std::make_pair(lvr, alloc.frame_node_id.value()));
allocs.erase(mem_event.addr);
}
}

if (!allocs.empty()) {
// TODO: jit::Value* .count()>0 doesn't work for some reason
// std::unordered_set<const jit::Value*> g_outputs;
std::unordered_set<std::string> g_outputs;
for (const auto& outp : graph->return_node()->outputs()) {
std::cout << "return outp " << outp->debugName() << "\n";
}
for (const auto& outp : graph->outputs()) {
g_outputs.insert(outp->debugName());
}
for (auto& alloc : allocs) {
TORCH_INTERNAL_ASSERT(
alloc.second.type == MemoryEvent::EventType::ALLOCATE &&
alloc.second.frame_node_id.has_value());
GRAPH_DEBUG("leaked alloc: ", alloc.second, "\n");
// TODO: this isn't a great heuristic (since tensors created within
// the scope of an op could be leaked but not the actual output values.
// a better way would be to connect allocs directly to values
if (alloc.second.frame_node_id.value().node->outputs().size() > 0) {
for (const auto& out :
alloc.second.frame_node_id.value().node->outputs()) {
TORCH_INTERNAL_ASSERT(
g_outputs.count(out->debugName()) > 0, out->debugName());
}
}
TORCH_WARN(alloc.second, " leaked");
}
}
return std::make_pair(managed_live_ranges, live_range_node_header);
}

void planMemoryWithTracing(
std::shared_ptr<Graph>& graph,
Strategy strat,
std::vector<MemoryEvent> mem_events,
at::Device device_type) {
TORCH_INTERNAL_ASSERT(!mem_events.empty());
SortedLiveRangeMap<size_t> managed_live_ranges;
std::vector<std::pair<UniqueLiveRange, FrameNodeId>> live_range_node_header;
std::tie(managed_live_ranges, live_range_node_header) =
getManagedLiveRangesFromMemoryEvents(mem_events, graph);
std::vector<MemAllocation> allocations;

switch (strat) {
case Strategy::NAIVE: {
allocations = naive(managed_live_ranges);
break;
}
case Strategy::LINEAR_SCAN: {
allocations = linearScanHeuristic(managed_live_ranges);
break;
}
case Strategy::GREEDY_BY_SIZE_WITH_SMALLEST_GAP: {
allocations = greedyBySizeWithSmallestGap(managed_live_ranges);
break;
}
case Strategy::GREEDY_BY_SIZE_WITH_FIRST_GAP: {
allocations = greedyBySizeWithFirstGap(managed_live_ranges);
break;
}
case Strategy::GREEDY_BY_LONGEST_AND_SIZE_WITH_SMALLEST_GAP: {
allocations = greedyByLongestAndSizeWithSmallestGap(managed_live_ranges);
break;
}
case Strategy::GREEDY_BY_LONGEST_AND_SIZE_WITH_FIRST_GAP: {
allocations = greedyByLongestAndSizeWithFirstGap(managed_live_ranges);
break;
}
default:
return;
}

GRAPH_DEBUG("\nnumber of allocations\n", allocations.size());
auto total_size = getTotalAllocationSize(allocations);

TORCH_INTERNAL_ASSERT(
validateAllocations(allocations, managed_live_ranges, total_size),
"invalid allocation",
strat);

auto collected_node_live_ranges =
collectLiveRangesPerNode(live_range_node_header);

insertPreallocNodes(
graph, total_size, allocations, collected_node_live_ranges);
}

std::pair<size_t, FastMap<const Value*, std::pair<UniqueLiveRange, size_t>>>
planMemory(const std::shared_ptr<Graph>& graph, Strategy strat) {
FastMap<const Value*, std::pair<UniqueLiveRange, size_t>> managed_values,
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/passes/memory_planning.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <torch/csrc/jit/passes/memory_planning/memory_observer.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/static/impl.h>

Expand Down Expand Up @@ -144,6 +145,12 @@ TORCH_API std::
pair<size_t, FastMap<const Value*, std::pair<UniqueLiveRange, size_t>>>
planMemory(const std::shared_ptr<Graph>&, Strategy);

TORCH_API void planMemoryWithTracing(
std::shared_ptr<Graph>& graph,
Strategy strat,
std::vector<MemoryEvent> mem_events,
at::Device device_type);

} // namespace jit
} // namespace torch

Expand Down
Loading

0 comments on commit 7f67655

Please sign in to comment.