Skip to content

Commit

Permalink
[profiler] Support top-level memory events
Browse files Browse the repository at this point in the history
Summary:
Mark memory events that did not happen within an operator context
explicitly in the profiler output.
This PR also adds an API to track memory events outside of or partially
overlapping with the profiler scope.

Test Plan:
python test/test_profiler.py -k test_memory_profiler

ghstack-source-id: 2857c59eb4ade824434727116c223a8a23a03f81
Pull Request resolved: #51421
  • Loading branch information
ilia-cher committed Feb 4, 2021
1 parent 0c60922 commit 8d1d934
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 134 deletions.
4 changes: 4 additions & 0 deletions c10/core/CPUAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ void ProfiledCPUMemoryReporter::Delete(void* ptr) {
allocated = allocated_;
nbytes = it->second;
size_table_.erase(it);
} else {
C10_LOG_EVERY_MS(WARNING, 1000)
<< "Memory block of unknown size was allocated before the profiling started, "
<< "profiler results will not include the deallocation event";
}
}
if (nbytes == 0) {
Expand Down
109 changes: 0 additions & 109 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -3258,115 +3258,6 @@ def test_profiler_aggregation_lstm(self):
with tempfile.NamedTemporaryFile() as trace_file:
prof.export_chrome_trace(trace_file.name)

def test_memory_profiler(self):
def run_profiler(tensor_creation_fn, metric):
# collecting allocs / deallocs
with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof:
x = None
with record_function("test_user_scope_alloc"):
x = tensor_creation_fn()
with record_function("test_user_scope_dealloc"):
del x
stats = prof.key_averages(group_by_input_shape=True)
print(stats.table(sort_by=metric))
return stats

def check_metrics(stats, metric, allocs=None, deallocs=None):
stat_metrics = {}
for stat in stats:
stat_metrics[stat.key] = getattr(stat, metric)
if allocs is not None:
for alloc_fn in allocs:
self.assertTrue(alloc_fn in stat_metrics)
self.assertTrue(stat_metrics[alloc_fn] > 0)
if deallocs is not None:
for dealloc_fn in deallocs:
self.assertTrue(dealloc_fn in stat_metrics)
self.assertTrue(stat_metrics[dealloc_fn] < 0)

def create_cpu_tensor():
return torch.rand(10, 10)

def create_cuda_tensor():
return torch.rand(10, 10).cuda()

def create_mkldnn_tensor():
return torch.rand(10, 10, dtype=torch.float32).to_mkldnn()

print("Running CPU test")
stats = run_profiler(create_cpu_tensor, "cpu_memory_usage")
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::empty",
"aten::rand",
"test_user_scope_alloc",
],
deallocs=[
"test_user_scope_dealloc",
]
)

if torch.cuda.is_available():
create_cuda_tensor()
print("Running CUDA test")
stats = run_profiler(create_cuda_tensor, "cuda_memory_usage")
check_metrics(
stats,
"cuda_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::to",
"aten::empty_strided",
],
deallocs=[
"test_user_scope_dealloc",
]
)
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::rand",
"aten::empty",
]
)

if torch._C.has_mkldnn:
create_mkldnn_tensor()
print("Running MKLDNN test")
stats = run_profiler(create_mkldnn_tensor, "cpu_memory_usage")
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::rand",
"aten::empty",
"aten::to_mkldnn",
],
deallocs=[
"test_user_scope_dealloc",
]
)

# check partial overlap of tensor allocation with memory profiler
x = torch.rand(10, 10)
with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof:
del x
x = torch.rand(10, 10)
del x
stats = prof.key_averages(group_by_input_shape=True)
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::rand",
"aten::empty",
]
)

def test_record_function(self):
x = torch.randn(10, 10)

Expand Down
141 changes: 129 additions & 12 deletions test/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from torch.testing._internal.common_utils import (
TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS, TemporaryFileName, TemporaryDirectoryName)
from torch.autograd.profiler import profile as _profile
from torch.profiler import profile, kineto_available, DeviceType, ProfilerActivity
from torch.profiler import (
kineto_available, profile, record_function, DeviceType, ProfilerActivity
)

try:
import psutil
Expand Down Expand Up @@ -92,10 +94,6 @@ def forward(self, x):
c = b.sum()
c.backward()

print(p.key_averages(
group_by_stack_n=5).table(
sort_by="self_cpu_time_total", row_limit=-1))

for e in p.function_events:
if "aten::add" in e.name or "AddBackward" in e.name:
self.assertTrue(any(["test_profiler" in entry for entry in e.stack]))
Expand All @@ -122,8 +120,9 @@ def test_kineto(self):
# rerun to avoid initial start overhead
with _profile(use_cuda=True, use_kineto=True) as p:
self.payload()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
output = p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1)
# print(output)
found_gemm = False
found_memcpy = False
for e in p.function_events:
Expand Down Expand Up @@ -163,6 +162,123 @@ def test_kineto_multigpu(self):
self.assertTrue(found_gemm_1)
self.assertTrue(found_cuda)

def test_memory_profiler(self):
def run_profiler(tensor_creation_fn, metric):
# collecting allocs / deallocs
with _profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof:
x = None
with record_function("test_user_scope_alloc"):
x = tensor_creation_fn()
with record_function("test_user_scope_dealloc"):
del x
return prof.key_averages(group_by_input_shape=True)

def check_metrics(stats, metric, allocs=None, deallocs=None):
stat_metrics = {}
for stat in stats:
stat_metrics[stat.key] = getattr(stat, metric)
if allocs is not None:
for alloc_fn in allocs:
self.assertTrue(alloc_fn in stat_metrics)
self.assertTrue(stat_metrics[alloc_fn] > 0)
if deallocs is not None:
for dealloc_fn in deallocs:
self.assertTrue(dealloc_fn in stat_metrics)
self.assertTrue(stat_metrics[dealloc_fn] < 0)

def create_cpu_tensor():
return torch.rand(10, 10)

def create_cuda_tensor():
return torch.rand(10, 10).cuda()

def create_mkldnn_tensor():
return torch.rand(10, 10, dtype=torch.float32).to_mkldnn()

stats = run_profiler(create_cpu_tensor, "cpu_memory_usage")
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::empty",
"aten::rand",
"test_user_scope_alloc",
],
deallocs=[
"test_user_scope_dealloc",
]
)

if torch.cuda.is_available():
create_cuda_tensor()
stats = run_profiler(create_cuda_tensor, "cuda_memory_usage")
check_metrics(
stats,
"cuda_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::to",
"aten::empty_strided",
],
deallocs=[
"test_user_scope_dealloc",
]
)
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::rand",
"aten::empty",
]
)

if torch._C.has_mkldnn:
create_mkldnn_tensor()
stats = run_profiler(create_mkldnn_tensor, "cpu_memory_usage")
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::rand",
"aten::empty",
"aten::to_mkldnn",
],
deallocs=[
"test_user_scope_dealloc",
]
)

# check top-level memory events
with _profile(profile_memory=True, use_kineto=kineto_available()) as prof:
x = torch.rand(10, 10)
del x
if torch.cuda.is_available():
y = torch.rand(10, 10).cuda()
del y
gc.collect()
stats = prof.key_averages(group_by_input_shape=True)
check_metrics(
stats,
"cpu_memory_usage",
allocs=[
"aten::rand",
"aten::empty"
],
deallocs=[
"[memory]"
]
)
if torch.cuda.is_available():
check_metrics(
stats,
"cuda_memory_usage",
deallocs=[
"[memory]"
]
)

def test_high_level_trace(self):
"""Checks that python side high level events are recorded.
"""
Expand Down Expand Up @@ -267,7 +383,6 @@ def test_flops(self):
with _profile(record_shapes=True, with_flops=True, use_kineto=kineto_available()) as prof:
model(inputs)
profiler_output = prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10)
print(profiler_output)
self.assertIn("FLOPS", profiler_output)

@unittest.skipIf(not kineto_available(), "Kineto is required")
Expand All @@ -279,8 +394,9 @@ def test_kineto_profiler_api(self):
self.payload()

def trace_handler(p):
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
output = p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1)
# print(output)
# p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json")
called_num[0] += 1

Expand Down Expand Up @@ -308,8 +424,9 @@ def trace_handler(p):
) as p:
self.payload()
self.payload()
print(p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
output = p.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1)
# print(output)

def test_export_stacks(self):
with _profile(with_stack=True, use_kineto=kineto_available()) as p:
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ class _TensorBase(object):
output_nr: _int
_version: _int
_base: Optional[Tensor]
_cdata: _int
_cdata: _int
grad_fn: Any
_grad_fn: Any
_grad: Optional[Tensor]
Expand Down
12 changes: 6 additions & 6 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5151,10 +5151,10 @@ def merge_dicts(*dicts):
add_docstr(torch.fmax, r"""
fmax(input, other, *, out=None) -> Tensor
Computes the element-wise maximum of :attr:`input` and :attr:`other`.
Computes the element-wise maximum of :attr:`input` and :attr:`other`.
This is like :func:`torch.maximum` except it handles NaNs differently:
if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the maximum.
This is like :func:`torch.maximum` except it handles NaNs differently:
if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the maximum.
Only if both elements are NaN is NaN propagated.
This function is a wrapper around C++'s ``std::fmax`` and is similar to NumPy's ``fmax`` function.
Expand Down Expand Up @@ -5626,10 +5626,10 @@ def merge_dicts(*dicts):
add_docstr(torch.fmin, r"""
fmin(input, other, *, out=None) -> Tensor
Computes the element-wise minimum of :attr:`input` and :attr:`other`.
Computes the element-wise minimum of :attr:`input` and :attr:`other`.
This is like :func:`torch.minimum` except it handles NaNs differently:
if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the minimum.
This is like :func:`torch.minimum` except it handles NaNs differently:
if exactly one of the two elements being compared is a NaN then the non-NaN element is taken as the minimum.
Only if both elements are NaN is NaN propagated.
This function is a wrapper around C++'s ``std::fmin`` and is similar to NumPy's ``fmin`` function.
Expand Down

0 comments on commit 8d1d934

Please sign in to comment.