-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[profiler] Support top-level memory events #51421
Changes from all commits
289a14b
6e974a3
7d1d418
479a7e0
326f80c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,9 @@ | |
from torch.testing._internal.common_utils import ( | ||
TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS, TemporaryFileName, TemporaryDirectoryName) | ||
from torch.autograd.profiler import profile as _profile | ||
from torch.profiler import profile, kineto_available, DeviceType, ProfilerActivity | ||
from torch.profiler import ( | ||
kineto_available, profile, record_function, DeviceType, ProfilerActivity | ||
) | ||
|
||
try: | ||
import psutil | ||
|
@@ -92,10 +94,6 @@ def forward(self, x): | |
c = b.sum() | ||
c.backward() | ||
|
||
print(p.key_averages( | ||
group_by_stack_n=5).table( | ||
sort_by="self_cpu_time_total", row_limit=-1)) | ||
|
||
for e in p.function_events: | ||
if "aten::add" in e.name or "AddBackward" in e.name: | ||
self.assertTrue(any(["test_profiler" in entry for entry in e.stack])) | ||
|
@@ -122,8 +120,9 @@ def test_kineto(self): | |
# rerun to avoid initial start overhead | ||
with _profile(use_cuda=True, use_kineto=True) as p: | ||
self.payload() | ||
print(p.key_averages().table( | ||
sort_by="self_cuda_time_total", row_limit=-1)) | ||
output = p.key_averages().table( | ||
sort_by="self_cuda_time_total", row_limit=-1) | ||
# print(output) | ||
found_gemm = False | ||
found_memcpy = False | ||
for e in p.function_events: | ||
|
@@ -163,6 +162,123 @@ def test_kineto_multigpu(self): | |
self.assertTrue(found_gemm_1) | ||
self.assertTrue(found_cuda) | ||
|
||
def test_memory_profiler(self): | ||
def run_profiler(tensor_creation_fn, metric): | ||
# collecting allocs / deallocs | ||
with _profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should make a follow-up item to either migrate or duplicate the tests with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all our autograd.profiler tests should use |
||
x = None | ||
with record_function("test_user_scope_alloc"): | ||
x = tensor_creation_fn() | ||
with record_function("test_user_scope_dealloc"): | ||
del x | ||
return prof.key_averages(group_by_input_shape=True) | ||
|
||
def check_metrics(stats, metric, allocs=None, deallocs=None): | ||
stat_metrics = {} | ||
for stat in stats: | ||
stat_metrics[stat.key] = getattr(stat, metric) | ||
if allocs is not None: | ||
for alloc_fn in allocs: | ||
self.assertTrue(alloc_fn in stat_metrics) | ||
self.assertTrue(stat_metrics[alloc_fn] > 0) | ||
if deallocs is not None: | ||
for dealloc_fn in deallocs: | ||
self.assertTrue(dealloc_fn in stat_metrics) | ||
self.assertTrue(stat_metrics[dealloc_fn] < 0) | ||
|
||
def create_cpu_tensor(): | ||
return torch.rand(10, 10) | ||
|
||
def create_cuda_tensor(): | ||
return torch.rand(10, 10).cuda() | ||
|
||
def create_mkldnn_tensor(): | ||
return torch.rand(10, 10, dtype=torch.float32).to_mkldnn() | ||
|
||
stats = run_profiler(create_cpu_tensor, "cpu_memory_usage") | ||
check_metrics( | ||
stats, | ||
"cpu_memory_usage", | ||
allocs=[ | ||
"aten::empty", | ||
"aten::rand", | ||
"test_user_scope_alloc", | ||
], | ||
deallocs=[ | ||
"test_user_scope_dealloc", | ||
] | ||
) | ||
|
||
if torch.cuda.is_available(): | ||
create_cuda_tensor() | ||
stats = run_profiler(create_cuda_tensor, "cuda_memory_usage") | ||
check_metrics( | ||
stats, | ||
"cuda_memory_usage", | ||
allocs=[ | ||
"test_user_scope_alloc", | ||
"aten::to", | ||
"aten::empty_strided", | ||
], | ||
deallocs=[ | ||
"test_user_scope_dealloc", | ||
] | ||
) | ||
check_metrics( | ||
stats, | ||
"cpu_memory_usage", | ||
allocs=[ | ||
"aten::rand", | ||
"aten::empty", | ||
] | ||
) | ||
|
||
if torch._C.has_mkldnn: | ||
create_mkldnn_tensor() | ||
stats = run_profiler(create_mkldnn_tensor, "cpu_memory_usage") | ||
check_metrics( | ||
stats, | ||
"cpu_memory_usage", | ||
allocs=[ | ||
"test_user_scope_alloc", | ||
"aten::rand", | ||
"aten::empty", | ||
"aten::to_mkldnn", | ||
], | ||
deallocs=[ | ||
"test_user_scope_dealloc", | ||
] | ||
) | ||
|
||
# check top-level memory events | ||
with _profile(profile_memory=True, use_kineto=kineto_available()) as prof: | ||
x = torch.rand(10, 10) | ||
del x | ||
if torch.cuda.is_available(): | ||
y = torch.rand(10, 10).cuda() | ||
del y | ||
gc.collect() | ||
stats = prof.key_averages(group_by_input_shape=True) | ||
check_metrics( | ||
stats, | ||
"cpu_memory_usage", | ||
allocs=[ | ||
"aten::rand", | ||
"aten::empty" | ||
], | ||
deallocs=[ | ||
"[memory]" | ||
] | ||
) | ||
if torch.cuda.is_available(): | ||
check_metrics( | ||
stats, | ||
"cuda_memory_usage", | ||
deallocs=[ | ||
"[memory]" | ||
] | ||
) | ||
|
||
def test_high_level_trace(self): | ||
"""Checks that python side high level events are recorded. | ||
""" | ||
|
@@ -267,7 +383,6 @@ def test_flops(self): | |
with _profile(record_shapes=True, with_flops=True, use_kineto=kineto_available()) as prof: | ||
model(inputs) | ||
profiler_output = prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10) | ||
print(profiler_output) | ||
self.assertIn("FLOPS", profiler_output) | ||
|
||
@unittest.skipIf(not kineto_available(), "Kineto is required") | ||
|
@@ -279,8 +394,9 @@ def test_kineto_profiler_api(self): | |
self.payload() | ||
|
||
def trace_handler(p): | ||
print(p.key_averages().table( | ||
sort_by="self_cuda_time_total", row_limit=-1)) | ||
output = p.key_averages().table( | ||
sort_by="self_cuda_time_total", row_limit=-1) | ||
# print(output) | ||
# p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json") | ||
called_num[0] += 1 | ||
|
||
|
@@ -308,8 +424,9 @@ def trace_handler(p): | |
) as p: | ||
self.payload() | ||
self.payload() | ||
print(p.key_averages().table( | ||
sort_by="self_cuda_time_total", row_limit=-1)) | ||
output = p.key_averages().table( | ||
sort_by="self_cuda_time_total", row_limit=-1) | ||
# print(output) | ||
|
||
def test_export_stacks(self): | ||
with _profile(with_stack=True, use_kineto=kineto_available()) as p: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are there any changes required in CUDAAllocator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDACachingAllocator already saves block sizes