-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[profiler] Support top-level memory events #51421
Changes from 1 commit
289a14b
6e974a3
7d1d418
479a7e0
326f80c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -292,6 +292,10 @@ void ProfiledCPUMemoryReporter::Delete(void* ptr) { | |
allocated = allocated_; | ||
nbytes = it->second; | ||
size_table_.erase(it); | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are there any changes required in CUDAAllocator? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CUDACachingAllocator already saves block sizes |
||
LOG(WARNING) << "Memory block of unknown size was allocated before the profiling started, " | ||
<< "use 'enable_global_memory_reporting' to track tensor sizes outside of " | ||
<< "the profiling scope"; | ||
} | ||
} | ||
if (nbytes == 0) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,9 @@ | |
from torch.testing._internal.common_utils import ( | ||
TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS, TemporaryFileName) | ||
from torch.autograd.profiler import profile as _profile | ||
from torch.profiler import profile, kineto_available, DeviceType, ProfilerActivity | ||
from torch.profiler import ( | ||
kineto_available, profile, record_function, DeviceType, ProfilerActivity | ||
) | ||
|
||
try: | ||
import psutil | ||
|
@@ -162,6 +164,140 @@ def test_kineto_multigpu(self): | |
self.assertTrue(found_gemm_1) | ||
self.assertTrue(found_cuda) | ||
|
||
def test_memory_profiler(self): | ||
def run_profiler(tensor_creation_fn, metric): | ||
# collecting allocs / deallocs | ||
with _profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should make a follow-up item to either migrate or duplicate the tests with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all our autograd.profiler tests should use |
||
x = None | ||
with record_function("test_user_scope_alloc"): | ||
x = tensor_creation_fn() | ||
with record_function("test_user_scope_dealloc"): | ||
del x | ||
stats = prof.key_averages(group_by_input_shape=True) | ||
print(stats.table(sort_by=metric)) | ||
return stats | ||
|
||
def check_metrics(stats, metric, allocs=None, deallocs=None): | ||
stat_metrics = {} | ||
for stat in stats: | ||
stat_metrics[stat.key] = getattr(stat, metric) | ||
if allocs is not None: | ||
for alloc_fn in allocs: | ||
self.assertTrue(alloc_fn in stat_metrics) | ||
self.assertTrue(stat_metrics[alloc_fn] > 0) | ||
if deallocs is not None: | ||
for dealloc_fn in deallocs: | ||
self.assertTrue(dealloc_fn in stat_metrics) | ||
self.assertTrue(stat_metrics[dealloc_fn] < 0) | ||
|
||
def create_cpu_tensor(): | ||
return torch.rand(10, 10) | ||
|
||
def create_cuda_tensor(): | ||
return torch.rand(10, 10).cuda() | ||
|
||
def create_mkldnn_tensor(): | ||
return torch.rand(10, 10, dtype=torch.float32).to_mkldnn() | ||
|
||
print("Running CPU test") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove debug print There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it was originally added on purpose (together with profiler output), I think we have many tests that do this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Idk, usually tests don't do this (they are tests), but there are some stray prints. Do you need this output? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only for debug, will remove |
||
stats = run_profiler(create_cpu_tensor, "cpu_memory_usage") | ||
check_metrics( | ||
stats, | ||
"cpu_memory_usage", | ||
allocs=[ | ||
"aten::empty", | ||
"aten::rand", | ||
"test_user_scope_alloc", | ||
], | ||
deallocs=[ | ||
"test_user_scope_dealloc", | ||
] | ||
) | ||
|
||
if torch.cuda.is_available(): | ||
create_cuda_tensor() | ||
print("Running CUDA test") | ||
stats = run_profiler(create_cuda_tensor, "cuda_memory_usage") | ||
check_metrics( | ||
stats, | ||
"cuda_memory_usage", | ||
allocs=[ | ||
"test_user_scope_alloc", | ||
"aten::to", | ||
"aten::empty_strided", | ||
], | ||
deallocs=[ | ||
"test_user_scope_dealloc", | ||
] | ||
) | ||
check_metrics( | ||
stats, | ||
"cpu_memory_usage", | ||
allocs=[ | ||
"aten::rand", | ||
"aten::empty", | ||
] | ||
) | ||
|
||
if torch._C.has_mkldnn: | ||
create_mkldnn_tensor() | ||
print("Running MKLDNN test") | ||
stats = run_profiler(create_mkldnn_tensor, "cpu_memory_usage") | ||
check_metrics( | ||
stats, | ||
"cpu_memory_usage", | ||
allocs=[ | ||
"test_user_scope_alloc", | ||
"aten::rand", | ||
"aten::empty", | ||
"aten::to_mkldnn", | ||
], | ||
deallocs=[ | ||
"test_user_scope_dealloc", | ||
] | ||
) | ||
|
||
if kineto_available(): | ||
torch.enable_global_memory_reporting(True) | ||
# check top-level memory events and | ||
# partial overlap of tensor lifetime and profiler | ||
x = torch.rand(10, 10) | ||
y = None | ||
if torch.cuda.is_available(): | ||
y = torch.rand(10, 10).cuda() | ||
with profile( | ||
# mem events are CPU events | ||
activities=[ProfilerActivity.CPU], | ||
profile_memory=True) as prof: | ||
del x | ||
if torch.cuda.is_available(): | ||
del y | ||
gc.collect() | ||
x = torch.rand(10, 10) | ||
del x | ||
stats = prof.key_averages(group_by_input_shape=True) | ||
print(stats.table(sort_by="cpu_memory_usage", row_limit=-1)) | ||
check_metrics( | ||
stats, | ||
"cpu_memory_usage", | ||
allocs=[ | ||
"aten::rand", | ||
"aten::empty" | ||
], | ||
deallocs=[ | ||
"[memory]" | ||
] | ||
) | ||
if torch.cuda.is_available(): | ||
check_metrics( | ||
stats, | ||
"cuda_memory_usage", | ||
deallocs=[ | ||
"[memory]" | ||
] | ||
) | ||
torch.enable_global_memory_reporting(False) | ||
|
||
def test_high_level_trace(self): | ||
"""Checks that python side high level events are recorded. | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -214,6 +214,8 @@ def export_chrome_trace(self, path): | |
# this technique is proven to give a 4x speedup. | ||
f.write("[") | ||
for evt in self: | ||
if evt.trace_name is None: | ||
continue | ||
f.write( | ||
'{"name": "%s", ' | ||
'"ph": "X", ' | ||
|
@@ -850,7 +852,7 @@ def __init__( | |
self.id: int = id | ||
self.node_id: int = node_id | ||
self.name: str = name | ||
self.trace_name: str = trace_name if trace_name is not None else self.name | ||
self.trace_name: str = trace_name | ||
self.time_range: Interval = Interval(start_us, end_us) | ||
self.thread: int = thread | ||
self.fwd_thread: Optional[int] = fwd_thread | ||
|
@@ -1123,12 +1125,14 @@ def parse_kineto_results(result): | |
# save memory allocation records | ||
start_record = None | ||
mem_records = [] | ||
covered_mem_records = [] | ||
for record in itertools.chain(*result.legacy_events()): | ||
if record.kind() == 'mark' and record.name() == '__start_profile': | ||
assert start_record is None | ||
start_record = record | ||
if record.kind() == 'memory_alloc': | ||
mem_records.append(record) | ||
covered_mem_records.append(False) | ||
assert start_record is not None, "Invalid profiler output, __start_profile is missing" | ||
|
||
# Create and return FunctionEvent list | ||
|
@@ -1145,11 +1149,14 @@ def parse_kineto_results(result): | |
cuda_memory_usage = 0 | ||
if kineto_event.device_type() == DeviceType.CPU: | ||
# find the corresponding memory allocation events | ||
for mem_record in mem_records: | ||
for mem_record_idx in range(len(mem_records)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
mem_record = mem_records[mem_record_idx] | ||
if (mem_record.start_us() >= kineto_event.start_us() and | ||
mem_record.start_us() <= abs_end_us): | ||
cpu_memory_usage += mem_record.cpu_memory_usage() | ||
cuda_memory_usage += mem_record.cuda_memory_usage() | ||
covered_mem_records[mem_record_idx] = True | ||
|
||
is_async = kineto_event.start_thread_id() != kineto_event.end_thread_id() | ||
fe = FunctionEvent( | ||
id=kineto_event.correlation_id(), | ||
|
@@ -1188,6 +1195,30 @@ def parse_kineto_results(result): | |
k_evt.start_us(), | ||
k_evt.start_us() + k_evt.duration_us()) | ||
|
||
# output top-level memory events | ||
for mem_record_idx in range(len(mem_records)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. with the previous proposal this would become:
|
||
if not covered_mem_records[mem_record_idx]: | ||
mem_record = mem_records[mem_record_idx] | ||
fe = FunctionEvent( | ||
id=mem_record.handle(), | ||
name="[memory]", | ||
trace_name=None, # not outputting in the trace | ||
thread=mem_record.thread_id(), | ||
start_us=mem_record.start_us(), | ||
end_us=mem_record.start_us(), # no duration | ||
fwd_thread=mem_record.fwd_thread_id(), | ||
input_shapes=[], | ||
stack=[], | ||
scope=mem_record.scope(), | ||
cpu_memory_usage=mem_record.cpu_memory_usage(), | ||
cuda_memory_usage=mem_record.cuda_memory_usage(), | ||
is_async=False, | ||
sequence_nr=-1, | ||
device_type=DeviceType.CPU, | ||
device_index=0, | ||
) | ||
function_events.append(fe) | ||
|
||
function_events.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end]) | ||
return function_events | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be
=enable
here?