Skip to content

Commit

Permalink
[profiler] Support top-level memory events
Browse files Browse the repository at this point in the history
Summary:
Mark memory events that did not happen within an operator context
explicitly in the profiler output.
This PR also adds an API to track memory events outside of or partially
overlapping with the profiler scope.

Test Plan:
python test/test_profiler.py -k test_memory_profiler

ghstack-source-id: 50b3f09f7a5cf4978f575f4f7a6d01e9f821666d
Pull Request resolved: #51421
  • Loading branch information
ilia-cher committed Feb 2, 2021
1 parent 79e7544 commit 6bc0ab8
Show file tree
Hide file tree
Showing 12 changed files with 239 additions and 123 deletions.
11 changes: 11 additions & 0 deletions c10/core/Allocator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <c10/core/Allocator.h>

#include <c10/util/ThreadLocalDebugInfo.h>
#include <atomic>

namespace c10 {

Expand Down Expand Up @@ -34,7 +35,17 @@ at::Allocator* GetAllocator(const at::DeviceType& t) {
return alloc;
}

namespace {
std::atomic<bool> global_memory_reporting_ {false};
}
void enableGlobalMemoryReporting(bool enable) {
global_memory_reporting_ = enable;
}

bool memoryProfilingEnabled() {
if (global_memory_reporting_) {
return true;
}
auto* reporter_ptr = static_cast<MemoryReportingInfoBase*>(
ThreadLocalDebugInfo::get(DebugInfoKind::PROFILER_STATE));
return reporter_ptr && reporter_ptr->memoryProfilingEnabled();
Expand Down
1 change: 1 addition & 0 deletions c10/core/Allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase {
virtual bool memoryProfilingEnabled() const = 0;
};

C10_API void enableGlobalMemoryReporting(bool);
C10_API bool memoryProfilingEnabled();
C10_API void reportMemoryUsageToProfiler(void* ptr, int64_t alloc_size, Device device);

Expand Down
4 changes: 4 additions & 0 deletions c10/core/CPUAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ void ProfiledCPUMemoryReporter::Delete(void* ptr) {
allocated = allocated_;
nbytes = it->second;
size_table_.erase(it);
} else {
LOG(WARNING) << "Memory block of unknown size was allocated before the profiling started, "
<< "use 'enable_global_memory_reporting' to track tensor sizes outside of "
<< "the profiling scope";
}
}
if (nbytes == 0) {
Expand Down
109 changes: 0 additions & 109 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3258,115 +3258,6 @@ def test_profiler_aggregation_lstm(self):
with tempfile.NamedTemporaryFile() as trace_file:
prof.export_chrome_trace(trace_file.name)

def test_memory_profiler(self):
def run_profiler(tensor_creation_fn, metric):
# collecting allocs / deallocs
with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof:
x = None
with record_function("test_user_scope_alloc"):
x = tensor_creation_fn()
with record_function("test_user_scope_dealloc"):
del x
stats = prof.key_averages(group_by_input_shape=True)
print(stats.table(sort_by=metric))
return stats

def check_metrics(stats, metric, allocs=None, deallocs=None):
stat_metrics = {}
for stat in stats:
stat_metrics[stat.key] = getattr(stat, metric)
if allocs is not None:
for alloc_fn in allocs:
self.assertTrue(alloc_fn in stat_metrics)
self.assertTrue(stat_metrics[alloc_fn] > 0)
if deallocs is not None:
for dealloc_fn in deallocs:
self.assertTrue(dealloc_fn in stat_metrics)
self.assertTrue(stat_metrics[dealloc_fn] < 0)

def create_cpu_tensor():
return torch.rand(10, 10)

def create_cuda_tensor():
return torch.rand(10, 10).cuda()

def create_mkldnn_tensor():
return torch.rand(10, 10, dtype=torch.float32).to_mkldnn()

print("Running CPU test")
stats = run_profiler(create_cpu_tensor, "cpu_memory_usage")
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::empty",
"aten::rand",
"test_user_scope_alloc",
],
deallocs=[
"test_user_scope_dealloc",
]
)

if torch.cuda.is_available():
create_cuda_tensor()
print("Running CUDA test")
stats = run_profiler(create_cuda_tensor, "cuda_memory_usage")
check_metrics(
stats,
"cuda_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::to",
"aten::empty_strided",
],
deallocs=[
"test_user_scope_dealloc",
]
)
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::rand",
"aten::empty",
]
)

if torch._C.has_mkldnn:
create_mkldnn_tensor()
print("Running MKLDNN test")
stats = run_profiler(create_mkldnn_tensor, "cpu_memory_usage")
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::rand",
"aten::empty",
"aten::to_mkldnn",
],
deallocs=[
"test_user_scope_dealloc",
]
)

# check partial overlap of tensor allocation with memory profiler
x = torch.rand(10, 10)
with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof:
del x
x = torch.rand(10, 10)
del x
stats = prof.key_averages(group_by_input_shape=True)
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::rand",
"aten::empty",
]
)

def test_record_function(self):
x = torch.randn(10, 10)

Expand Down
135 changes: 134 additions & 1 deletion test/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from torch.testing._internal.common_utils import (
TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS, TemporaryFileName)
from torch.autograd.profiler import profile as _profile
from torch.profiler import profile, kineto_available, DeviceType, ProfilerActivity
from torch.profiler import (
kineto_available, profile, record_function, DeviceType, ProfilerActivity
)

try:
import psutil
Expand Down Expand Up @@ -162,6 +164,137 @@ def test_kineto_multigpu(self):
self.assertTrue(found_gemm_1)
self.assertTrue(found_cuda)

def test_memory_profiler(self):
def run_profiler(tensor_creation_fn, metric):
# collecting allocs / deallocs
with _profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof:
x = None
with record_function("test_user_scope_alloc"):
x = tensor_creation_fn()
with record_function("test_user_scope_dealloc"):
del x
stats = prof.key_averages(group_by_input_shape=True)
print(stats.table(sort_by=metric))
return stats

def check_metrics(stats, metric, allocs=None, deallocs=None):
stat_metrics = {}
for stat in stats:
stat_metrics[stat.key] = getattr(stat, metric)
if allocs is not None:
for alloc_fn in allocs:
self.assertTrue(alloc_fn in stat_metrics)
self.assertTrue(stat_metrics[alloc_fn] > 0)
if deallocs is not None:
for dealloc_fn in deallocs:
self.assertTrue(dealloc_fn in stat_metrics)
self.assertTrue(stat_metrics[dealloc_fn] < 0)

def create_cpu_tensor():
return torch.rand(10, 10)

def create_cuda_tensor():
return torch.rand(10, 10).cuda()

def create_mkldnn_tensor():
return torch.rand(10, 10, dtype=torch.float32).to_mkldnn()

print("Running CPU test")
stats = run_profiler(create_cpu_tensor, "cpu_memory_usage")
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::empty",
"aten::rand",
"test_user_scope_alloc",
],
deallocs=[
"test_user_scope_dealloc",
]
)

if torch.cuda.is_available():
create_cuda_tensor()
print("Running CUDA test")
stats = run_profiler(create_cuda_tensor, "cuda_memory_usage")
check_metrics(
stats,
"cuda_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::to",
"aten::empty_strided",
],
deallocs=[
"test_user_scope_dealloc",
]
)
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::rand",
"aten::empty",
]
)

if torch._C.has_mkldnn:
create_mkldnn_tensor()
print("Running MKLDNN test")
stats = run_profiler(create_mkldnn_tensor, "cpu_memory_usage")
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::rand",
"aten::empty",
"aten::to_mkldnn",
],
deallocs=[
"test_user_scope_dealloc",
]
)

# check top-level memory events and
# partial overlap of tensor lifetime and profiler
torch.enable_global_memory_reporting(True)
x = torch.rand(10, 10)
y = None
if torch.cuda.is_available():
y = torch.rand(10, 10).cuda()
# mem events are CPU events
with _profile(profile_memory=True, use_kineto=kineto_available()) as prof:
del x
if torch.cuda.is_available():
del y
gc.collect()
x = torch.rand(10, 10)
del x
stats = prof.key_averages(group_by_input_shape=True)
print(stats.table(sort_by="cpu_memory_usage", row_limit=-1))
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::rand",
"aten::empty"
],
deallocs=[
"[memory]"
]
)
if torch.cuda.is_available():
check_metrics(
stats,
"cuda_memory_usage",
deallocs=[
"[memory]"
]
)
torch.enable_global_memory_reporting(False)

def test_high_level_trace(self):
"""Checks that python side high level events are recorded.
"""
Expand Down
1 change: 1 addition & 0 deletions tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
'init_num_threads': ['def init_num_threads() -> None: ...'],
'get_num_interop_threads': ['def get_num_interop_threads() -> _int: ...'],
'set_num_interop_threads': ['def set_num_interop_threads(num: _int) -> None: ...'],
'enable_global_memory_reporting': ['def enable_global_memory_reporting(enable: _bool) -> None: ...'],
# These functions are explicitly disabled by
# SKIP_PYTHON_BINDINGS because they are hand bound.
# Correspondingly, we must hand-write their signatures.
Expand Down
3 changes: 2 additions & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def get_num_thread() -> _int: ... # THPModule_getNumThreads
def set_num_threads(nthreads: _int) -> None: ... # THPModule_setNumThreads
def get_num_interop_threads() -> _int: ... # THPModule_getNumInteropThreads
def set_num_interop_threads(nthreads: _int) -> None: ... # THPModule_setNumInteropThreads
def enable_global_memory_reporting(arg: _bool) -> None: ... # THPModule_enableGlobalMemoryReporting
def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN
def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN
def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn
Expand Down Expand Up @@ -654,7 +655,7 @@ class _TensorBase(object):
output_nr: _int
_version: _int
_base: Optional[Tensor]
_cdata: _int
_cdata: _int
grad_fn: Any
_grad_fn: Any
_grad: Optional[Tensor]
Expand Down
29 changes: 23 additions & 6 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5112,10 +5112,10 @@ def merge_dicts(*dicts):
add_docstr(torch.fmax, r"""
fmax(input, other, *, out=None) -> Tensor
Computes the element-wise maximum of :attr:`input` and :attr:`other`.
Computes the element-wise maximum of :attr:`input` and :attr:`other`.
This is like :func:`torch.maximum` except it handles NaNs differently:
if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the maximum.
This is like :func:`torch.maximum` except it handles NaNs differently:
if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the maximum.
Only if both elements are NaN is NaN propagated.
This function is a wrapper around C++'s ``std::fmax`` and is similar to NumPy's ``fmax`` function.
Expand Down Expand Up @@ -5587,10 +5587,10 @@ def merge_dicts(*dicts):
add_docstr(torch.fmin, r"""
fmin(input, other, *, out=None) -> Tensor
Computes the element-wise minimum of :attr:`input` and :attr:`other`.
Computes the element-wise minimum of :attr:`input` and :attr:`other`.
This is like :func:`torch.minimum` except it handles NaNs differently:
if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the minimum.
This is like :func:`torch.minimum` except it handles NaNs differently:
if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the minimum.
Only if both elements are NaN is NaN propagated.
This function is a wrapper around C++'s ``std::fmin`` and is similar to NumPy's ``fmin`` function.
Expand Down Expand Up @@ -7588,6 +7588,23 @@ def merge_dicts(*dicts):
is started (e.g. JIT execution).
""")

add_docstr(torch.enable_global_memory_reporting, r"""
enable_global_memory_reporting(bool)
Allows profiler to track memory of tensors which lifetime partially
overlaps with profiling scope.
Example::
torch.enable_global_memory_reporting(True)
x = torch.rand(10, 10)
with profile(activities=[ProfilerActivity.CPU], profile_memory=True) as prof:
# use enable_global_memory_reporting to track the deallocation of x
del x
...
torch.enable_global_memory_reporting(False)
""")

add_docstr(torch.sigmoid, r"""
sigmoid(input, *, out=None) -> Tensor
Expand Down

0 comments on commit 6bc0ab8

Please sign in to comment.