diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index b640032bcd32..2e4a39b2e8b1 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -10,13 +10,13 @@ namespace { // Used to generate unique callback handles CallbackHandle next_unique_callback_handle() { - static std::atomic unique_cb_id {0}; - return CallbackHandle(++unique_cb_id); + static std::atomic unique_cb_id {1}; + return CallbackHandle(unique_cb_id++); } RecordFunctionHandle next_unique_record_function_handle() { - static std::atomic unique_rf_id {0}; - return RecordFunctionHandle(++unique_rf_id); + static std::atomic unique_rf_id {1}; + return RecordFunctionHandle(unique_rf_id++); } thread_local RecordFunctionTLS rf_tls_; diff --git a/benchmarks/profiler_benchmark/profiler_bench.py b/benchmarks/profiler_benchmark/profiler_bench.py index 6b187b03522e..75cd490fed2e 100644 --- a/benchmarks/profiler_benchmark/profiler_bench.py +++ b/benchmarks/profiler_benchmark/profiler_bench.py @@ -1,10 +1,9 @@ import argparse -import statistics import sys import timeit import torch -from torch.utils._benchmark import Timer +from torch.utils.benchmark import Timer PARALLEL_TASKS_NUM = 4 INTERNAL_ITER = None @@ -34,12 +33,12 @@ def parallel_task(x): parser.add_argument('--with_cuda', action='store_true') parser.add_argument('--with_stack', action='store_true') parser.add_argument('--use_script', action='store_true') + parser.add_argument('--use_kineto', action='store_true') parser.add_argument('--profiling_tensor_size', default=1, type=int) parser.add_argument('--workload', default='loop', type=str) parser.add_argument('--internal_iter', default=256, type=int) - parser.add_argument('--n', default=100, type=int) - parser.add_argument('--use_timer', action='store_true') - parser.add_argument('--timer_min_run_time', default=100, type=int) + parser.add_argument('--timer_min_run_time', default=10, type=int) + parser.add_argument('--cuda_only', action='store_true') args = parser.parse_args() @@ -47,16 +46,17 @@ def parallel_task(x): print("No CUDA available") sys.exit() - print("Payload: {}; {} iterations, N = {}\n".format( - args.workload, args.internal_iter, args.n)) + print("Payload: {}, {} iterations; timer min. runtime = {}\n".format( + args.workload, args.internal_iter, args.timer_min_run_time)) INTERNAL_ITER = args.internal_iter for profiling_enabled in [False, True]: - print("Profiling {}, tensor size {}x{}, use cuda: {}, with stacks: {}, use script: {}".format( + print("Profiling {}, tensor size {}x{}, use cuda: {}, use kineto: {}, with stacks: {}, use script: {}".format( "enabled" if profiling_enabled else "disabled", args.profiling_tensor_size, args.profiling_tensor_size, args.with_cuda, + args.use_kineto, args.with_stack, args.use_script)) @@ -83,27 +83,18 @@ def payload(): x = None with torch.autograd.profiler.profile( use_cuda=args.with_cuda, - with_stack=args.with_stack) as prof: + with_stack=args.with_stack, + use_kineto=args.use_kineto, + use_cpu=not args.cuda_only) as prof: x = workload(input_x) return x else: def payload(): return workload(input_x) - if args.use_timer: - t = Timer( - "payload()", - globals={"payload": payload}, - timer=timeit.default_timer, - ).blocked_autorange(min_run_time=args.timer_min_run_time) - print(t) - else: - runtimes = timeit.repeat(payload, repeat=args.n, number=1) - avg_time = statistics.mean(runtimes) * 1000.0 - stddev_time = statistics.stdev(runtimes) * 1000.0 - print("\tavg. time: {:.3f} ms, stddev: {:.3f} ms".format( - avg_time, stddev_time)) - if args.workload == "loop": - print("\ttime per iteration: {:.3f} ms".format( - avg_time / args.internal_iter)) - print() + t = Timer( + "payload()", + globals={"payload": payload}, + timer=timeit.default_timer, + ).blocked_autorange(min_run_time=args.timer_min_run_time) + print(t) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index f210fd62ed08..42fc53700fc2 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1751,7 +1751,8 @@ endif() # # End ATen checks # - +set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) +set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt) # Disable compiler feature checks for `fmt`. @@ -1764,6 +1765,7 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt) set_target_properties(fmt-header-only PROPERTIES INTERFACE_COMPILE_FEATURES "") list(APPEND Caffe2_DEPENDENCY_LIBS fmt::fmt-header-only) +set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE) # ---[ Kineto if(USE_KINETO) @@ -1774,8 +1776,34 @@ if(USE_KINETO) set(KINETO_LIBRARY_TYPE "static" CACHE STRING "") set(CUDA_SOURCE_DIR "${CUDA_TOOLKIT_ROOT_DIR}" CACHE STRING "") + message(STATUS "Configuring Kineto dependency:") + message(STATUS " KINETO_SOURCE_DIR = ${KINETO_SOURCE_DIR}") + message(STATUS " KINETO_BUILD_TESTS = ${KINETO_BUILD_TESTS}") + message(STATUS " KINETO_LIBRARY_TYPE = ${KINETO_LIBRARY_TYPE}") + message(STATUS " CUDA_SOURCE_DIR = ${CUDA_SOURCE_DIR}") + + if(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/include) + set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/extras/CUPTI/include") + elseif(EXISTS ${CUDA_SOURCE_DIR}/include/cupti.h) + set(CUPTI_INCLUDE_DIR "${CUDA_SOURCE_DIR}/include") + endif() + + if((NOT DEFINED CUDA_cupti_LIBRARY) OR (${CUDA_cupti_LIBRARY} STREQUAL "CUDA_cupti_LIBRARY-NOTFOUND")) + if(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti_static.a) + set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti_static.a") + elseif(EXISTS ${CUDA_SOURCE_DIR}/lib64/libcupti_static.a) + set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/lib64/libcupti_static.a") + elseif(EXISTS ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti.so) + set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/extras/CUPTI/lib64/libcupti.so") + elseif(EXISTS ${CUDA_SOURCE_DIR}/lib64/libcupti.so) + set(CUDA_cupti_LIBRARY "${CUDA_SOURCE_DIR}/lib64/libcupti.so") + endif() + endif() + message(STATUS " CUDA_cupti_LIBRARY = ${CUDA_cupti_LIBRARY}") + message(STATUS " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}") + add_subdirectory("${KINETO_SOURCE_DIR}") - message(STATUS "Configured libkineto as a dependency.") + message(STATUS "Configured Kineto as a dependency.") endif() list(APPEND Caffe2_DEPENDENCY_LIBS kineto) diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index d677b12b349c..f5201532f438 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -2163,7 +2163,7 @@ TEST(TLSFutureCallbacksTest, Basic) { // test running callbacks with propagation of TLS state. { // Enable the profiler in this thread - torch::autograd::profiler::enableProfiler( + torch::autograd::profiler::enableProfilerLegacy( torch::autograd::profiler::ProfilerConfig( torch::autograd::profiler::ProfilerState::CPU, false, false)); auto s1 = c10::make_intrusive(IntType::get()); @@ -2172,12 +2172,12 @@ TEST(TLSFutureCallbacksTest, Basic) { // Since we join here, we can ensure that all callbacks corresponding to // markCompleted() have finished. t.join(); - torch::autograd::profiler::disableProfiler(); + torch::autograd::profiler::disableProfilerLegacy(); } // then() with TLS State { // Enable the profiler in this thread - torch::autograd::profiler::enableProfiler( + torch::autograd::profiler::enableProfilerLegacy( torch::autograd::profiler::ProfilerConfig( torch::autograd::profiler::ProfilerState::CPU, false, false)); auto s1 = c10::make_intrusive(IntType::get()); @@ -2190,7 +2190,7 @@ TEST(TLSFutureCallbacksTest, Basic) { std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); }); t.join(); s2->wait(); - torch::autograd::profiler::disableProfiler(); + torch::autograd::profiler::disableProfilerLegacy(); } } @@ -2199,7 +2199,7 @@ TEST(ProfilerDisableInCallbackTest, Basic) { auto profilerEnabledCb = []() { ASSERT_TRUE(torch::autograd::profiler::profilerEnabled()); }; - torch::autograd::profiler::enableProfiler( + torch::autograd::profiler::enableProfilerLegacy( torch::autograd::profiler::ProfilerConfig( torch::autograd::profiler::ProfilerState::CPU, false, false)); auto s1 = c10::make_intrusive(IntType::get()); @@ -2212,10 +2212,10 @@ TEST(ProfilerDisableInCallbackTest, Basic) { // Don't cleanup TLSState, and just consolidate. auto opts = torch::autograd::profiler::ProfilerDisableOptions(false, true); auto thread_event_lists = - torch::autograd::profiler::disableProfiler(std::move(opts)); + torch::autograd::profiler::disableProfilerLegacy(std::move(opts)); // Ensure that the events from this thread are still profiled and we obtain // the expected in events in our consolidated list when calling - // disableProfiler(). + // disableProfilerLegacy(). bool found_ones = false; bool found_add = false; for (const auto& li : thread_event_lists) { @@ -2237,13 +2237,13 @@ TEST(ProfilerDisableInCallbackTest, Basic) { s1->addCallback(verifyProfilerCb); // Disable the profiler, but do not consolidate results in the main thread. auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false); - torch::autograd::profiler::disableProfiler(std::move(opts)); + torch::autograd::profiler::disableProfilerLegacy(std::move(opts)); std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); }); t.join(); // Similar to above test, but verifies correctness in the case where // continuation runs on the main thread. - torch::autograd::profiler::enableProfiler( + torch::autograd::profiler::enableProfilerLegacy( torch::autograd::profiler::ProfilerConfig( torch::autograd::profiler::ProfilerState::CPU, false, false)); s1 = c10::make_intrusive(IntType::get()); @@ -2251,7 +2251,7 @@ TEST(ProfilerDisableInCallbackTest, Basic) { // Runs callback inline s1->markCompleted(at::IValue(1)); opts = torch::autograd::profiler::ProfilerDisableOptions(true, false); - torch::autograd::profiler::disableProfiler(std::move(opts)); + torch::autograd::profiler::disableProfilerLegacy(std::move(opts)); } TEST(IValueKWargsTest, Basic) { diff --git a/test/test_autograd.py b/test/test_autograd.py index 457a2234524a..8d4af26b72c8 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -33,7 +33,7 @@ suppress_warnings, slowTest, load_tests, random_symmetric_matrix, IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck) -from torch.autograd import Variable, Function, detect_anomaly +from torch.autograd import Variable, Function, detect_anomaly, kineto_available from torch.autograd.function import InplaceFunction from torch.testing import randn_like from torch.testing._internal.common_methods_invocations import (method_tests, @@ -2954,7 +2954,7 @@ def gen_matrices(p): https://github.com/pytorch/pytorch/issues/34086""") def test_profiler_tracing(self): t1, t2 = torch.ones(1), torch.ones(1) - with torch.autograd.profiler.profile() as prof: + with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof: torch.add(t1, t2) with tempfile.NamedTemporaryFile(mode="w+") as f: @@ -2969,7 +2969,7 @@ def test_profiler_tracing(self): device = torch.device("cuda:0") t1, t2 = torch.ones(1, device=device), torch.ones(1, device=device) - with torch.autograd.profiler.profile(use_cuda=True) as prof: + with torch.autograd.profiler.profile(use_cuda=True, use_kineto=kineto_available()) as prof: torch.add(t1, t2) with tempfile.NamedTemporaryFile(mode="w+") as f: @@ -2980,7 +2980,7 @@ def test_profiler_tracing(self): def test_profiler(self): x = torch.randn(10, 10) - with profile() as p: + with profile(use_kineto=kineto_available()) as p: self.assertTrue(torch.autograd._profiler_enabled()) y = x * 2 + 4 @@ -2991,22 +2991,21 @@ def test_profiler(self): 'aten::empty', 'aten::add', 'aten::to', 'aten::empty_strided', 'aten::copy_', 'aten::empty'] top_level_names = ['aten::mul', 'aten::add'] - top_level_iter = iter(top_level_names) - self.assertEqual(len(p.function_events), len(names)) - for info, expected_name in zip(p.function_events, names): - if info.cpu_interval.start > last_end: - top_level_name_expected = next(top_level_iter) - self.assertEqual(info.name, top_level_name_expected) - last_end = info.cpu_interval.end - self.assertEqual(info.name, expected_name) + for evt in p.function_events: + if evt.time_range.start > last_end: + self.assertTrue(evt.name in top_level_names) + last_end = evt.time_range.end + self.assertTrue(evt.name in names) def test_profiler_seq_nr(self): - with profile() as p: + with profile(use_kineto=kineto_available()) as p: x = torch.randn(10, 10, requires_grad=True) y = torch.randn(10, 10, requires_grad=True) z = x + y s = z.sum() s.backward() + print(p.key_averages().table( + sort_by="self_cpu_time_total", row_limit=-1)) # expecting aten::add, aten::sum to have the sequence numbers, # expecting the corresponding backward nodes to have the same numbers # as the forward ops @@ -3049,7 +3048,7 @@ def test_profiler_seq_nr(self): def test_profiler_unboxed_only(self): x = torch.rand(3, 4) - with torch.autograd.profiler.profile() as prof: + with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof: x.resize_([3, 2]) def test_profiler_propagation(self): @@ -3074,7 +3073,7 @@ def bar(x): traced_bar = torch.jit.trace(bar, x) - with profile() as p: + with profile(use_kineto=kineto_available()) as p: traced_bar(x) found_foo = False @@ -3096,7 +3095,7 @@ def bar(x): def test_record_function_callbacks(self): x = torch.randn(10, 10) - with profile() as p: + with profile(use_kineto=kineto_available()) as p: with record_function("foo"): y = x * 2 + 4 @@ -3128,12 +3127,12 @@ def get_id(): node_id=0, name="", thread=thread, - cpu_start=range[0], - cpu_end=range[1], + start_us=range[0], + end_us=range[1], ) ) - events.populate_cpu_children() + events._populate_cpu_children() # Note that [1, 3] pushes out [0, 2] first. Then we record [1, 2] # as a child of [1, 3] @@ -3152,7 +3151,7 @@ def test_profiler_aggregation_table(self): """ x = torch.randn(1024) - with torch.autograd.profiler.profile() as prof: + with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof: torch.einsum("i->", x) prof_str = str(prof) @@ -3162,8 +3161,8 @@ def test_profiler_aggregation_table(self): def test_profiler_function_event_avg(self): avg = FunctionEventAvg() - avg.add(FunctionEvent(id=0, node_id=0, name="foo", thread=0, cpu_start=10, cpu_end=15)) - avg.add(FunctionEvent(id=1, node_id=0, name="foo", thread=0, cpu_start=20, cpu_end=30)) + avg.add(FunctionEvent(id=0, node_id=0, name="foo", thread=0, start_us=10, end_us=15)) + avg.add(FunctionEvent(id=1, node_id=0, name="foo", thread=0, start_us=20, end_us=30)) avg.add(avg) self.assertEqual(avg.key, "foo") @@ -3182,7 +3181,7 @@ def test_profiler_shapes(self): layer1 = torch.nn.Linear(20, 30) layer2 = torch.nn.Linear(30, 40) input = torch.randn(128, 20) - with profile(record_shapes=True) as prof: + with profile(record_shapes=True, use_kineto=kineto_available()) as prof: layer2(layer1(input)) print(prof.function_events) @@ -3198,18 +3197,18 @@ def test_profiler_shapes(self): last_end = 0 for event in prof.function_events: - if event.cpu_interval.start > last_end: + if event.time_range.start > last_end: name_expected, input_shape_expected = next(expected_iter) if name_expected is not None: self.assertEqual(event.name, name_expected) self.assertEqual(event.input_shapes, input_shape_expected) - last_end = event.cpu_interval.end + last_end = event.time_range.end def test_profiler_no_cuda(self): print("") layer = torch.nn.Linear(20, 30) x = torch.randn(128, 20) - with profile(use_cuda=False) as prof: + with profile(use_cuda=False, use_kineto=kineto_available()) as prof: layer(x) prof_str = str(prof) @@ -3221,7 +3220,7 @@ def test_profiler_aggregation_lstm(self): print("") rnn = torch.nn.LSTM(10, 20, 2) total_time_s = 0 - with profile(record_shapes=True) as prof: + with profile(record_shapes=True, use_kineto=kineto_available()) as prof: for i in range(20): input = torch.randn(5, 3, 10) h = torch.randn(2, 3, 20) @@ -3258,7 +3257,7 @@ def test_profiler_aggregation_lstm(self): def test_memory_profiler(self): def run_profiler(tensor_creation_fn, metric): # collecting allocs / deallocs - with profile(profile_memory=True, record_shapes=True) as prof: + with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof: x = None with record_function("test_user_scope_alloc"): x = tensor_creation_fn() @@ -3350,7 +3349,7 @@ def create_mkldnn_tensor(): # check partial overlap of tensor allocation with memory profiler x = torch.rand(10, 10) - with profile(profile_memory=True, record_shapes=True) as prof: + with profile(profile_memory=True, record_shapes=True, use_kineto=kineto_available()) as prof: del x x = torch.rand(10, 10) del x @@ -3376,7 +3375,7 @@ def forward(x): forward(x) - with profile() as p: + with profile(use_kineto=kineto_available()) as p: forward(x) events = p.function_events @@ -3401,7 +3400,7 @@ def forward(x): def f(x, y): return x + y - with profile() as p: + with profile(use_kineto=kineto_available()) as p: f(1, 2) self.assertTrue('my_func' in str(p)) diff --git a/test/test_jit.py b/test/test_jit.py index 3daf9026f889..b2acf7b7040a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -44,6 +44,7 @@ from torch.autograd import Variable from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401 from torch.testing import FileCheck +import torch.autograd.profiler import torch.cuda import torch.jit import torch.jit._logging @@ -2552,10 +2553,10 @@ def fn(x): for e in prof.function_events: if e.name == "aten::mul": self.assertTrue(e.thread not in mul_events) - mul_events[e.thread] = e.cpu_interval.elapsed_us() + mul_events[e.thread] = e.time_range.elapsed_us() elif e.name == "other_fn": self.assertTrue(e.thread not in other_fn_events) - other_fn_events[e.thread] = e.cpu_interval.elapsed_us() + other_fn_events[e.thread] = e.time_range.elapsed_us() self.assertTrue(len(mul_events) == 2) self.assertTrue(len(other_fn_events) == 2) @@ -8268,7 +8269,7 @@ def _dtype_to_expect(self, dtype, dim=0): def _test_dtype_op_shape(self, ops, args, input_dims=1): if input_dims < 1: - raise 'input dims must be at least 1' + raise RuntimeError("input dims must be at least 1") dtypes = [torch.float32, torch.float64, torch.int64, torch.int32] str_args = ', '.join([str(arg) for arg in args]) + (', ' if len(args) else '') tensor_data = ('[' * input_dims) + '1, 2, 3' + (input_dims * ']') diff --git a/test/test_profiler.py b/test/test_profiler.py index f1feff1d0af3..797ad0995913 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -7,6 +7,7 @@ from torch.testing._internal.common_utils import ( TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS) from torch.autograd.profiler import profile +from torch.autograd import kineto_available try: import psutil @@ -73,7 +74,7 @@ def forward(self, x): mod = DummyModule() - with profile(with_stack=True) as p: + with profile(with_stack=True, use_kineto=kineto_available()) as p: x = torch.randn(10, 10, requires_grad=True) y = torch.randn(10, 10, requires_grad=True) z = x + y @@ -99,6 +100,34 @@ def forward(self, x): torch._C._set_graph_executor_optimize(prev_opt) + def payload(self): + x = torch.randn(10, 10).cuda() + y = torch.randn(10, 10).cuda() + z = torch.mm(x, y) + z = z + y + z = z.cpu() + + @unittest.skipIf(not kineto_available(), "Kineto is required") + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + def test_kineto(self): + with profile(use_cuda=True, use_kineto=True): + self.payload() + + # rerun to avoid initial start overhead + with profile(use_cuda=True, use_kineto=True) as p: + self.payload() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + found_gemm = False + found_memcpy = False + for e in p.function_events: + if "gemm" in e.name: + found_gemm = True + if "Memcpy" in e.name or "memcpy" in e.name: + found_memcpy = True + self.assertTrue(found_gemm) + self.assertTrue(found_memcpy) + # p.export_chrome_trace("/tmp/test_trace.json") if __name__ == '__main__': run_tests() diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 5d6146cf9268..6181239bd5b3 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -74,7 +74,8 @@ jit_core_sources = [ # list for the shared files. core_sources_common = [ - "torch/csrc/autograd/profiler.cpp", + "torch/csrc/autograd/profiler_legacy.cpp", + "torch/csrc/autograd/profiler_kineto.cpp", "torch/csrc/jit/frontend/edit_distance.cpp", "torch/csrc/jit/frontend/string_to_type.cpp", "torch/csrc/jit/mobile/type_parser.cpp", diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 926457fe80ee..db37db44d879 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Set from enum import Enum # Defined in tools/autograd/init.cpp @@ -8,7 +8,16 @@ class ProfilerState(Enum): CPU = ... CUDA = ... NVTX = ... + KINETO = ... +class ProfilerActivity(Enum): + CPU = ... + CUDA = ... + +class DeviceType(Enum): + CPU = ... + CUDA = ... + ... class ProfilerConfig: def __init__( @@ -37,9 +46,25 @@ class ProfilerEvent: def thread_id(self) -> int: ... ... +class KinetoEvent: + def name(self) -> str: ... + def device_index(self) -> int: ... + def start_us(self) -> int: ... + def duration_us(self) -> int: ... + ... -def _enable_profiler(config: ProfilerConfig) -> None: ... -def _disable_profiler() -> List[List[ProfilerEvent]]: ... +class ProfilerResult: + def events(self) -> List[KinetoEvent]: ... + def legacy_events(self) -> List[List[ProfilerEvent]]: ... + def save(self, str) -> None: ... + +def _enable_profiler(config: ProfilerConfig, activities: Set[ProfilerActivity]) -> None: ... +def _prepare_profiler(config: ProfilerConfig, activities: Set[ProfilerActivity]) -> None: ... +def _disable_profiler() -> ProfilerResult: ... def _profiler_enabled() -> bool: ... +def kineto_available() -> bool: ... def _enable_record_function(enable: bool) -> None: ... def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ... + +def _enable_profiler_legacy(config: ProfilerConfig) -> None: ... +def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ... diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index eabb07fd9de0..71537c562013 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -18,7 +18,6 @@ from .grad_mode import no_grad, enable_grad, set_grad_enabled from .anomaly_mode import detect_anomaly, set_detect_anomaly from ..overrides import has_torch_function, handle_torch_function -from . import profiler from . import functional __all__ = ['Variable', 'Function', 'backward', 'grad_mode'] @@ -251,6 +250,10 @@ def variable(*args, **kwargs): raise RuntimeError("autograd initialization failed") # Import all native method/classes -from torch._C._autograd import (ProfilerState, ProfilerConfig, ProfilerEvent, - _enable_profiler, _disable_profiler, _profiler_enabled, - _enable_record_function, _set_empty_test_observer) +from torch._C._autograd import (DeviceType, ProfilerActivity, ProfilerState, ProfilerConfig, ProfilerEvent, + _enable_profiler_legacy, _disable_profiler_legacy, _profiler_enabled, + _enable_record_function, _set_empty_test_observer, kineto_available) + +if kineto_available(): + from torch._C._autograd import (ProfilerResult, KinetoEvent, + _prepare_profiler, _enable_profiler, _disable_profiler) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index ba7d44814421..b8ee67101393 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -1,12 +1,13 @@ import itertools from typing import Any import torch +from torch.autograd import DeviceType from torch.futures import Future from collections import defaultdict, namedtuple from operator import attrgetter -from typing import List, Dict, Tuple, Optional +from typing import Dict, List, Tuple, Optional try: # Available in Python >= 3.2 @@ -37,14 +38,38 @@ def __init__(self, *args, **kwargs): use_cuda = kwargs.pop('use_cuda', True) profile_memory = kwargs.pop('profile_memory', False) super(EventList, self).__init__(*args, **kwargs) - self._cpu_children_populated = False self._use_cuda = use_cuda self._profile_memory = profile_memory + self._tree_built = False + + def _build_tree(self): + self._populate_cpu_children() + self._remove_dup_nodes() + self._set_backward_stacktraces() + self._tree_built = True def __str__(self): return self.table() - def populate_cpu_children(self): + def _remove_dup_nodes(self): + while True: + to_delete = [] + for idx in range(len(self)): + if (self[idx].cpu_parent is not None and + self[idx].cpu_parent.name == self[idx].name and + len(self[idx].cpu_parent.cpu_children) == 1): + self[idx].cpu_parent.cpu_children = self[idx].cpu_children + self[idx].cpu_parent.kernels = self[idx].kernels # lift kernels up + for ch in self[idx].cpu_children: + ch.cpu_parent = self[idx].cpu_parent + to_delete.append(idx) + if len(to_delete) == 0: + break + new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete] + self.clear() + self.extend(new_evts) + + def _populate_cpu_children(self): """Populates child events into each underlying FunctionEvent object. One event is a child of another if [s1, e1) is inside [s2, e2). Where s1 and e1 would be start and end of the child event's interval. And @@ -56,13 +81,11 @@ def populate_cpu_children(self): If for any reason two intervals intersect only partially, this function will not record a parent child relationship between then. """ - if self.cpu_children_populated: - return # Some events can be async (i.e. start and end on different threads), # since it's generally undefined how to attribute children ranges to # async ranges, we do not use them when calculating nested ranges and stats - sync_events = [evt for evt in self if not evt.is_async] + sync_events = [evt for evt in self if not evt.is_async and evt.device_type == DeviceType.CPU] events = sorted( sync_events, key=attrgetter("thread"), @@ -89,15 +112,15 @@ def populate_cpu_children(self): for thread_id, thread_events in threads: thread_events_ = sorted( thread_events, - key=lambda event: [event.cpu_interval.start, -event.cpu_interval.end], + key=lambda event: [event.time_range.start, -event.time_range.end], ) current_events: List[FunctionEvent] = [] cur_end = 0 for event in thread_events_: while len(current_events) > 0: parent = current_events[-1] - if event.cpu_interval.start >= parent.cpu_interval.end or \ - event.cpu_interval.end > parent.cpu_interval.end: + if event.time_range.start >= parent.time_range.end or \ + event.time_range.end > parent.time_range.end: # this can't be a parent current_events.pop() else: @@ -112,22 +135,18 @@ def populate_cpu_children(self): current_events.append(event) - self._cpu_children_populated = True - - def set_backward_stacktraces(self): - self.populate_cpu_children() - + def _set_backward_stacktraces(self): def bw_parent(evt): if evt is None: return None - elif evt.scope == 1: + elif evt.scope == 1: # BACKWARD_FUNCTION return evt else: return bw_parent(evt.cpu_parent) fwd_stacks = {} for evt in self: - if bw_parent(evt) is None: + if bw_parent(evt) is None and evt.stack is not None: t = (evt.sequence_nr, evt.thread) if t not in fwd_stacks: fwd_stacks[t] = evt.stack @@ -142,15 +161,10 @@ def bw_parent(evt): else: evt.stack = [] - @property def self_cpu_time_total(self): return sum([event.self_cpu_time_total for event in self]) - @property - def cpu_children_populated(self): - return self._cpu_children_populated - def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False): """Prints an EventList as a nicely formatted table. @@ -205,8 +219,8 @@ def export_chrome_trace(self, path): '"args": {}}, ' % ( evt.name, - evt.cpu_interval.start, - evt.cpu_interval.elapsed_us(), + evt.time_range.start, + evt.time_range.elapsed_us(), evt.thread if not evt.is_remote else f'" node_id:{evt.node_id}, thread_id:{evt.thread} "', @@ -222,7 +236,7 @@ def export_chrome_trace(self, path): '"pid": "CPU functions", ' '"id": %s, ' '"cat": "cpu_to_cuda", ' - '"args": {}}, ' % (evt.name, evt.cpu_interval.start, + '"args": {}}, ' % (evt.name, evt.time_range.start, evt.thread, next_id)) f.write('{"name": "%s", ' '"ph": "f", ' @@ -262,11 +276,11 @@ def key_averages(self, group_by_input_shapes=False, group_by_stack_n=0): Returns: An EventList containing FunctionEventAvg objects. """ - self.populate_cpu_children() - stats: Dict[Tuple[int, Tuple[int, int]], FunctionEventAvg] = defaultdict(FunctionEventAvg) + assert self._tree_built + stats: Dict[Tuple[str, ...], FunctionEventAvg] = defaultdict(FunctionEventAvg) - def get_key(event, group_by_input_shapes, group_by_stack_n): - key = [str(event.key), str(event.node_id)] + def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]: + key = [str(event.key), str(event.node_id), str(event.device_type), str(event.is_legacy)] if group_by_input_shapes: key.append(str(event.input_shapes)) if group_by_stack_n > 0: @@ -326,6 +340,11 @@ class profile(object): with_stack (bool, optional): record source information (file and line number) for the ops + use_kineto (bool, default False): experimental support for Kineto profiler + + use_cpu (default True) - whether to profile CPU events; setting to False requires + use_kineto=True and can be used to lower the overhead for GPU-only profiling + .. warning: Enabling memory profiling or source attribution incurs additional profiler overhead @@ -365,44 +384,83 @@ def __init__( use_cuda=False, record_shapes=False, profile_memory=False, - with_stack=False): - self.enabled = enabled - self.use_cuda = use_cuda - self.function_events = None + with_stack=False, + use_kineto=False, + use_cpu=True): + self.enabled: bool = enabled if not self.enabled: return + self.use_cuda = use_cuda + self.function_events = None self.entered = False self.record_shapes = record_shapes self.profile_memory = profile_memory self.with_stack = with_stack + self.use_cpu = use_cpu + self.kineto_results = None + if not self.use_cpu: + assert use_kineto, \ + "Device-only events supported only with Kineto (use_kineto=True)" + + self.profiler_kind = None + self.kineto_activities = set() + if use_kineto: + self.profiler_kind = torch.autograd.ProfilerState.KINETO + if self.use_cpu: + self.kineto_activities.add(torch.autograd.ProfilerActivity.CPU) + if self.use_cuda: + self.kineto_activities.add( + # uses CUPTI + torch.autograd.ProfilerActivity.CUDA) + assert len(self.kineto_activities) > 0, \ + "No activities specified for Kineto profiler" + elif self.use_cuda: + # legacy CUDA mode + self.profiler_kind = torch.autograd.ProfilerState.CUDA + else: + self.profiler_kind = torch.autograd.ProfilerState.CPU + + if self.profiler_kind == torch.autograd.ProfilerState.KINETO: + assert ( + torch.autograd.kineto_available() + ), """Requested Kineto profiling but Kineto is not available, + make sure PyTorch is built with USE_KINETO=1""" + + def config(self): + assert self.profiler_kind is not None + return torch.autograd.ProfilerConfig( + self.profiler_kind, + self.record_shapes, + self.profile_memory, + self.with_stack) def __enter__(self): if not self.enabled: return if self.entered: - raise RuntimeError("autograd profiler traces are not reentrant") + raise RuntimeError("profiler context manager is not reentrant") self.entered = True - profiler_kind = torch.autograd.ProfilerState.CUDA if self.use_cuda \ - else torch.autograd.ProfilerState.CPU - - config = torch.autograd.ProfilerConfig( - profiler_kind, - self.record_shapes, - self.profile_memory, - self.with_stack) - torch.autograd._enable_profiler(config) + if self.kineto_activities: + torch.autograd._prepare_profiler(self.config(), self.kineto_activities) + torch.autograd._enable_profiler(self.config(), self.kineto_activities) + else: + torch.autograd._enable_profiler_legacy(self.config()) return self def __exit__(self, exc_type, exc_val, exc_tb): if not self.enabled: return - records = torch.autograd._disable_profiler() + if self.kineto_activities: + self.kineto_results = torch.autograd._disable_profiler() + parsed_results = parse_kineto_results(self.kineto_results) + else: + records = torch.autograd._disable_profiler_legacy() + parsed_results = parse_legacy_records(records) self.function_events = EventList( - parse_event_records(records), + parsed_results, use_cuda=self.use_cuda, profile_memory=self.profile_memory) - if self.with_stack: - self.function_events.set_backward_stacktraces() + self.function_events._build_tree() return False def __repr__(self): @@ -413,13 +471,11 @@ def __repr__(self): def __str__(self): if self.function_events is None: return '' - self.function_events.populate_cpu_children() return str(self.function_events) def _check_finish(self): if self.function_events is None: raise RuntimeError("can't export a trace that didn't finish running") - self.function_events.populate_cpu_children() def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=None, top_level_events_only=False): self._check_finish() @@ -432,8 +488,11 @@ def table(self, sort_by=None, row_limit=100, max_src_column_width=75, header=Non def export_chrome_trace(self, path): self._check_finish() - assert self.function_events is not None - return self.function_events.export_chrome_trace(path) + if self.kineto_results is not None: + self.kineto_results.save(path) + else: + assert self.function_events is not None + return self.function_events.export_chrome_trace(path) export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__ def key_averages(self, group_by_input_shape=False, group_by_stack_n=0): @@ -630,7 +689,7 @@ def __enter__(self): raise RuntimeError("NVTX annotation context manager is not reentrant") self.entered = True torch.cuda.synchronize() - torch.autograd._enable_profiler( + torch.autograd._enable_profiler_legacy( torch.autograd.ProfilerConfig( torch.autograd.ProfilerState.NVTX, self.record_shapes, @@ -643,7 +702,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not self.enabled: return torch.cuda.synchronize() - torch.autograd._disable_profiler() + torch.autograd._disable_profiler_legacy() return False @@ -731,13 +790,14 @@ def elapsed_us(self): class FunctionEvent(FormattedTimesMixin): """Profiling information about a single function.""" def __init__( - self, id, node_id, name, thread, cpu_start, cpu_end, fwd_thread=None, input_shapes=None, + self, id, name, thread, start_us, end_us, fwd_thread=None, input_shapes=None, stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, is_async=False, - is_remote=True, sequence_nr=-1): + is_remote=False, sequence_nr=-1, node_id=-1, device_type=DeviceType.CPU, device_index=0, + is_legacy=False): self.id: int = id self.node_id: int = node_id self.name: str = name - self.cpu_interval: Interval = Interval(cpu_start, cpu_end) + self.time_range: Interval = Interval(start_us, end_us) self.thread: int = thread self.fwd_thread: Optional[int] = fwd_thread self.kernels: List[Kernel] = [] @@ -752,8 +812,12 @@ def __init__( self.is_async: bool = is_async self.is_remote: bool = is_remote self.sequence_nr: int = sequence_nr + self.device_type: DeviceType = device_type + self.device_index: int = device_index + self.is_legacy: bool = is_legacy def append_kernel(self, name, device, start, end): + assert self.device_type == DeviceType.CPU self.kernels.append(Kernel(name, device, Interval(start, end))) def append_cpu_child(self, child): @@ -762,7 +826,9 @@ def append_cpu_child(self, child): One is supposed to append only direct children to the event to have correct self cpu time being reported. """ + assert(self.device_type == DeviceType.CPU) assert(isinstance(child, FunctionEvent)) + assert(child.device_type == DeviceType.CPU) self.cpu_children.append(child) def set_cpu_parent(self, parent): @@ -772,14 +838,16 @@ def set_cpu_parent(self, parent): the child's range interval is completely inside the parent's. We use this connection to determine the event is from top-level op or not. """ + assert(self.device_type == DeviceType.CPU) assert(isinstance(parent, FunctionEvent)) + assert(parent.device_type == DeviceType.CPU) self.cpu_parent = parent # Note: async events don't have children, are not used when computing 'self' # metrics of other events, have only total cpu time @property def self_cpu_memory_usage(self): - if self.is_async: + if self.is_async or self.device_type != DeviceType.CPU: return 0 return self.cpu_memory_usage - sum( [child.cpu_memory_usage for child in self.cpu_children] @@ -787,7 +855,7 @@ def self_cpu_memory_usage(self): @property def self_cuda_memory_usage(self): - if self.is_async: + if self.is_async or self.device_type != DeviceType.CPU: return 0 return self.cuda_memory_usage - sum( [child.cuda_memory_usage for child in self.cpu_children] @@ -795,7 +863,7 @@ def self_cuda_memory_usage(self): @property def self_cpu_time_total(self): - if self.is_async: + if self.is_async or self.device_type != DeviceType.CPU: return 0 return self.cpu_time_total - sum( [child.cpu_time_total for child in self.cpu_children] @@ -803,16 +871,37 @@ def self_cpu_time_total(self): @property def cuda_time_total(self): - return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) + if self.is_async: + return 0 + if self.device_type == DeviceType.CPU: + if not self.is_legacy: + # account for the kernels in the children ops + return (sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) + + sum(ch.cuda_time_total for ch in self.cpu_children)) + else: + # each legacy cpu events has a single (fake) kernel + return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) + else: + assert self.device_type == DeviceType.CUDA + return self.time_range.elapsed_us() @property def self_cuda_time_total(self): - return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) - \ - sum([child.cuda_time_total for child in self.cpu_children]) + if self.is_async: + return 0 + if self.device_type == DeviceType.CPU: + return self.cuda_time_total - \ + sum([child.cuda_time_total for child in self.cpu_children]) + else: + assert(self.device_type == DeviceType.CUDA) + return self.cuda_time_total @property def cpu_time_total(self): - return self.cpu_interval.elapsed_us() + if self.device_type == DeviceType.CPU: + return self.time_range.elapsed_us() + else: + return 0 @property def key(self): @@ -820,14 +909,16 @@ def key(self): def __repr__(self): return ( - ''.format( + 'cpu_memory_usage={} cuda_memory_usage={} is_async={} is_remote={} seq_nr={} is_legacy={}>'.format( self.id, + self.name, + self.device_type, self.node_id, self.cpu_time_str, - self.cpu_interval.start, - self.cpu_interval.end, + self.time_range.start, + self.time_range.end, str([child.id for child in self.cpu_children]), self.cuda_time_str, self.name, @@ -838,6 +929,7 @@ def __repr__(self): self.is_async, self.is_remote, self.sequence_nr, + self.is_legacy, ) ) @@ -863,6 +955,8 @@ def __init__(self): self.self_cuda_memory_usage: int = 0 self.cpu_children: Optional[List[FunctionEvent]] = None self.cpu_parent: Optional[FunctionEvent] = None + self.device_type: DeviceType = DeviceType.CPU + self.is_legacy: bool = False def add(self, other): if self.key is None: @@ -878,6 +972,8 @@ def add(self, other): self.input_shapes = other.input_shapes self.stack = other.stack self.scope = other.scope + self.device_type = other.device_type + self.is_legacy = other.is_legacy assert isinstance(other, (FunctionEvent, FunctionEventAvg)) assert other.key == self.key @@ -923,10 +1019,111 @@ def __missing__(self, key): self[key] = torch._C._demangle(key) if len(key) > 1 else key return self[key] -def parse_event_records(thread_records): +def filter_stack_entry(entry): + filtered_entries = [ + ("autograd/__init__", "_make_grads"), + ("autograd/__init__", "backward"), + ("torch/tensor", "backward"), + ("_internal/common_utils", "prof_callable"), + ("_internal/common_utils", "prof_func_call"), + ("_internal/common_utils", "prof_meth_call"), + ] + return all([not (f[0] in entry and f[1] in entry) for f in filtered_entries]) + +def filter_name(name): + # ignoring the following utility ops + filtered_out_names = [ + "profiler::_record_function_enter", + "profiler::_record_function_exit", + "aten::is_leaf", + "aten::output_nr", + "aten::_version", + ] + return name in filtered_out_names + +# Parsing of kineto profiler events +def parse_kineto_results(result): + # result.events() has most of the events - PyTorch op-level and device-level events + # result.legacy_events() has events not yet ported to kineto + # (e.g. start/stop marks, tensor memory allocator events) + + # First, find __start_profile mark to get the absolute time of the start of the trace; + # save memory allocation records + start_record = None + mem_records = [] + for record in itertools.chain(*result.legacy_events()): + if record.kind() == 'mark' and record.name() == '__start_profile': + assert start_record is None + start_record = record + if record.kind() == 'memory_alloc': + mem_records.append(record) + assert start_record is not None, "Invalid profiler output, __start_profile is missing" + + # Create and return FunctionEvent list + string_table = StringTable() + function_events = [] + cuda_corr_map: Dict[int, List[torch.autograd.KinetoEvent]] = {} + for kineto_event in result.events(): + if filter_name(kineto_event.name()): + continue + rel_start_us = kineto_event.start_us() - start_record.start_us() + rel_end_us = rel_start_us + kineto_event.duration_us() + abs_end_us = kineto_event.start_us() + kineto_event.duration_us() + + cpu_memory_usage = 0 + cuda_memory_usage = 0 + if kineto_event.device_type() == DeviceType.CPU: + # find the corresponding memory allocation events + for mem_record in mem_records: + if (mem_record.start_us() >= kineto_event.start_us() and + mem_record.start_us() <= abs_end_us): + cpu_memory_usage += mem_record.cpu_memory_usage() + cuda_memory_usage += mem_record.cuda_memory_usage() + is_async = kineto_event.start_thread_id() != kineto_event.end_thread_id() + fe = FunctionEvent( + id=kineto_event.correlation_id(), + name=string_table[kineto_event.name()], + thread=kineto_event.start_thread_id(), + start_us=rel_start_us, + end_us=rel_end_us, + fwd_thread=kineto_event.fwd_thread_id(), + input_shapes=kineto_event.shapes(), + stack=[entry for entry in kineto_event.stack() if filter_stack_entry(entry)], + scope=kineto_event.scope(), + cpu_memory_usage=cpu_memory_usage, + cuda_memory_usage=cuda_memory_usage, + is_async=is_async, + sequence_nr=kineto_event.sequence_nr(), + device_type=kineto_event.device_type(), + device_index=kineto_event.device_index(), + ) + function_events.append(fe) + if kineto_event.device_type() == DeviceType.CUDA: + corr_id = kineto_event.linked_correlation_id() + if corr_id > 0: + if corr_id not in cuda_corr_map: + cuda_corr_map[corr_id] = [] + cuda_corr_map[corr_id].append(kineto_event) + + # associate CUDA kernels with CPU events + for fe in function_events: + if (fe.device_type == DeviceType.CPU and not fe.is_async and + fe.id in cuda_corr_map): + for k_evt in cuda_corr_map[fe.id]: + fe.append_kernel( + k_evt.name(), + k_evt.device_index(), + k_evt.start_us(), + k_evt.start_us() + k_evt.duration_us()) + + function_events.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end]) + return function_events + +# Parsing of legacy profiler events +def parse_legacy_records(thread_records): def get_record_key(record): """ - Returns a tuple to be used by parse_event_records for correlating start and + Returns a tuple to be used by parse_legacy_records for correlating start and end records. """ return (record.handle(), record.node_id()) @@ -938,26 +1135,6 @@ def get_record_key(record): record_stack = [] string_table = StringTable() - # ignoring the following utility ops - filtered_out_names = [ - "profiler::_record_function_enter", - "profiler::_record_function_exit", - "aten::is_leaf", - "aten::output_nr", - "aten::_version", - ] - - def filter_stack_entry(entry): - filtered_entries = [ - ("autograd/__init__", "_make_grads"), - ("autograd/__init__", "backward"), - ("torch/tensor", "backward"), - ("_internal/common_utils", "prof_callable"), - ("_internal/common_utils", "prof_func_call"), - ("_internal/common_utils", "prof_meth_call"), - ] - return all([not (f[0] in entry and f[1] in entry) for f in filtered_entries]) - # cuda start events and the overall profiler start event don't happen # at exactly the same time because we need to record an event on each device # and each record takes ~4us. So we adjust here by the difference @@ -994,7 +1171,7 @@ def adjusted_time(cuda_record, cuda_records_map): prev_record = None for record in thread_record_list: record_key = get_record_key(record) - if (record.name() in filtered_out_names or + if (filter_name(record.name()) or record_key in filtered_handles): filtered_handles.add(record_key) continue @@ -1035,8 +1212,8 @@ def adjusted_time(cuda_record, cuda_records_map): node_id=record.node_id(), name=string_table[start.name()], thread=start.thread_id(), - cpu_start=start_record.cpu_elapsed_us(start), - cpu_end=start_record.cpu_elapsed_us(record), + start_us=start_record.cpu_elapsed_us(start), + end_us=start_record.cpu_elapsed_us(record), fwd_thread=start.fwd_thread_id(), input_shapes=start.shapes(), stack=[entry for entry in start.stack() if filter_stack_entry(entry)], @@ -1046,6 +1223,8 @@ def adjusted_time(cuda_record, cuda_records_map): is_async=is_async, is_remote=is_remote_event, sequence_nr=start.sequence_nr(), + device_type=DeviceType.CPU, + is_legacy=True, ) # note: async events have only cpu total time if not is_async and start.has_cuda(): @@ -1074,7 +1253,7 @@ def adjusted_time(cuda_record, cuda_records_map): # granularity of the given clock tick)--we always show # the outermost nested call first. This adds stability # in how FunctionEvents appear - functions.sort(key=lambda evt: [evt.cpu_interval.start, -evt.cpu_interval.end]) + functions.sort(key=lambda evt: [evt.time_range.start, -evt.time_range.end]) return functions @@ -1121,8 +1300,8 @@ def parse_nvprof_trace(path): node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure # that pytorch doesn't crash when creating a FunctionEvent() object name=strings[row['name']], - cpu_start=row['start_time'], - cpu_end=row['end_time'], + start_us=row['start_time'], + end_us=row['end_time'], thread=0) # TODO: find in sqlite database functions.append(evt) functions_map[evt.id] = evt @@ -1153,7 +1332,7 @@ def parse_nvprof_trace(path): row['kernel_start'], row['kernel_end']) - functions.sort(key=lambda evt: evt.cpu_interval.start) + functions.sort(key=lambda evt: evt.time_range.start) return functions @@ -1182,7 +1361,9 @@ def build_table( has_input_shapes = any( [(event.input_shapes is not None and len(event.input_shapes) > 0) for event in events]) + MAX_NAME_COLUMN_WIDTH = 55 name_column_width = max([len(evt.key) for evt in events]) + 4 + name_column_width = min(name_column_width, MAX_NAME_COLUMN_WIDTH) DEFAULT_COLUMN_WIDTH = 12 @@ -1269,7 +1450,16 @@ def append(s): result.append('\n') # Yes, newline after the end as well self_cpu_time_total = sum([event.self_cpu_time_total for event in events]) - cuda_time_total = sum([evt.self_cuda_time_total for evt in events]) + cuda_time_total = 0 + for evt in events: + if evt.device_type == DeviceType.CPU: + # in legacy profiler, kernel info is stored in cpu events + if evt.is_legacy: + cuda_time_total += evt.self_cuda_time_total + elif evt.device_type == DeviceType.CUDA: + # in kineto mode, there're events with the correct device type (e.g. CUDA) + cuda_time_total += evt.self_cuda_time_total + # Actual printing if header is not None: append('=' * line_length) @@ -1290,8 +1480,11 @@ def append(s): continue else: event_limit += 1 + name = evt.key + if len(name) >= MAX_NAME_COLUMN_WIDTH - 3: + name = name[:(MAX_NAME_COLUMN_WIDTH - 3)] + "..." row_values = [ - evt.key, # Name + name, # Self CPU total, 0 for async events. % format_time_share(evt.self_cpu_time_total, self_cpu_time_total), diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 7f673952a2e7..78336ded0d88 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -1,5 +1,6 @@ #include +#include #include #include #include @@ -39,37 +40,132 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { .value("Disabled", ProfilerState::Disabled) .value("CPU", ProfilerState::CPU) .value("CUDA", ProfilerState::CUDA) - .value("NVTX", ProfilerState::NVTX); + .value("NVTX", ProfilerState::NVTX) + .value("KINETO", ProfilerState::KINETO); + + py::enum_(m, "ProfilerActivity") + .value("CPU", ActivityType::CPU) + .value("CUDA", ActivityType::CUDA); py::class_(m, "ProfilerConfig") .def(py::init()); - py::class_(m, "ProfilerEvent") - .def("kind", &Event::kind) - .def("name", [](const Event& e) { return e.name(); }) - .def("thread_id", &Event::threadId) - .def("fwd_thread_id", &Event::fwdThreadId) - .def("device", &Event::device) - .def("cpu_elapsed_us", &Event::cpuElapsedUs) - .def("cuda_elapsed_us", &Event::cudaElapsedUs) - .def("has_cuda", &Event::hasCuda) - .def("shapes", &Event::shapes) - .def("cpu_memory_usage", &Event::cpuMemoryUsage) - .def("cuda_memory_usage", &Event::cudaMemoryUsage) - .def("handle", &Event::handle) - .def("node_id", &Event::nodeId) - .def("is_remote", &Event::isRemote) - .def("sequence_nr", &Event::sequenceNr) - .def("stack", &Event::stack) - .def("scope", &Event::scope); + py::class_(m, "ProfilerEvent") + .def("kind", &LegacyEvent::kindStr) + .def("name", [](const LegacyEvent& e) { return e.name(); }) + .def("thread_id", &LegacyEvent::threadId) + .def("fwd_thread_id", &LegacyEvent::fwdThreadId) + .def("device", &LegacyEvent::device) + .def("cpu_elapsed_us", &LegacyEvent::cpuElapsedUs) + .def("cuda_elapsed_us", &LegacyEvent::cudaElapsedUs) + .def("has_cuda", &LegacyEvent::hasCuda) + .def("shapes", &LegacyEvent::shapes) + .def("cpu_memory_usage", &LegacyEvent::cpuMemoryUsage) + .def("cuda_memory_usage", &LegacyEvent::cudaMemoryUsage) + .def("handle", &LegacyEvent::handle) + .def("node_id", &LegacyEvent::nodeId) + .def("is_remote", &LegacyEvent::isRemote) + .def("sequence_nr", &LegacyEvent::sequenceNr) + .def("stack", &LegacyEvent::stack) + .def("scope", &LegacyEvent::scope) + .def("correlation_id", &LegacyEvent::correlationId) + .def("start_us", &LegacyEvent::cpuUs); - py::class_(m, "_ProfilerDisableOptions") - .def(py::init()); + py::enum_(m, "DeviceType") + .value("CPU", c10::DeviceType::CPU) + .value("CUDA", c10::DeviceType::CUDA) + .value("MKLDNN", c10::DeviceType::MKLDNN) + .value("OPENGL", c10::DeviceType::OPENGL) + .value("OPENCL", c10::DeviceType::OPENCL) + .value("IDEEP", c10::DeviceType::IDEEP) + .value("HIP", c10::DeviceType::HIP) + .value("FPGA", c10::DeviceType::FPGA) + .value("MSNPU", c10::DeviceType::MSNPU) + .value("XLA", c10::DeviceType::XLA) + .value("Vulkan", c10::DeviceType::Vulkan) + .value("Metal", c10::DeviceType::Metal); + +#ifdef USE_KINETO + py::class_(m, "KinetoEvent") + // name of the event + .def("name", &KinetoEvent::name) + // PyTorch thread id of the start callback + .def("start_thread_id", [](const KinetoEvent& e) { + return e.startThreadId(); + }) + // PyTorch thread id of the end callback + .def("end_thread_id", [](const KinetoEvent& e) { + return e.endThreadId(); + }) + // for events of scope BACKWARD_FUNCTION - PyTorch thread id + // of the corresponding forward op + .def("fwd_thread_id", [](const KinetoEvent& e) { + return e.fwdThreadId(); + }) + // together with fwd_thread_id, used to uniquely identify + // the forward op + .def("sequence_nr", [](const KinetoEvent& e) { + return e.sequenceNr(); + }) + // absolute start time (since unix epoch) in us + .def("start_us", &KinetoEvent::startUs) + // duration in us + .def("duration_us", &KinetoEvent::durationUs) + // used for correlation between high-level PyTorch events + // and low-level device events + .def("correlation_id", [](const KinetoEvent& e) { + return e.correlationId(); + }) + // shapes of input tensors + .def("shapes", [](const KinetoEvent& e) { + if (e.hasShapes()) { + return e.shapes(); + } else { + return std::vector>(); + } + }) + // stack traces of the PyTorch CPU events + .def("stack", [](const KinetoEvent& e) { + if (e.hasStack()) { + return e.stack(); + } else { + return std::vector(); + } + }) + // type of the RecordFunction that generated a PyTorch CPU event + // (op, torchscript function, user label, etc) + .def("scope", [](const KinetoEvent& e) { + return e.scope(); + }) + // device number, for CPU - process id + .def("device_index", &KinetoEvent::deviceIndex) + // for CUDA - stream id, for CPU - start thread id + .def("device_resource_id", &KinetoEvent::deviceResourceId) + // device type + .def("device_type", [](const KinetoEvent& e) { + return e.deviceType(); + }) + // correlation id of a linked event + .def("linked_correlation_id", &KinetoEvent::linkedCorrelationId); + + py::class_(m, "ProfilerResult") + .def("events", &ProfilerResult::events) + .def("legacy_events", &ProfilerResult::legacy_events) + .def("save", &ProfilerResult::save); m.def("_enable_profiler", enableProfiler); + m.def("_disable_profiler", disableProfiler); + m.def("_prepare_profiler", prepareProfiler); +#endif + + m.def("kineto_available", kinetoAvailable); + + m.def("_enable_profiler_legacy", enableProfilerLegacy); + py::class_(m, "_ProfilerDisableOptions") + .def(py::init()); m.def( - "_disable_profiler", - disableProfiler, + "_disable_profiler_legacy", + disableProfilerLegacy, py::arg("profiler_disable_options") = ProfilerDisableOptions()); m.def("_profiler_enabled", profilerEnabled); m.def("_enable_record_function", [](bool enable) { diff --git a/torch/csrc/autograd/profiler.h b/torch/csrc/autograd/profiler.h index 9cfe9ea1fd6e..7ac44096cda7 100644 --- a/torch/csrc/autograd/profiler.h +++ b/torch/csrc/autograd/profiler.h @@ -1,461 +1,4 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifndef _WIN32 -#include -#endif -#if defined(C10_IOS) && defined(C10_MOBILE) -#include // for gettimeofday() -#endif - -#include - -struct CUevent_st; -typedef std::shared_ptr CUDAEventStub; - -namespace torch { namespace autograd { - -struct Node; - -namespace profiler { - -struct TORCH_API CUDAStubs { - virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) { - fail(); - } - virtual float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) { - fail(); - return 0.f; - } - virtual void nvtxMarkA(const char* name) { - fail(); - } - virtual void nvtxRangePushA(const char* name) { - fail(); - } - virtual void nvtxRangePop() { - fail(); - } - virtual bool enabled() { - return false; - } - virtual void onEachDevice(std::function op) { - fail(); - } - virtual void synchronize() { - fail(); - } - virtual ~CUDAStubs(); - -private: - void fail() { - AT_ERROR("CUDA used in profiler but not enabled."); - } -}; - -TORCH_API void registerCUDAMethods(CUDAStubs* stubs); - -constexpr inline size_t ceilToMultiple(size_t a, size_t b) { - return ((a + b - 1) / b) * b; -} - -inline int64_t getTime() { -#if defined(C10_IOS) && defined(C10_MOBILE) -// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS can't rely on -// CLOCK_REALTIME, as it is defined no matter if clock_gettime is implemented or not - struct timeval now; - gettimeofday(&now, NULL); - return static_cast(now.tv_sec) * 1000000000 + static_cast(now.tv_usec) * 1000; -#elif defined(_WIN32) || defined(__MACH__) - using namespace std::chrono; - using clock = std::conditional::type; - return duration_cast(clock::now().time_since_epoch()).count(); -#else - // clock_gettime is *much* faster than std::chrono implementation on Linux - struct timespec t{}; - clock_gettime(CLOCK_MONOTONIC, &t); - return static_cast(t.tv_sec) * 1000000000 + static_cast(t.tv_nsec); -#endif -} - -// A struct to control settings of disableProfiler options. -struct TORCH_API ProfilerDisableOptions { - ProfilerDisableOptions() = default; - ProfilerDisableOptions(bool shouldCleanupTLSState, bool shouldConsolidate) - : cleanupTLSState(shouldCleanupTLSState), - consolidate(shouldConsolidate) {} - // Whether we should clean up profiler states that are thread local, such as - // ThreadLocalDebugInfo and thread local RecordFunction callbacks. - bool cleanupTLSState = true; - // Whether we should consolidate all currently recorded profiled events. If - // false, will not consolidate and other threads can continue to write to the - // event lists. - bool consolidate = true; -}; - -enum class C10_API_ENUM ProfilerState { - Disabled, - CPU, // CPU-only profiling - CUDA, // CPU + CUDA events - NVTX, // only emit NVTX markers -}; - -struct TORCH_API ProfilerConfig { - ProfilerConfig( - ProfilerState state, - bool report_input_shapes = false, - bool profile_memory = false, - bool with_stack = false) - : state(state), - report_input_shapes(report_input_shapes), - profile_memory(profile_memory), - with_stack(with_stack) {} - ~ProfilerConfig(); - ProfilerState state; - bool report_input_shapes; - bool profile_memory; - bool with_stack; - - // Returns IValues corresponding to ProfilerConfig struct, to be used for - // serialization. - at::IValue toIValue() const; - - // Reconstructs a ProfilerConfig from IValues given by toIValue. - static ProfilerConfig fromIValue(const at::IValue& profilerConfigIValue); - -}; - -enum class C10_API_ENUM EventKind : uint16_t { - Mark, - PushRange, - PopRange, - MemoryAlloc, -}; - -struct TORCH_API Event final { - Event( - EventKind kind, - at::StringView name, - uint16_t thread_id, - bool record_cuda, - at::RecordFunctionHandle handle = 0, - std::vector>&& shapes = {}, - int node_id = -1) - : name_(std::move(name)), - kind_(kind), - thread_id_(thread_id), - handle_(handle), - shapes_(shapes), - node_id_(node_id) { - record(record_cuda); - } - - // Constructor to be used in conjunction with Event::fromIValue. - Event( - EventKind kind, - at::StringView name, - uint16_t thread_id, - at::RecordFunctionHandle handle, - std::vector>&& shapes, - int node_id, - bool is_remote, - int64_t cpu_memory_usage, - int64_t cpu_ns, - bool cuda_recorded, - int64_t cuda_memory_usage = 0, - int device = -1, - double cuda_us = -1) - : cpu_ns_(cpu_ns), - name_(std::move(name)), - kind_(kind), - thread_id_(thread_id), - handle_(handle), - shapes_(shapes), - cpu_memory_usage_(cpu_memory_usage), - cuda_memory_usage_(cuda_memory_usage), - device_(device), - node_id_(node_id), - is_remote_(is_remote), - cuda_us_(cuda_us) { - // Sanity check values that were deserialized - TORCH_INTERNAL_ASSERT(cpu_ns_ > 0); - if (cuda_recorded) { - TORCH_INTERNAL_ASSERT(device_ >= 0); - TORCH_INTERNAL_ASSERT(cuda_us_ >= 0); - } - } - - // Returns IValues corresponding to event structure, to be used for - // serialization. - at::IValue toIValue() const; - - // Reconstructs an event from IValues given by toIValue. - static Event fromIValue(const at::IValue& eventIValue); - - void record(bool record_cuda); - std::string kind() const { - switch(kind_) { - case EventKind::Mark: return "mark"; - case EventKind::PushRange: return "push"; - case EventKind::PopRange: return "pop"; - case EventKind::MemoryAlloc: return "memory_alloc"; - } - throw std::runtime_error("unknown EventKind"); - } - - // Get enum kind of this event. - EventKind eventKind() const { - return kind_; - } - - const char* name() const { - return name_.str(); - } - - uint64_t threadId() const { - return thread_id_; - } - - std::vector> shapes() const { - return shapes_; - } - - double cpuElapsedUs(const Event& e) const { - return (e.cpu_ns_ - cpu_ns_)/(1000.0); - } - - double cpuUs() const { - return cpu_ns_ / (1000.0); - } - - double cudaElapsedUs(const Event& e) const; - - bool hasCuda() const { - return cuda_event != nullptr || (isRemote() && device_ != -1); - } - - int device() const { - return device_; - } - - void updateMemoryStats(int64_t alloc_size, c10::Device device) { - if (device.type() == c10::DeviceType::CUDA || - device.type() == c10::DeviceType::HIP) { - cuda_memory_usage_ = alloc_size; - } else if (device.type() == c10::DeviceType::CPU || - device.type() == c10::DeviceType::MKLDNN || - device.type() == c10::DeviceType::IDEEP) { - cpu_memory_usage_ = alloc_size; - } else { - LOG(WARNING) << "Unsupported memory profiling device: " << device; - } - } - - int64_t cpuMemoryUsage() const { - return cpu_memory_usage_; - } - - int64_t cudaMemoryUsage() const { - return cuda_memory_usage_; - } - - at::RecordFunctionHandle handle() const { - return handle_; - } - - // Node ID corresponding to this event. - int nodeId( ) const { - return node_id_; - } - - // Set Node ID on this event. - void setNodeId(int node_id) { - node_id_ = node_id; - } - - void setName(at::StringView newName_) { - name_ = std::move(newName_); - } - - bool isRemote() const { - return is_remote_; - } - - void setCudaUs(int64_t cuda_us) { - cuda_us_ = cuda_us; - } - - void setSequenceNr(int64_t sequence_nr) { - sequence_nr_ = sequence_nr; - } - - int64_t sequenceNr() const { - return sequence_nr_; - } - - const std::vector& stack() const { - return stack_; - } - - void setStack(const std::vector& stack) { - stack_ = stack; - } - - uint64_t fwdThreadId() const { - return fwd_thread_id_; - } - - void setFwdThreadId(uint64_t fwd_thread_id) { - fwd_thread_id_ = fwd_thread_id; - } - - uint8_t scope() const { - return scope_; - } - - void setScope(uint8_t scope) { - scope_ = scope; - } - - private: - // signed to allow for negative intervals, initialized for safety. - int64_t cpu_ns_ = 0; - at::StringView name_; - EventKind kind_; - uint64_t thread_id_; - uint64_t fwd_thread_id_; - at::RecordFunctionHandle handle_ {0}; - std::vector> shapes_; - int64_t cpu_memory_usage_ = 0; - int64_t cuda_memory_usage_ = 0; - int device_ = -1; - CUDAEventStub cuda_event = nullptr; - int node_id_ = 0; - bool is_remote_ = false; - int64_t cuda_us_ = -1; - int64_t sequence_nr_ = -1; - - std::vector stack_; - uint8_t scope_; -}; - -// a linked-list of fixed sized vectors, to avoid -// a std::vector resize from taking a large amount of time inside -// a profiling event -struct RangeEventList { - RangeEventList() { - events_.reserve(kReservedCapacity); - } - - template - void record(Args&&... args) { - std::lock_guard guard(mutex_); - events_.emplace_back(std::forward(args)...); - } - - std::vector consolidate() { - std::lock_guard lock(mutex_); - std::vector result; - result.insert( - result.begin(), - std::make_move_iterator(events_.begin()), - std::make_move_iterator(events_.end())); - events_.erase(events_.begin(), events_.end()); - return result; - } - - size_t size() { - std::lock_guard lock(mutex_); - return events_.size(); - } - - private: - // This mutex is used to serialize access when different threads are writing - // to the same instance of RangeEventList. - std::mutex mutex_; - std::vector events_; - - static const size_t kReservedCapacity = 1024; -}; - -using thread_event_lists = std::vector>; -// NOTE: profiler mode is thread local, with automatic propagation -// across thread boundary (e.g. at::launch tasks) -TORCH_API void enableProfiler(const ProfilerConfig&); -TORCH_API thread_event_lists disableProfiler(c10::optional profilerDisableOptions = c10::nullopt); -// adds profiledEvents to the current thread local recorded events. Each event -// will be marked with node ID given by fromNodeId. -TORCH_API void addEventList(std::vector&& profiledEvents); -// Returns if the profiler is currently enabled in the current thread. -TORCH_API bool profilerEnabled(); -// Retrieve the thread_local ProfilerConfig. -TORCH_API ProfilerConfig getProfilerConfig(); -// Writes profiled events to a stream. -TORCH_API void writeProfilerEventsToStream(std::ostream& out, const std::vector& events); - -// Usage: -// { -// RecordProfile guard("filename.trace"); -// // code you want to profile -// } -// Then open filename.trace in chrome://tracing -struct TORCH_API RecordProfile { - RecordProfile(std::ostream& out); - RecordProfile(const std::string& filename); - - ~RecordProfile(); -private: - void init(); - std::unique_ptr file_; - std::ostream& out_; - void processEvents(const std::vector& events); -}; - -// A guard that enables the profiler, taking in an optional callback to process -// the results -// Usage: -// { -// TLSProfilerGuard g([](thread_event_lists profilerResults) { -// // process profilerResults -// }); -// Code to profile -// } -struct TORCH_API TLSProfilerGuard { - explicit TLSProfilerGuard( - const ProfilerConfig& cfg, - c10::optional> - resultCallback = c10::nullopt, - c10::optional profilerDisableOptions = - c10::nullopt) - : cb_(std::move(resultCallback)), - profilerDisableOptions_(std::move(profilerDisableOptions)) { - enableProfiler(cfg); - } - ~TLSProfilerGuard() { - thread_event_lists event_lists = disableProfiler(profilerDisableOptions_); - if (cb_) { - try { - (*cb_)(event_lists); - } catch (const std::exception& e) { - LOG(ERROR) << "Got error processing profiler events: " << e.what(); - } - } - } - - private: - c10::optional> cb_; - const c10::optional profilerDisableOptions_; -}; - -} // namespace profiler -}} // namespace torch::autograd +#include +#include diff --git a/torch/csrc/autograd/profiler_cuda.cpp b/torch/csrc/autograd/profiler_cuda.cpp index ad677dbc6680..14dff19629fd 100644 --- a/torch/csrc/autograd/profiler_cuda.cpp +++ b/torch/csrc/autograd/profiler_cuda.cpp @@ -32,7 +32,7 @@ static inline void cudaCheck(cudaError_t result, const char * file, int line) { #define TORCH_CUDA_CHECK(result) cudaCheck(result,__FILE__,__LINE__); struct CUDAMethods : public CUDAStubs { - void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) override { + void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) const override { TORCH_CUDA_CHECK(cudaGetDevice(device)); CUevent_st* cuda_event_ptr; TORCH_CUDA_CHECK(cudaEventCreate(&cuda_event_ptr)); @@ -43,23 +43,28 @@ struct CUDAMethods : public CUDAStubs { *cpu_ns = getTime(); TORCH_CUDA_CHECK(cudaEventRecord(cuda_event_ptr, stream)); } - float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) override { + + float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) const override{ TORCH_CUDA_CHECK(cudaEventSynchronize(event->get())); TORCH_CUDA_CHECK(cudaEventSynchronize(event2->get())); float ms; TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event->get(), event2->get())); return ms*1000.0; } - void nvtxMarkA(const char* name) override { + + void nvtxMarkA(const char* name) const override { ::nvtxMark(name); } - void nvtxRangePushA(const char* name) override { + + void nvtxRangePushA(const char* name) const override { ::nvtxRangePushA(name); } - void nvtxRangePop() override { + + void nvtxRangePop() const override { ::nvtxRangePop(); } - void onEachDevice(std::function op) override { + + void onEachDevice(std::function op) const override { at::cuda::OptionalCUDAGuard device_guard; int count = at::cuda::device_count(); for(int i = 0; i < count; i++) { @@ -67,13 +72,14 @@ struct CUDAMethods : public CUDAStubs { op(i); } } - void synchronize() override { + + void synchronize() const override { cudaDeviceSynchronize(); } - bool enabled() override { + + bool enabled() const override { return true; } - }; struct RegisterCUDAMethods { diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp new file mode 100644 index 000000000000..7c91e76490a1 --- /dev/null +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -0,0 +1,368 @@ +#include + +#include +#include + +#include + +#ifdef USE_KINETO +#include +#include +#endif + +namespace torch { namespace autograd { namespace profiler { + +#ifdef USE_KINETO +namespace { +// TODO: consider TLS (tid + tls counter) +uint64_t next_correlation_id() { + static std::atomic corr_id_ {1}; + return corr_id_++; +} + +inline int64_t getTimeUs() { + using namespace std::chrono; + return duration_cast(high_resolution_clock::now().time_since_epoch()).count(); +} + +std::string shapesToStr(const std::vector>& shapes); + +struct TORCH_API KinetoThreadLocalState : public ProfilerThreadLocalState { + using ProfilerThreadLocalState::ProfilerThreadLocalState; + virtual ~KinetoThreadLocalState() override = default; + + void reportClientActivity( + const at::RecordFunction& fn, + const KinetoObserverContext* ctx) { + if (!ctx) { + return; + } + libkineto::ClientTraceActivity op; + op.startTime = ctx->startUs; + op.endTime = getTimeUs(); + op.opType = std::string(fn.name().str()); + op.device = 0; + op.threadId = ctx->startThreadId; + op.correlation = ctx->correlationId; + // optimization - postpone shapesToStr till finalizeCPUTrace + // is called from disableProfiler + // if (ctx->shapes && !ctx->shapes->empty()) { + // op.inputDims = shapesToStr(*ctx->shapes); + // } + + // Not setting atm + op.inputTypes = "[]"; + op.arguments = "[]"; + op.outputDims = "[]"; + op.outputTypes = "[]"; + op.inputNames = "[]"; + op.outputNames = "[]"; + + // + op.threadId = pthread_self(); + { + std::lock_guard guard(state_mutex_); + kineto_events_.emplace_back(); + kineto_events_.back() + .activity(op) + .startThreadId(ctx->startThreadId) + .endThreadId(ctx->endThreadId) + .sequenceNr(ctx->sequenceNr) + .fwdThreadId(ctx->fwdThreadId) + .scope(ctx->recFunScope); + if (ctx->shapes && !ctx->shapes->empty()) { + kineto_events_.back().shapes(*ctx->shapes); + } + if (ctx->stack && !ctx->stack->empty()) { + kineto_events_.back().stack(*ctx->stack); + } + cpu_trace->activities.emplace_back(std::move(op)); + } + } + + // TODO: use kineto + void reportMemoryUsage( + void* /* unused */, + int64_t alloc_size, + c10::Device device) override { + if (config_.profile_memory && config_.state != ProfilerState::Disabled) { + uint64_t thread_id = at::RecordFunction::currentThreadId(); + LegacyEvent evt( + EventKind::MemoryAlloc, + at::StringView(""), + thread_id, + config_.state == ProfilerState::CUDA); + evt.setCpuUs(getTimeUs()); // upd. time using Kineto's clock + evt.updateMemoryStats(alloc_size, device); + getEventList(thread_id).record(std::move(evt)); + } + } + + void addTraceEvents(libkineto::ActivityTraceInterface& trace) { + const auto& events = *(trace.activities()); + for (const auto& ev_ptr : events) { + // ClientTraceActivity events are already processed + if (ev_ptr->type() != libkineto::ActivityType::CPU_OP) { + kineto_events_.emplace_back(); + kineto_events_.back() + .activity(*ev_ptr); + } + } + } + + void finalizeCPUTrace() { + TORCH_INTERNAL_ASSERT(cpu_trace->activities.size() == kineto_events_.size()); + for (auto idx = 0; idx < cpu_trace->activities.size(); ++idx) { + if (kineto_events_[idx].hasShapes()) { + cpu_trace->activities[idx].inputDims = shapesToStr(kineto_events_[idx].shapes()); + } else { + cpu_trace->activities[idx].inputDims = "[]"; + } + } + } + + std::vector kineto_events_; + std::unique_ptr cpu_trace = + std::make_unique(); +}; + +KinetoThreadLocalState* getProfilerTLSState() { + const auto& state = c10::ThreadLocalDebugInfo::get( + c10::DebugInfoKind::PROFILER_STATE); + return static_cast(state); +} + +void pushProfilingCallbacks() { + auto state_ptr = getProfilerTLSState(); + TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); + auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback( + [](const at::RecordFunction& fn) { + auto state_ptr = getProfilerTLSState(); + if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) { + return std::make_unique(); + } + + auto corr_id = next_correlation_id(); + libkineto::api().activityProfiler().pushCorrelationId(corr_id); + + auto ctx_ptr = std::make_unique(); + ctx_ptr->startUs = getTimeUs(); + ctx_ptr->correlationId = corr_id; + ctx_ptr->startThreadId = at::RecordFunction::currentThreadId(); + + if (state_ptr->config().report_input_shapes) { + ctx_ptr->shapes = inputSizes(fn); + } + + ctx_ptr->sequenceNr = fn.seqNr(); + ctx_ptr->fwdThreadId = fn.forwardThreadId(); + ctx_ptr->recFunScope = (uint8_t)fn.scope(); + +#ifndef C10_MOBILE + // backward nodes source range corresponds to the forward node + // TODO: consider using C++ stack trace + if (state_ptr->config().with_stack && + fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { + auto cs = prepareCallstack(jit::currentCallstack()); + if (cs.empty()) { + cs = prepareCallstack(jit::tracer::pythonCallstack()); + } + ctx_ptr->stack = callstackStr(cs); + } +#endif + return ctx_ptr; + }, + [](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) { + auto state_ptr = getProfilerTLSState(); + if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) { + return; + } + auto* kineto_ctx_ptr = static_cast(ctx_ptr); + TORCH_INTERNAL_ASSERT(kineto_ctx_ptr != nullptr); + + kineto_ctx_ptr->endThreadId = at::RecordFunction::currentThreadId(); + + state_ptr->reportClientActivity(fn, kineto_ctx_ptr); + libkineto::api().activityProfiler().popCorrelationId(); + }) + .needsInputs(state_ptr->config().report_input_shapes) + .needsIds(true)); + state_ptr->setCallbackHandle(handle); +} + +std::string shapesToStr(const std::vector>& shapes) { + std::ostringstream oss; + oss << "["; + for (auto t_idx = 0; t_idx < shapes.size(); ++t_idx) { + if (t_idx > 0) { + oss << ", "; + } + oss << "["; + for (auto s_idx = 0; s_idx < shapes[t_idx].size(); ++s_idx) { + if (s_idx > 0) { + oss << ", "; + } + oss << shapes[t_idx][s_idx]; + } + oss << "]"; + } + oss << "]"; + return oss.str(); +} + +} // namespace + +void prepareProfiler( + const ProfilerConfig& config, + const std::set& activities) { + TORCH_CHECK(config.state == ProfilerState::KINETO, + "Supported only in Kineto profiler"); + + std::set cpuTypes = { + libkineto::ActivityType::CPU_OP, + libkineto::ActivityType::EXTERNAL_CORRELATION, + libkineto::ActivityType::CUDA_RUNTIME, + }; + + std::set cudaTypes = { + libkineto::ActivityType::GPU_MEMCPY, + libkineto::ActivityType::GPU_MEMSET, + libkineto::ActivityType::CONCURRENT_KERNEL, + // also including CUDA_RUNTIME + libkineto::ActivityType::CUDA_RUNTIME, + }; + + std::set k_activities; + if (activities.count(ActivityType::CPU)) { + k_activities.insert(cpuTypes.begin(), cpuTypes.end()); + } + if (activities.count(ActivityType::CUDA)) { + k_activities.insert(cudaTypes.begin(), cudaTypes.end()); + } + + if (!libkineto::api().isProfilerRegistered()) { + libkineto_init(); + } + + if (!libkineto::api().isProfilerInitialized()) { + libkineto::api().initProfilerIfRegistered(); + } + + libkineto::api().activityProfiler().prepareTrace(k_activities); +} + +void enableProfiler( + const ProfilerConfig& config, + const std::set& activities) { + TORCH_CHECK(config.state == ProfilerState::KINETO); + TORCH_CHECK(!activities.empty(), "No activities specified for Kineto profiler"); + + auto state_ptr = getProfilerTLSState(); + TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread"); + auto state = std::make_shared(config); + c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state); + + state->cpu_trace = std::make_unique(); + state->cpu_trace->span.startTime = getTimeUs(); + // TODO: number of GPU ops + state->cpu_trace->gpuOpCount = -1; + state->cpu_trace->span.name = "PyTorch Profiler"; + + if (activities.count(ActivityType::CPU)) { + pushProfilingCallbacks(); + } + + libkineto::api().activityProfiler().startTrace(); + + state->mark("__start_profile", false); +} + +std::unique_ptr disableProfiler() { + // all the DebugInfoBase objects are scope based and supposed to use DebugInfoGuard + auto state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE); + + auto state_ptr = static_cast(state.get()); + TORCH_CHECK(state_ptr && state_ptr->config().state == ProfilerState::KINETO, + "Can't disable Kineto profiler when it's not running"); + + if (state_ptr->hasCallbackHandle()) { + at::removeCallback(state_ptr->callbackHandle()); + } + + state_ptr->mark("__stop_profile"); + + state_ptr->cpu_trace->span.endTime = getTimeUs(); + + state_ptr->finalizeCPUTrace(); + libkineto::api().activityProfiler().transferCpuTrace(std::move(state_ptr->cpu_trace)); + + auto trace = std::move(libkineto::api().activityProfiler().stopTrace()); + TORCH_CHECK(trace); + state_ptr->addTraceEvents(*trace); + return std::make_unique( + std::move(state_ptr->kineto_events_), + std::move(state_ptr->consolidate()), + std::move(trace)); +} + +KinetoEvent& KinetoEvent::activity(const libkineto::TraceActivity& activity) { + name_ = activity.name(); + device_index_ = activity.deviceId(); + device_resource_id_ = activity.resourceId(); + start_us_ = activity.timestamp(); + duration_us_ = activity.duration(); + correlation_id_ = activity.correlationId(); + activity_type_ = (uint8_t)activity.type(); + if (activity.linkedActivity()) { + linked_correlation_id_ = activity.linkedActivity()->correlationId(); + } + return *this; +} + +c10::DeviceType KinetoEvent::deviceType() const { + switch (activity_type_) { + case (uint8_t)libkineto::ActivityType::CPU_OP: + return c10::DeviceType::CPU; + case (uint8_t)libkineto::ActivityType::GPU_MEMCPY: + return c10::DeviceType::CUDA; + case (uint8_t)libkineto::ActivityType::GPU_MEMSET: + return c10::DeviceType::CUDA; + case (uint8_t)libkineto::ActivityType::CONCURRENT_KERNEL: + return c10::DeviceType::CUDA; + case (uint8_t)libkineto::ActivityType::EXTERNAL_CORRELATION: + return c10::DeviceType::CPU; + case (uint8_t)libkineto::ActivityType::CUDA_RUNTIME: + return c10::DeviceType::CPU; + } + TORCH_CHECK(false, "Unknown activity type"); +} + +KinetoEvent::KinetoEvent() : activity_type_((uint8_t)libkineto::ActivityType::CPU_OP) {} + +ProfilerResult::ProfilerResult( + std::vector events, + thread_event_lists legacy_events, + std::unique_ptr trace) + : events_(std::move(events)), + legacy_events_(std::move(legacy_events)), + trace_(std::move(trace)) {} +ProfilerResult::~ProfilerResult() {} + +void ProfilerResult::save(const std::string& path) { + // Kineto's save is destructive + TORCH_CHECK(!saved_, "Trace is already saved"); + trace_->save(path); + saved_ = true; +} + +#endif + +bool kinetoAvailable() { +#ifdef USE_KINETO + return true; +#else + return false; +#endif +} + +}}} diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h new file mode 100644 index 000000000000..a1c2b2122e41 --- /dev/null +++ b/torch/csrc/autograd/profiler_kineto.h @@ -0,0 +1,213 @@ +#pragma once + +#include + +#ifdef USE_KINETO +namespace libkineto { +class TraceActivity; +class ActivityTraceInterface; +} +#endif + +namespace torch { +namespace autograd { +namespace profiler { + +enum class C10_API_ENUM ActivityType { + CPU = 0, + CUDA, // CUDA kernels, runtime + NUM_KINETO_ACTIVITIES, // must be the last one +}; + +#ifdef USE_KINETO + +struct KinetoObserverContext : public at::ObserverContext { + int64_t startUs; + uint64_t correlationId; + uint64_t startThreadId; + uint64_t endThreadId; + c10::optional>> shapes; + int64_t sequenceNr; + uint64_t fwdThreadId; + uint8_t recFunScope; + c10::optional> stack; +}; + +struct TORCH_API KinetoEvent { + KinetoEvent(); + + uint64_t startThreadId() const { + return start_thread_id_; + } + + uint64_t endThreadId() const { + return end_thread_id_; + } + + uint8_t activityType() const { + return activity_type_; + } + + uint64_t fwdThreadId() const { + return fwd_thread_id_; + } + + bool hasShapes() const { + return shapes_ != c10::nullopt; + } + + const std::vector>& shapes() const { + return *shapes_; + } + + int64_t sequenceNr() const { + return sequence_nr_; + } + + bool hasStack() const { + return stack_ != c10::nullopt; + } + + const std::vector& stack() const { + return *stack_; + } + + uint8_t scope() const { + return scope_; + } + + KinetoEvent& startThreadId(uint64_t start_thread_id) { + start_thread_id_ = start_thread_id; + return *this; + } + + KinetoEvent& endThreadId(uint64_t end_thread_id) { + end_thread_id_ = end_thread_id; + return *this; + } + + KinetoEvent& fwdThreadId(uint64_t fwd_thread_id) { + fwd_thread_id_ = fwd_thread_id; + return *this; + } + + KinetoEvent& shapes(const std::vector>& shapes) { + shapes_ = shapes; + return *this; + } + + KinetoEvent& sequenceNr(int64_t sequence_nr) { + sequence_nr_ = sequence_nr; + return *this; + } + + KinetoEvent& stack(const std::vector& st) { + stack_ = st; + return *this; + } + + KinetoEvent& scope(uint8_t scope) { + scope_ = scope; + return *this; + } + + // Kineto fields + + KinetoEvent& activity(const libkineto::TraceActivity& activity); + + std::string name() const { + return name_; + } + + uint64_t deviceIndex() const { + return device_index_; + } + + uint64_t startUs() const { + return start_us_; + } + + uint64_t durationUs() const { + return duration_us_; + } + + uint64_t correlationId() const { + return correlation_id_; + } + + KinetoEvent& correlationId(uint64_t correlation_id) { + correlation_id_ = correlation_id; + return *this; + } + + uint64_t linkedCorrelationId() const { + return linked_correlation_id_; + } + + int64_t deviceResourceId() const { + return device_resource_id_; + } + + c10::DeviceType deviceType() const; + + uint64_t start_thread_id_ = 0; + uint64_t end_thread_id_ = 0; + uint64_t fwd_thread_id_ = 0; + int64_t sequence_nr_ = -1; + uint8_t scope_ = 0; + + uint8_t activity_type_; + c10::optional>> shapes_; + c10::optional> stack_; + + std::string name_; + uint64_t device_index_ = 0; + uint64_t start_us_ = 0; + uint64_t duration_us_ = 0; + uint64_t correlation_id_ = 0; + uint64_t linked_correlation_id_ = 0; + int64_t device_resource_id_ = 0; +}; + +// Consolidating events returned directly from Kineto +// with events manually created by us (e.g. start/stop marks, +// memory allocation events) +struct TORCH_API ProfilerResult { + ProfilerResult( + std::vector events, + thread_event_lists legacy_events, + std::unique_ptr trace); + ~ProfilerResult(); + + const std::vector& events() const { + return events_; + } + + const thread_event_lists& legacy_events() const { + return legacy_events_; + } + + void save(const std::string& path); + + private: + bool saved_ = false; + std::vector events_; + thread_event_lists legacy_events_; + std::unique_ptr trace_; +}; + +TORCH_API void enableProfiler( + const ProfilerConfig& config, + const std::set& activities); + +TORCH_API std::unique_ptr disableProfiler(); + +TORCH_API void prepareProfiler( + const ProfilerConfig& config, + const std::set& activities); +#endif // USE_KINETO + +TORCH_API bool kinetoAvailable(); + +} // namespace profiler +}} // namespace torch::autograd diff --git a/torch/csrc/autograd/profiler.cpp b/torch/csrc/autograd/profiler_legacy.cpp similarity index 61% rename from torch/csrc/autograd/profiler.cpp rename to torch/csrc/autograd/profiler_legacy.cpp index 7d50b794648f..88cf22321865 100644 --- a/torch/csrc/autograd/profiler.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -23,54 +23,33 @@ namespace torch { namespace autograd { namespace profiler { -namespace { - -enum EventIValueIdx { - KIND = 0, - NAME, - THREAD_ID, - HANDLE, - NODE_ID, - CPU_MEM_USAGE, - CPU_NS, - CUDA_RECORDED, - CUDA_MEM_USAGE, - CUDA_DEVICE, - CUDA_US, - SHAPES, - NUM_EVENT_IVALUE_IDX // must be last in list -}; - -enum ProfilerIValueIdx { - STATE = 0, - REPORT_INPUT_SHAPES, - PROFILE_MEMORY, - NUM_PROFILER_CFG_IVALUE_IDX // must be last in list -}; +std::vector prepareCallstack(const std::vector& cs) { + std::vector entries; + entries.reserve(cs.size()); + for (const auto& entry : cs) { + auto& range = entry.range; + if (range.source()) { + auto& src = range.source(); + if (src && src->filename()) { + auto line = src->starting_line_no() + + src->lineno_for_offset(range.start()); + entries.emplace_back(FileLineFunc{*(src->filename()), line, entry.filename}); + } + } + } + return entries; +} - const std::unordered_set disable_cuda_profiling = { - "aten::view", - "aten::t", - "aten::transpose", - "aten::stride", - "aten::empty", - "aten::empty_like", - "aten::empty_strided", - "aten::as_strided", - "aten::expand", - "aten::resize_", - "aten::squeeze", - "aten::unsqueeze", - "aten::slice", - "aten::_unsafe_view", - "aten::size" - }; - -CUDAStubs default_stubs; -constexpr CUDAStubs* default_stubs_addr = &default_stubs; -// Constant initialization, so it is guaranteed to be initialized before -// static initialization calls which may invoke registerCUDAMethods -static CUDAStubs* cuda_stubs = default_stubs_addr; +std::vector callstackStr(const std::vector& cs) { + std::vector cs_str; + cs_str.reserve(cs.size()); + for (const auto& entry : cs) { + std::stringstream loc; + loc << entry.filename << "(" << entry.line << "): " << entry.funcname; + cs_str.push_back(loc.str()); + } + return cs_str; +} // We decompose the profiler logic into the following components: // @@ -163,252 +142,267 @@ static CUDAStubs* cuda_stubs = default_stubs_addr; // - save profiling events into the profiling state // -struct FileLineFunc { - std::string filename; - size_t line; - std::string funcname; -}; +namespace { +const CUDAStubs default_stubs; +constexpr const CUDAStubs* default_stubs_addr = &default_stubs; +// Constant initialization, so it is guaranteed to be initialized before +// static initialization calls which may invoke registerCUDAMethods +inline const CUDAStubs*& cuda_stubs() { + static const CUDAStubs* stubs_ = default_stubs_addr; + return stubs_; +} +} // Profiler state -struct ProfilerThreadLocalState : public c10::MemoryReportingInfoBase { - explicit ProfilerThreadLocalState(const ProfilerConfig& config) - : config_(config), remoteProfiledEvents_{c10::nullopt} {} - ~ProfilerThreadLocalState() override = default; +const ProfilerConfig& ProfilerThreadLocalState::config() const { + return config_; +} - inline const ProfilerConfig& config() const { - return config_; +thread_event_lists ProfilerThreadLocalState::consolidate() { + std::lock_guard g(state_mutex_); + thread_event_lists result; + for (auto& kv : event_lists_map_) { + auto& list = kv.second; + result.emplace_back(list->consolidate()); } - - thread_event_lists consolidate() { - std::lock_guard g(state_mutex_); - thread_event_lists result; - for (auto& kv : event_lists_map_) { - auto& list = kv.second; - result.emplace_back(list->consolidate()); - } - // Consolidate remote events if applicable as well. - if (remoteProfiledEvents_) { - result.insert( - result.end(), - std::make_move_iterator(remoteProfiledEvents_->begin()), - std::make_move_iterator(remoteProfiledEvents_->end())); - } - return result; + // Consolidate remote events if applicable as well. + if (remoteProfiledEvents_) { + result.insert( + result.end(), + std::make_move_iterator(remoteProfiledEvents_->begin()), + std::make_move_iterator(remoteProfiledEvents_->end())); } + return result; +} - void mark(std::string name, bool include_cuda = true) { - if (config_.state == ProfilerState::Disabled) { - return; - } - if (config_.state == ProfilerState::NVTX) { - cuda_stubs->nvtxMarkA(name.c_str()); - } else { - Event evt( - EventKind::Mark, - at::StringView(std::move(name)), - at::RecordFunction::currentThreadId(), - include_cuda && config_.state == ProfilerState::CUDA); - evt.setNodeId(at::RecordFunction::getDefaultNodeId()); - getEventList().record(std::move(evt)); - } +void ProfilerThreadLocalState::mark(std::string name, bool include_cuda) { + if (config_.state == ProfilerState::Disabled) { + return; + } + if (config_.state == ProfilerState::NVTX) { + cuda_stubs()->nvtxMarkA(name.c_str()); + } else { + LegacyEvent evt( + EventKind::Mark, + at::StringView(std::move(name)), + at::RecordFunction::currentThreadId(), + include_cuda && config_.state == ProfilerState::CUDA); + evt.setNodeId(at::RecordFunction::getDefaultNodeId()); + getEventList().record(std::move(evt)); } +} - void setOrAddRemoteProfiledEvents( - std::vector&& remoteProfiledEvents) { - // Lock to serialize access from multiple callback threads. - std::lock_guard guard(state_mutex_); - if (remoteProfiledEvents_) { - (*remoteProfiledEvents_).emplace_back(remoteProfiledEvents); - } else { - remoteProfiledEvents_ = {std::move(remoteProfiledEvents)}; - } +void ProfilerThreadLocalState::setOrAddRemoteProfiledEvents( + std::vector&& remoteProfiledEvents) { + // Lock to serialize access from multiple callback threads. + std::lock_guard guard(state_mutex_); + if (remoteProfiledEvents_) { + (*remoteProfiledEvents_).emplace_back(remoteProfiledEvents); + } else { + remoteProfiledEvents_ = {std::move(remoteProfiledEvents)}; } +} - void pushRange( - const at::RecordFunction& fn, - const bool record_cuda, - const char* msg = "", - std::vector>&& shapes = {}) { - if (config_.state == ProfilerState::Disabled) { - return; - } - if (config_.state == ProfilerState::NVTX) { - cuda_stubs->nvtxRangePushA(getNvtxStr( - fn.name(), msg, fn.seqNr(), shapes).c_str()); - } else { - Event evt( - EventKind::PushRange, - fn.name(), - at::RecordFunction::currentThreadId(), - record_cuda, - fn.handle(), - std::move(shapes), - at::RecordFunction::getDefaultNodeId()); - evt.setSequenceNr(fn.seqNr()); - evt.setFwdThreadId(fn.forwardThreadId()); - evt.setScope((uint8_t)fn.scope()); +void ProfilerThreadLocalState::pushRange( + const at::RecordFunction& fn, + const bool record_cuda, + const char* msg, + std::vector>&& shapes) { + if (config_.state == ProfilerState::Disabled) { + return; + } + if (config_.state == ProfilerState::NVTX) { + cuda_stubs()->nvtxRangePushA(getNvtxStr( + fn.name(), msg, fn.seqNr(), shapes).c_str()); + } else { + LegacyEvent evt( + EventKind::PushRange, + fn.name(), + at::RecordFunction::currentThreadId(), + record_cuda, + fn.handle(), + std::move(shapes), + at::RecordFunction::getDefaultNodeId()); + evt.setSequenceNr(fn.seqNr()); + evt.setFwdThreadId(fn.forwardThreadId()); + evt.setScope((uint8_t)fn.scope()); #ifndef C10_MOBILE - // backward nodes source range corresponds to the forward node - // TODO: consider using C++ stack trace - if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { - auto cs = prepareCallstack(jit::currentCallstack()); - if (cs.empty()) { - cs = prepareCallstack(jit::tracer::pythonCallstack()); - } - evt.setStack(callstackStr(cs)); + // backward nodes source range corresponds to the forward node + // TODO: consider using C++ stack trace + if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { + auto cs = prepareCallstack(jit::currentCallstack()); + if (cs.empty()) { + cs = prepareCallstack(jit::tracer::pythonCallstack()); } -#endif - getEventList().record(std::move(evt)); - } - } - - void popRange(const at::RecordFunction& fn, const bool record_cuda) { - if (config_.state == ProfilerState::Disabled) { - return; - } - if (config_.state == ProfilerState::NVTX) { - cuda_stubs->nvtxRangePop(); - } else { - // In some cases RecordFunction (and popRange) may be - // called on a different thread than pushRange - // As a convention, we put the async pop on the original - // thread and save current thread id in pop event - Event evt( - EventKind::PopRange, - at::StringView(""), - at::RecordFunction::currentThreadId(), - record_cuda, - fn.handle()); - evt.setNodeId(at::RecordFunction::getDefaultNodeId()); - getEventList(fn.threadId()).record(std::move(evt)); + evt.setStack(callstackStr(cs)); } +#endif + getEventList().record(std::move(evt)); } +} - void setCallbackHandle(at::CallbackHandle handle) { - handle_ = handle; - } - - at::CallbackHandle callbackHandle() const { - return handle_; - } - - void reportMemoryUsage( - void* /* unused */, - int64_t alloc_size, - c10::Device device) override { - if (config_.profile_memory && config_.state != ProfilerState::Disabled) { - uint64_t thread_id = at::RecordFunction::currentThreadId(); - Event evt( - EventKind::MemoryAlloc, - at::StringView(""), - thread_id, - config_.state == ProfilerState::CUDA); - evt.updateMemoryStats(alloc_size, device); - getEventList(thread_id).record(std::move(evt)); - } +void ProfilerThreadLocalState::popRange(const at::RecordFunction& fn, const bool record_cuda) { + if (config_.state == ProfilerState::Disabled) { + return; } - - bool memoryProfilingEnabled() const override { - return config_.profile_memory; + if (config_.state == ProfilerState::NVTX) { + cuda_stubs()->nvtxRangePop(); + } else { + // In some cases RecordFunction (and popRange) may be + // called on a different thread than pushRange + // As a convention, we put the async pop on the original + // thread and save current thread id in pop event + LegacyEvent evt( + EventKind::PopRange, + at::StringView(""), + at::RecordFunction::currentThreadId(), + record_cuda, + fn.handle()); + evt.setNodeId(at::RecordFunction::getDefaultNodeId()); + getEventList(fn.threadId()).record(std::move(evt)); } +} - private: - std::vector prepareCallstack(const std::vector& cs) { - std::vector entries; - entries.reserve(cs.size()); - for (const auto& entry : cs) { - auto& range = entry.range; - if (range.source()) { - auto& src = range.source(); - if (src && src->filename()) { - auto line = src->starting_line_no() + - src->lineno_for_offset(range.start()); - entries.emplace_back(FileLineFunc{*(src->filename()), line, entry.filename}); - } - } - } - return entries; +void ProfilerThreadLocalState::reportMemoryUsage( + void* /* unused */, + int64_t alloc_size, + c10::Device device) { + if (config_.profile_memory && config_.state != ProfilerState::Disabled) { + uint64_t thread_id = at::RecordFunction::currentThreadId(); + LegacyEvent evt( + EventKind::MemoryAlloc, + at::StringView(""), + thread_id, + config_.state == ProfilerState::CUDA); + evt.updateMemoryStats(alloc_size, device); + getEventList(thread_id).record(std::move(evt)); } +} - std::vector callstackStr(const std::vector& cs) { - std::vector cs_str; - cs_str.reserve(cs.size()); - for (const auto& entry : cs) { - std::stringstream loc; - loc << entry.filename << "(" << entry.line << "): " << entry.funcname; - cs_str.push_back(loc.str()); - } - return cs_str; - } +bool ProfilerThreadLocalState::memoryProfilingEnabled() const { + return config_.profile_memory; +} - std::string getNvtxStr( - const at::StringView& name, - const char* msg, - int64_t sequence_nr, - const std::vector>& shapes) const { - if (sequence_nr >= 0 || shapes.size() > 0) { - std::stringstream s; +std::string ProfilerThreadLocalState::getNvtxStr( + const at::StringView& name, + const char* msg, + int64_t sequence_nr, + const std::vector>& shapes) const { + if (sequence_nr >= 0 || shapes.size() > 0) { + std::stringstream s; #ifdef __HIP_PLATFORM_HCC__ - s << name.str(); + s << name.str(); #endif - if (sequence_nr >= 0) { + if (sequence_nr >= 0) { #ifdef __HIP_PLATFORM_HCC__ - s << msg << sequence_nr; + s << msg << sequence_nr; #else - s << name.str() << msg << sequence_nr; + s << name.str() << msg << sequence_nr; #endif - } - if (shapes.size() > 0) { - s << ", sizes = ["; - for (size_t idx = 0; idx < shapes.size(); ++idx) { - if (shapes[idx].size() > 0) { - s << "["; - for (size_t dim = 0; dim < shapes[idx].size(); ++dim) { - s << shapes[idx][dim]; - if (dim < shapes[idx].size() - 1) { - s << ", "; - } + } + if (shapes.size() > 0) { + s << ", sizes = ["; + for (size_t idx = 0; idx < shapes.size(); ++idx) { + if (shapes[idx].size() > 0) { + s << "["; + for (size_t dim = 0; dim < shapes[idx].size(); ++dim) { + s << shapes[idx][dim]; + if (dim < shapes[idx].size() - 1) { + s << ", "; } - s << "]"; - } else { - s << "[]"; - } - if (idx < shapes.size() - 1) { - s << ", "; } + s << "]"; + } else { + s << "[]"; + } + if (idx < shapes.size() - 1) { + s << ", "; } - s << "]"; } - return s.str(); - } else { - return name.str(); + s << "]"; } + return s.str(); + } else { + return name.str(); } +} + +RangeEventList& ProfilerThreadLocalState::getEventList(int64_t thread_id) { + if (thread_id < 0) { + thread_id = at::RecordFunction::currentThreadId(); + } + RangeEventList* list_ptr = nullptr; + std::lock_guard guard(state_mutex_); + auto it = event_lists_map_.find(thread_id); + if (it != event_lists_map_.end()) { + list_ptr = it->second.get(); + } else { + auto event_list = std::make_shared(); + event_lists_map_[thread_id] = event_list; + list_ptr = event_list.get(); + } + return *list_ptr; +} - RangeEventList& getEventList(int64_t thread_id = -1) { - if (thread_id < 0) { - thread_id = at::RecordFunction::currentThreadId(); +std::vector> inputSizes(const at::RecordFunction& fn) { + std::vector> sizes; + sizes.reserve(fn.inputs().size()); + for (const c10::IValue& input : fn.inputs()) { + if (!input.isTensor()) { + sizes.emplace_back(); + continue; } - RangeEventList* list_ptr = nullptr; - std::lock_guard guard(state_mutex_); - auto it = event_lists_map_.find(thread_id); - if (it != event_lists_map_.end()) { - list_ptr = it->second.get(); + const at::Tensor& tensor = input.toTensor(); + if (tensor.defined()) { + sizes.push_back(input.toTensor().sizes().vec()); } else { - auto event_list = std::make_shared(); - event_lists_map_[thread_id] = event_list; - list_ptr = event_list.get(); + sizes.emplace_back(); } - return *list_ptr; } + return sizes; +} - std::mutex state_mutex_; - std::unordered_map> - event_lists_map_; +namespace { - ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled); - at::CallbackHandle handle_ = 0; - c10::optional>> remoteProfiledEvents_; +enum EventIValueIdx { + KIND = 0, + NAME, + THREAD_ID, + HANDLE, + NODE_ID, + CPU_MEM_USAGE, + CPU_NS, + CUDA_RECORDED, + CUDA_MEM_USAGE, + CUDA_DEVICE, + CUDA_US, + SHAPES, + NUM_EVENT_IVALUE_IDX // must be last in list +}; + +enum ProfilerIValueIdx { + STATE = 0, + REPORT_INPUT_SHAPES, + PROFILE_MEMORY, + NUM_PROFILER_CFG_IVALUE_IDX // must be last in list +}; + +const std::unordered_set disable_cuda_profiling = { + "aten::view", + "aten::t", + "aten::transpose", + "aten::stride", + "aten::empty", + "aten::empty_like", + "aten::empty_strided", + "aten::as_strided", + "aten::expand", + "aten::resize_", + "aten::squeeze", + "aten::unsqueeze", + "aten::slice", + "aten::_unsafe_view", + "aten::size" }; ProfilerThreadLocalState* getProfilerTLSState() { @@ -416,7 +410,7 @@ ProfilerThreadLocalState* getProfilerTLSState() { c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE)); } -void pushProfilingCallbacks() { +void pushProfilingCallbacksLegacy() { auto state_ptr = getProfilerTLSState(); TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback( @@ -433,21 +427,8 @@ void pushProfilingCallbacks() { auto* msg = (fn.seqNr() >= 0) ? ", seq = " : ""; if (state_ptr->config().report_input_shapes) { - std::vector> inputSizes; - inputSizes.reserve(fn.inputs().size()); - for (const c10::IValue& input : fn.inputs()) { - if (!input.isTensor()) { - inputSizes.emplace_back(); - continue; - } - const at::Tensor& tensor = input.toTensor(); - if (tensor.defined()) { - inputSizes.push_back(input.toTensor().sizes().vec()); - } else { - inputSizes.emplace_back(); - } - } - state_ptr->pushRange(fn, record_cuda, msg, std::move(inputSizes)); + auto sizes = inputSizes(fn); + state_ptr->pushRange(fn, record_cuda, msg, std::move(sizes)); } else { state_ptr->pushRange(fn, record_cuda, msg); } @@ -474,11 +455,9 @@ const int kCUDAWarmupStart = 5; } // namespace void registerCUDAMethods(CUDAStubs* stubs) { - cuda_stubs = stubs; + cuda_stubs() = stubs; } -ProfilerConfig::~ProfilerConfig() = default; - at::IValue ProfilerConfig::toIValue() const { c10::impl::GenericList eventIValueList(at::AnyType::get()); eventIValueList.reserve(NUM_PROFILER_CFG_IVALUE_IDX); @@ -519,38 +498,40 @@ bool profilerEnabled() { return state_ptr && state_ptr->config().state != ProfilerState::Disabled; } -void enableProfiler(const ProfilerConfig& new_config) { - TORCH_CHECK(new_config.state != ProfilerState::NVTX || cuda_stubs->enabled(), +void enableProfilerLegacy(const ProfilerConfig& new_config) { + TORCH_CHECK(new_config.state != ProfilerState::NVTX || cuda_stubs()->enabled(), "Can't use NVTX profiler - PyTorch was compiled without CUDA"); + TORCH_CHECK(new_config.state != ProfilerState::KINETO); + auto state_ptr = getProfilerTLSState(); TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread"); auto state = std::make_shared(new_config); c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state); - pushProfilingCallbacks(); + pushProfilingCallbacksLegacy(); if (new_config.state == ProfilerState::CUDA) { // event recording appears to have some startup overhead, so we need to // to generate some dummy events first before recording synchronization events for (int idx = 0; idx < kCUDAWarmupStart; ++idx) { - cuda_stubs->onEachDevice([state](int /* unused */) { + cuda_stubs()->onEachDevice([state](int /* unused */) { state->mark("__cuda_startup"); - cuda_stubs->synchronize(); + cuda_stubs()->synchronize(); }); } // cuda events must be on the same device, so we need a start event recorded // for each gpu. we then use this event to synchronize time on the GPU // with the CPU clock. - cuda_stubs->onEachDevice([state](int d) { + cuda_stubs()->onEachDevice([state](int d) { state->mark("__cuda_start_event"); }); } state->mark("__start_profile", false); } -thread_event_lists disableProfiler(c10::optional profilerDisableOptions) { +thread_event_lists disableProfilerLegacy(c10::optional profilerDisableOptions) { auto cleanupTLSState = profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true; auto consolidate = profilerDisableOptions ? profilerDisableOptions->consolidate : true; // all the DebugInfoBase objects are scope based and supposed to use DebugInfoGuard @@ -578,21 +559,21 @@ thread_event_lists disableProfiler(c10::optional profile return state_ptr->consolidate(); } -void addEventList(std::vector&& profiledEvents) { +void addEventList(std::vector&& profiledEvents) { auto state_ptr = getProfilerTLSState(); TORCH_CHECK(state_ptr, "Profiler must be enabled."); state_ptr->setOrAddRemoteProfiledEvents(std::move(profiledEvents)); } -void Event::record(bool record_cuda) { +void LegacyEvent::record(bool record_cuda) { if (record_cuda) { - cuda_stubs->record(&device_, &cuda_event, &cpu_ns_); + cuda_stubs()->record(&device_, &cuda_event, &cpu_ns_); return; } cpu_ns_ = getTime(); } -/* static */ Event Event::fromIValue(const at::IValue& eventIValue) { +/* static */ LegacyEvent LegacyEvent::fromIValue(const at::IValue& eventIValue) { TORCH_INTERNAL_ASSERT( eventIValue.isList(), "Expected IValue to contain type c10::impl::GenericList"); @@ -601,7 +582,7 @@ void Event::record(bool record_cuda) { ivalues.size() >= NUM_EVENT_IVALUE_IDX, "Expected at least ", NUM_EVENT_IVALUE_IDX, - " elements to reconstruct Event."); + " elements to reconstruct LegacyEvent."); // Reconstruct input shapes from ivalues. auto shapeListIValue = ivalues.get(EventIValueIdx::SHAPES); @@ -627,7 +608,7 @@ void Event::record(bool record_cuda) { shapes.emplace_back(s); } - Event evt( + LegacyEvent evt( static_cast( ivalues.get(EventIValueIdx::KIND).toInt()), // EventKind at::StringView(ivalues.get(EventIValueIdx::NAME).toStringRef()), // name @@ -647,7 +628,7 @@ void Event::record(bool record_cuda) { return evt; } -at::IValue Event::toIValue() const { +at::IValue LegacyEvent::toIValue() const { c10::impl::GenericList eventIValueList(at::AnyType::get()); eventIValueList.reserve(NUM_EVENT_IVALUE_IDX); eventIValueList.emplace_back(static_cast(kind_)); @@ -679,7 +660,7 @@ at::IValue Event::toIValue() const { return at::IValue(eventIValueList); } -double Event::cudaElapsedUs(const Event& e) const { +double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const { TORCH_CHECK(e.hasCuda() && hasCuda(), "Events were not recorded for CUDA"); TORCH_CHECK( e.device() == device(), @@ -690,13 +671,12 @@ double Event::cudaElapsedUs(const Event& e) const { TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0); return static_cast(e.cuda_us_ - cuda_us_); } - return cuda_stubs->elapsed(&cuda_event, &e.cuda_event); + return cuda_stubs()->elapsed(&cuda_event, &e.cuda_event); } CUDAStubs::~CUDAStubs() = default; - -static jit::CodeTemplate event_template(R"( +static const jit::CodeTemplate event_template(R"( { "name": "${name}", "ph": "X", @@ -707,10 +687,10 @@ static jit::CodeTemplate event_template(R"( "args": {} })"); -void writeProfilerEventsToStream(std::ostream& out, const std::vector& events) { +void writeProfilerEventsToStream(std::ostream& out, const std::vector& events) { TORCH_CHECK(out, "Could not open file"); - Event* profiler_start = nullptr; - for (Event* e : events) { + LegacyEvent* profiler_start = nullptr; + for (LegacyEvent* e : events) { if (0 == strcmp(e->name(), "__start_profile")) { profiler_start = e; break; @@ -724,20 +704,20 @@ void writeProfilerEventsToStream(std::ostream& out, const std::vector& e return std::hash()(p.first) ^ std::hash()(p.second); } }; - std::unordered_map, Event*, PairHash> events_map; + std::unordered_map, LegacyEvent*, PairHash> events_map; out << "[\n"; bool first = true; - for (Event* evt : events) { - if (evt->kind() == "push") { + for (LegacyEvent* evt : events) { + if (evt->kindStr() == "push") { events_map[std::make_pair(evt->handle(), evt->nodeId())] = evt; - } else if (evt->kind() == "pop") { + } else if (evt->kindStr() == "pop") { if (!first) { out << ",\n"; } first = false; auto it = events_map.find(std::make_pair(evt->handle(), evt->nodeId())); TORCH_CHECK(it != events_map.end(), "Unmatched pop event"); - Event* evt_start = it->second; + LegacyEvent* evt_start = it->second; events_map.erase(it); jit::TemplateEnv env; @@ -751,7 +731,6 @@ void writeProfilerEventsToStream(std::ostream& out, const std::vector& e out << "]\n"; } - RecordProfile::RecordProfile(std::ostream& out) : out_(out) { init(); @@ -763,24 +742,27 @@ RecordProfile::RecordProfile(const std::string& filename) } void RecordProfile::init() { - enableProfiler(ProfilerConfig(ProfilerState::CPU)); + enableProfilerLegacy(ProfilerConfig(ProfilerState::CPU)); } RecordProfile::~RecordProfile() { - thread_event_lists event_lists = disableProfiler(); - std::vector events; - for (auto& l : event_lists) { - for (auto& e : l) { - events.push_back(&e); + try { + thread_event_lists event_lists = disableProfilerLegacy(); + std::vector events; + for (auto& l : event_lists) { + for (auto& e : l) { + events.push_back(&e); + } } - } - processEvents(events); - if (file_){ - file_->close(); + processEvents(events); + } catch (const std::exception& e) { + LOG(ERROR) << e.what() << std::endl; + } catch (...) { + LOG(ERROR) << "Unknown error" << std::endl; } } -void RecordProfile::processEvents(const std::vector& events) { +void RecordProfile::processEvents(const std::vector& events) { writeProfilerEventsToStream(out_, events); } diff --git a/torch/csrc/autograd/profiler_legacy.h b/torch/csrc/autograd/profiler_legacy.h new file mode 100644 index 000000000000..8ccc3b3c1189 --- /dev/null +++ b/torch/csrc/autograd/profiler_legacy.h @@ -0,0 +1,544 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef _WIN32 +#include +#endif +#if defined(C10_IOS) && defined(C10_MOBILE) +#include // for gettimeofday() +#endif + +#include + +#include + +struct CUevent_st; +typedef std::shared_ptr CUDAEventStub; + +namespace torch { namespace autograd { + +struct Node; + +namespace profiler { + +struct TORCH_API CUDAStubs { + virtual void record(int* device, CUDAEventStub* event, int64_t* cpu_ns) const { + fail(); + } + virtual float elapsed(const CUDAEventStub* event, const CUDAEventStub* event2) const { + fail(); + return 0.f; + } + virtual void nvtxMarkA(const char* name) const { + fail(); + } + virtual void nvtxRangePushA(const char* name) const { + fail(); + } + virtual void nvtxRangePop() const { + fail(); + } + virtual bool enabled() const { + return false; + } + virtual void onEachDevice(std::function op) const { + fail(); + } + virtual void synchronize() const { + fail(); + } + virtual ~CUDAStubs(); + +private: + void fail() const { + AT_ERROR("CUDA used in profiler but not enabled."); + } +}; + +TORCH_API void registerCUDAMethods(CUDAStubs* stubs); + +constexpr inline size_t ceilToMultiple(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +inline int64_t getTime() { +#if defined(C10_IOS) && defined(C10_MOBILE) +// clock_gettime is only available on iOS 10.0 or newer. Unlike OS X, iOS can't rely on +// CLOCK_REALTIME, as it is defined no matter if clock_gettime is implemented or not + struct timeval now; + gettimeofday(&now, NULL); + return static_cast(now.tv_sec) * 1000000000 + static_cast(now.tv_usec) * 1000; +#elif defined(_WIN32) || defined(__MACH__) + using namespace std::chrono; + using clock = std::conditional::type; + return duration_cast(clock::now().time_since_epoch()).count(); +#else + // clock_gettime is *much* faster than std::chrono implementation on Linux + struct timespec t{}; + clock_gettime(CLOCK_MONOTONIC, &t); + return static_cast(t.tv_sec) * 1000000000 + static_cast(t.tv_nsec); +#endif +} + +enum class C10_API_ENUM EventKind : uint16_t { + Mark, + PushRange, + PopRange, + MemoryAlloc, +}; + +// To be deprecated, once we switch to Kineto profiling +struct TORCH_API LegacyEvent { + LegacyEvent( + EventKind kind, + at::StringView name, + uint16_t thread_id, + bool record_cuda, + at::RecordFunctionHandle handle = 0, + std::vector>&& shapes = {}, + int node_id = -1) + : name_(std::move(name)), + kind_(kind), + thread_id_(thread_id), + handle_(handle), + shapes_(shapes), + node_id_(node_id) { + record(record_cuda); + } + + // Constructor to be used in conjunction with LegacyEvent::fromIValue. + LegacyEvent( + EventKind kind, + at::StringView name, + uint16_t thread_id, + at::RecordFunctionHandle handle, + std::vector>&& shapes, + int node_id, + bool is_remote, + int64_t cpu_memory_usage, + int64_t cpu_ns, + bool cuda_recorded, + int64_t cuda_memory_usage = 0, + int device = -1, + double cuda_us = -1) + : cpu_ns_(cpu_ns), + name_(std::move(name)), + kind_(kind), + thread_id_(thread_id), + handle_(handle), + shapes_(shapes), + cpu_memory_usage_(cpu_memory_usage), + cuda_memory_usage_(cuda_memory_usage), + device_(device), + node_id_(node_id), + is_remote_(is_remote), + cuda_us_(cuda_us) { + // Sanity check values that were deserialized + TORCH_INTERNAL_ASSERT(cpu_ns_ > 0); + if (cuda_recorded) { + TORCH_INTERNAL_ASSERT(device_ >= 0); + TORCH_INTERNAL_ASSERT(cuda_us_ >= 0); + } + } + + // Returns IValues corresponding to event structure, to be used for + // serialization. + at::IValue toIValue() const; + + // Reconstructs an event from IValues given by toIValue. + static LegacyEvent fromIValue(const at::IValue& eventIValue); + + void record(bool record_cuda); + + std::string kindStr() const { + switch (kind_) { + case EventKind::Mark: return "mark"; + case EventKind::PushRange: return "push"; + case EventKind::PopRange: return "pop"; + case EventKind::MemoryAlloc: return "memory_alloc"; + } + throw std::runtime_error("unknown event kind"); + } + + const char* name() const { + return name_.str(); + } + + uint64_t threadId() const { + return thread_id_; + } + + std::vector> shapes() const { + return shapes_; + } + + double cpuElapsedUs(const LegacyEvent& e) const { + return (e.cpu_ns_ - cpu_ns_)/(1000.0); + } + + void setCpuUs(int64_t cpu_us) { + cpu_ns_ = cpu_us * 1000.0; + } + + double cpuUs() const { + return cpu_ns_ / (1000.0); + } + + double cudaElapsedUs(const LegacyEvent& e) const; + + bool hasCuda() const { + return cuda_event != nullptr || (isRemote() && device_ != -1); + } + + int device() const { + return device_; + } + + void updateMemoryStats(int64_t alloc_size, c10::Device device) { + if (device.type() == c10::DeviceType::CUDA || + device.type() == c10::DeviceType::HIP) { + cuda_memory_usage_ = alloc_size; + } else if (device.type() == c10::DeviceType::CPU || + device.type() == c10::DeviceType::MKLDNN || + device.type() == c10::DeviceType::IDEEP) { + cpu_memory_usage_ = alloc_size; + } else { + LOG(WARNING) << "Unsupported memory profiling device: " << device; + } + } + + int64_t cpuMemoryUsage() const { + return cpu_memory_usage_; + } + + int64_t cudaMemoryUsage() const { + return cuda_memory_usage_; + } + + at::RecordFunctionHandle handle() const { + return handle_; + } + + // Node ID corresponding to this event. + int nodeId( ) const { + return node_id_; + } + + // Set Node ID on this event. + void setNodeId(int node_id) { + node_id_ = node_id; + } + + void setName(at::StringView newName_) { + name_ = std::move(newName_); + } + + bool isRemote() const { + return is_remote_; + } + + void setCudaUs(int64_t cuda_us) { + cuda_us_ = cuda_us; + } + + void setSequenceNr(int64_t sequence_nr) { + sequence_nr_ = sequence_nr; + } + + int64_t sequenceNr() const { + return sequence_nr_; + } + + void setCorrelationId(uint64_t correlation_id) { + correlation_id_ = correlation_id; + } + + uint64_t correlationId() const { + return correlation_id_; + } + + const std::vector& stack() const { + return stack_; + } + + void setStack(const std::vector& stack) { + stack_ = stack; + } + + uint64_t fwdThreadId() const { + return fwd_thread_id_; + } + + void setFwdThreadId(uint64_t fwd_thread_id) { + fwd_thread_id_ = fwd_thread_id; + } + + uint8_t scope() const { + return scope_; + } + + void setScope(uint8_t scope) { + scope_ = scope; + } + + private: + // signed to allow for negative intervals, initialized for safety. + int64_t cpu_ns_ = 0; + at::StringView name_; + EventKind kind_; + uint64_t thread_id_; + uint64_t fwd_thread_id_; + at::RecordFunctionHandle handle_ {0}; + std::vector> shapes_; + int64_t cpu_memory_usage_ = 0; + int64_t cuda_memory_usage_ = 0; + int device_ = -1; + CUDAEventStub cuda_event = nullptr; + int node_id_ = 0; + bool is_remote_ = false; + int64_t cuda_us_ = -1; + int64_t sequence_nr_ = -1; + + std::vector stack_; + uint8_t scope_; + uint64_t correlation_id_; +}; + +// a linked-list of fixed sized vectors, to avoid +// a std::vector resize from taking a large amount of time inside +// a profiling event +struct RangeEventList { + RangeEventList() { + events_.reserve(kReservedCapacity); + } + + template + void record(Args&&... args) { + std::lock_guard guard(mutex_); + events_.emplace_back(std::forward(args)...); + } + + std::vector consolidate() { + std::lock_guard lock(mutex_); + std::vector result; + result.insert( + result.begin(), + std::make_move_iterator(events_.begin()), + std::make_move_iterator(events_.end())); + events_.erase(events_.begin(), events_.end()); + return result; + } + + size_t size() { + std::lock_guard lock(mutex_); + return events_.size(); + } + + private: + // This mutex is used to serialize access when different threads are writing + // to the same instance of RangeEventList. + std::mutex mutex_; + std::vector events_; + + static const size_t kReservedCapacity = 1024; +}; + +enum class C10_API_ENUM ProfilerState { + Disabled = 0, + CPU, // CPU-only profiling + CUDA, // CPU + CUDA events + NVTX, // only emit NVTX markers + KINETO, // use libkineto + NUM_PROFILER_STATES, // must be the last one +}; + +struct TORCH_API ProfilerConfig { + ProfilerConfig( + ProfilerState state, + bool report_input_shapes = false, + bool profile_memory = false, + bool with_stack = false) + : state(state), + report_input_shapes(report_input_shapes), + profile_memory(profile_memory), + with_stack(with_stack) {} + ~ProfilerConfig() = default; + ProfilerState state; + bool report_input_shapes; + bool profile_memory; + bool with_stack; + + // Returns IValues corresponding to ProfilerConfig struct, to be used for + // serialization. + at::IValue toIValue() const; + + // Reconstructs a ProfilerConfig from IValues given by toIValue. + static ProfilerConfig fromIValue(const at::IValue& profilerConfigIValue); +}; + +// A struct to control settings of disableProfiler options. +struct TORCH_API ProfilerDisableOptions { + ProfilerDisableOptions() = default; + ProfilerDisableOptions(bool shouldCleanupTLSState, bool shouldConsolidate) + : cleanupTLSState(shouldCleanupTLSState), + consolidate(shouldConsolidate) {} + // Whether we should clean up profiler states that are thread local, such as + // ThreadLocalDebugInfo and thread local RecordFunction callbacks. + bool cleanupTLSState = true; + // Whether we should consolidate all currently recorded profiled events. If + // false, will not consolidate and other threads can continue to write to the + // event lists. + bool consolidate = true; +}; + +// NOTE: profiler mode is thread local, with automatic propagation +// across thread boundary (e.g. at::launch tasks) +TORCH_API void enableProfilerLegacy(const ProfilerConfig&); +using thread_event_lists = std::vector>; +TORCH_API thread_event_lists disableProfilerLegacy(c10::optional profilerDisableOptions = c10::nullopt); + +// adds profiledEvents to the current thread local recorded events. Each event +// will be marked with node ID given by fromNodeId. +TORCH_API void addEventList(std::vector&& profiledEvents); +// Returns if the profiler is currently enabled in the current thread. +TORCH_API bool profilerEnabled(); +// Retrieve the thread_local ProfilerConfig. +TORCH_API ProfilerConfig getProfilerConfig(); +// Writes profiled events to a stream. +TORCH_API void writeProfilerEventsToStream(std::ostream& out, const std::vector& events); + +// Usage: +// { +// RecordProfile guard("filename.trace"); +// // code you want to profile +// } +// Then open filename.trace in chrome://tracing +struct TORCH_API RecordProfile { + RecordProfile(std::ostream& out); + RecordProfile(const std::string& filename); + + ~RecordProfile(); +private: + void init(); + std::unique_ptr file_; + std::ostream& out_; + void processEvents(const std::vector& events); +}; + +// A guard that enables the profiler, taking in an optional callback to process +// the results +// Usage: +// { +// TLSProfilerGuard g([](thread_event_lists profilerResults) { +// // process profilerResults +// }); +// Code to profile +// } +struct TORCH_API TLSProfilerGuard { + explicit TLSProfilerGuard( + const ProfilerConfig& cfg, + c10::optional> + resultCallback = c10::nullopt, + c10::optional profilerDisableOptions = + c10::nullopt) + : cb_(std::move(resultCallback)), + profilerDisableOptions_(std::move(profilerDisableOptions)) { + enableProfilerLegacy(cfg); + } + ~TLSProfilerGuard() { + thread_event_lists event_lists = disableProfilerLegacy(profilerDisableOptions_); + if (cb_) { + try { + (*cb_)(event_lists); + } catch (const std::exception& e) { + LOG(ERROR) << "Got error processing profiler events: " << e.what(); + } + } + } + + private: + c10::optional> cb_; + const c10::optional profilerDisableOptions_; +}; + +struct TORCH_API FileLineFunc { + std::string filename; + size_t line; + std::string funcname; +}; +TORCH_API std::vector prepareCallstack(const std::vector& cs); +TORCH_API std::vector callstackStr(const std::vector& cs); +TORCH_API std::vector> inputSizes(const at::RecordFunction& fn); + +struct TORCH_API ProfilerThreadLocalState : public c10::MemoryReportingInfoBase { + explicit ProfilerThreadLocalState(const ProfilerConfig& config) + : config_(config), remoteProfiledEvents_{c10::nullopt} {} + ~ProfilerThreadLocalState() override = default; + + const ProfilerConfig& config() const; + + thread_event_lists consolidate(); + + void mark(std::string name, bool include_cuda = true); + + void setOrAddRemoteProfiledEvents( + std::vector&& remoteProfiledEvents); + + void pushRange( + const at::RecordFunction& fn, + const bool record_cuda, + const char* msg = "", + std::vector>&& shapes = {}); + + void popRange(const at::RecordFunction& fn, const bool record_cuda); + + void setCallbackHandle(at::CallbackHandle handle) { + handle_ = handle; + } + + at::CallbackHandle callbackHandle() const { + return handle_; + } + + bool hasCallbackHandle() { + return handle_ > 0; + } + + void reportMemoryUsage( + void* /* unused */, + int64_t alloc_size, + c10::Device device) override; + + bool memoryProfilingEnabled() const override; + + protected: + std::string getNvtxStr( + const at::StringView& name, + const char* msg, + int64_t sequence_nr, + const std::vector>& shapes) const; + + RangeEventList& getEventList(int64_t thread_id = -1); + + std::mutex state_mutex_; + std::unordered_map> + event_lists_map_; + + ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled); + at::CallbackHandle handle_ = 0; + c10::optional>> remoteProfiledEvents_; +}; + + +} // namespace profiler +}} // namespace torch::autograd diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp index 3656a1b9dae4..2336711d07e9 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp @@ -13,7 +13,7 @@ constexpr auto kProfileEventsStartIdx = 3; RpcWithProfilingResp::RpcWithProfilingResp( rpc::MessageType messageType, rpc::Message&& wrappedMessage, - std::vector profiledEvents, + std::vector profiledEvents, rpc::ProfilingId profilingId) : messageType_(messageType), wrappedMessage_(std::move(wrappedMessage)), @@ -32,7 +32,7 @@ RpcWithProfilingResp::RpcWithProfilingResp( std::unique_ptr wrappedRpc, rpc::MessageType wrappedMessageType, std::vector tensors, - std::vector profiledEvents, + std::vector profiledEvents, rpc::ProfilingId profilingId) : messageType_(messageType), wrappedRpc_(std::move(wrappedRpc)), @@ -52,7 +52,7 @@ rpc::MessageType RpcWithProfilingResp::wrappedMessageType() const { return wrappedMessageType_; } -std::vector RpcWithProfilingResp:: +std::vector RpcWithProfilingResp:: getProfiledEvents() const { return profiledEvents_; } @@ -119,15 +119,15 @@ std::unique_ptr RpcWithProfilingResp::fromMessage( static_cast(tupleElements[0].toInt()); rpc::ProfilingId profilingId = rpc::ProfilingId::fromIValue(tupleElements[1]); int profiledEventsSize = tupleElements[2].toInt(); - std::vector remoteEvents; + std::vector remoteEvents; remoteEvents.reserve(profiledEventsSize); for (int i = kProfileEventsStartIdx; i < kProfileEventsStartIdx + profiledEventsSize; ++i) { TORCH_CHECK(i < tupleElements.size()); // Reconstruct remote event from the ivalues. - torch::autograd::profiler::Event fromIvalueEvent = - torch::autograd::profiler::Event::fromIValue(tupleElements[i]); + torch::autograd::profiler::LegacyEvent fromIvalueEvent = + torch::autograd::profiler::LegacyEvent::fromIValue(tupleElements[i]); remoteEvents.push_back(std::move(fromIvalueEvent)); } diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h index c4dc088017f0..ad7b54ea8b82 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h @@ -15,7 +15,7 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase { RpcWithProfilingResp( rpc::MessageType messageType, rpc::Message&& wrappedMessage, - std::vector profiledEvents, + std::vector profiledEvents, rpc::ProfilingId profilingId); // For receving RPCs. Used in from message when converting a message received @@ -25,13 +25,13 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase { std::unique_ptr wrappedRpc, rpc::MessageType wrappedMessageType, std::vector tensors, - std::vector profiledEvents, + std::vector profiledEvents, rpc::ProfilingId profilingId); rpc::Message toMessageImpl() && override; static std::unique_ptr fromMessage( const rpc::Message& message); // Retrieve remote Events - std::vector getProfiledEvents() const; + std::vector getProfiledEvents() const; // Retrieve the globally unique profiling ID corresponding to this command. const rpc::ProfilingId& getProfilingId() const; // Retrieve the original RPC which this ProfilingRPC wraps. @@ -51,7 +51,7 @@ class TORCH_API RpcWithProfilingResp : public rpc::RpcCommandBase { std::unique_ptr wrappedRpc_; rpc::MessageType wrappedMessageType_; std::vector tensors_; - const std::vector profiledEvents_; + const std::vector profiledEvents_; const rpc::ProfilingId profilingId_; }; } // namespace autograd diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index 843062b2270c..7f8db89f55bf 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -81,7 +81,7 @@ std::shared_ptr RequestCallbackNoPython::processMessage( if (serverProcessGlobalProfilerStateStackEntryPtr) { // Initialize thread-local profiler state from process-global // profiler state. - ::torch::autograd::profiler::enableProfiler( + ::torch::autograd::profiler::enableProfilerLegacy( serverProcessGlobalProfilerStateStackEntryPtr->statePtr() ->config()); } @@ -93,7 +93,7 @@ std::shared_ptr RequestCallbackNoPython::processMessage( if (serverProcessGlobalProfilerStateStackEntryPtr) { // Restore thread-local profiler state. ::torch::autograd::profiler::thread_event_lists event_lists = - ::torch::autograd::profiler::disableProfiler(); + ::torch::autograd::profiler::disableProfilerLegacy(); // Put thread_local event_lists into the process-global profiler // state. profiler::processglobal::pushResultRecursive( @@ -509,7 +509,7 @@ void RequestCallbackNoPython::processRunWithProfilingReq( responseFuture, profilingKeyId, profilingConfig] { - std::vector profiledEvents; + std::vector profiledEvents; // Defer consolidation of profiler events until async work has // completed (such as async UDF) @@ -521,7 +521,7 @@ void RequestCallbackNoPython::processRunWithProfilingReq( // they will be cleaned up by main thread, and consolidate all // events so we obtain asynchronously run events. torch::autograd::profiler::ProfilerDisableOptions opts(false, true); - auto event_lists = torch::autograd::profiler::disableProfiler(opts); + auto event_lists = torch::autograd::profiler::disableProfilerLegacy(opts); if (wrappedRpcResponseFuture->hasError()) { // Propagate error // No need to propagate remote events in the case of an error. diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index b4fb2ff394e7..79c197505bbb 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -36,7 +36,7 @@ void processRemoteProfiledEvents( "Profiler was expected to be enabled. This can happen in callback " " continutations that run in different threads, and the TLS of the " " profiler was not propagated."); - std::vector events = + std::vector events = rpcWithProfilingResp.getProfiledEvents(); const auto& profilingId = rpcWithProfilingResp.getProfilingId(); auto& remoteProfilerManager = RemoteProfilerManager::getInstance(); @@ -46,7 +46,7 @@ void processRemoteProfiledEvents( std::for_each( events.begin(), events.end(), - [&keyPrefixStr](torch::autograd::profiler::Event& event) { + [&keyPrefixStr](torch::autograd::profiler::LegacyEvent& event) { std::string name = keyPrefixStr + std::string(event.name()); event.setName(at::StringView(name)); }); @@ -511,9 +511,9 @@ std::vector readWrappedPayload( } void populateRemoteProfiledEvents( - std::vector& profiledEvents, + std::vector& profiledEvents, const torch::autograd::profiler::ProfilerConfig& profilingConfig, - const std::vector>& + const std::vector>& eventLists) { // Gather all events into a vector for (auto& l : eventLists) { @@ -525,11 +525,11 @@ void populateRemoteProfiledEvents( bool cudaProfilingEnabled = profilingConfig.state == torch::autograd::profiler::ProfilerState::CUDA; bool foundCpuStart = false; - const torch::autograd::profiler::Event* profilerStart = nullptr; + const torch::autograd::profiler::LegacyEvent* profilerStart = nullptr; // Each device has its own cudaProfilerStart, so we must take // care to use the correct one depending on the device the // operation ran on. - std::unordered_map + std::unordered_map cudaProfilerStarts; for (auto& e : profiledEvents) { if (!foundCpuStart && 0 == strcmp(e.name(), "__start_profile")) { diff --git a/torch/csrc/distributed/rpc/utils.h b/torch/csrc/distributed/rpc/utils.h index f91dfb4f4c7d..aa920d06cae8 100644 --- a/torch/csrc/distributed/rpc/utils.h +++ b/torch/csrc/distributed/rpc/utils.h @@ -82,9 +82,9 @@ TORCH_API std::vector readWrappedPayload( // Takes a list of events from autograd profiler and populates them into // profiledEvents to be carried over RPC. TORCH_API void populateRemoteProfiledEvents( - std::vector& profiledEvents, + std::vector& profiledEvents, const torch::autograd::profiler::ProfilerConfig& profilerConfig, - const std::vector>& + const std::vector>& eventLists); } // namespace rpc diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index 0f4ba8d53817..f113ef609e2b 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -103,7 +103,7 @@ def __enter__(self): if not self.enabled: return - if self.entered: + if self.entered: # type: ignore[has-type] raise RuntimeError("autograd profiler traces are not reentrant") self.entered = True @@ -145,13 +145,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): process_global_function_events = [] for thread_local_events in process_global_events: # Parse from ``Event``s to ``FunctionEvent``s. - thread_local_function_events = torch.autograd.profiler.parse_event_records( + thread_local_function_events = torch.autograd.profiler.parse_legacy_records( thread_local_events ) thread_local_function_events.sort( key=lambda function_event: [ - function_event.cpu_interval.start, - -(function_event.cpu_interval.end), + function_event.time_range.start, + -(function_event.time_range.end), ] ) process_global_function_events.append(thread_local_function_events) @@ -164,6 +164,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): use_cuda=self.use_cuda, profile_memory=self.profile_memory, ) + self.function_events._build_tree() self.process_global_function_events = process_global_function_events diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index fa5e988273fc..6c8bec2d0f92 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -177,9 +177,9 @@ class AllreduceNCCLTest : public NCCLTest { // Make sure enabling profile does not make any issue. Note, in single // process multi-device mode we do not expect any events be populated for // collective operations, since profiling for that mode is not supported. - enableProfiler({ProfilerState::CPU}); + enableProfilerLegacy({ProfilerState::CPU}); auto results = pg_->allreduce(tensors_); - disableProfiler(); + disableProfilerLegacy(); return results; } }; diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 31bd8524f55f..5abe981a1734 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -1567,8 +1567,8 @@ def _profiler_test_with_rpc(self, rpc_exec_mode, func, args, use_record_function scope_event = get_function_event(events, "foo") # Since RPC call is within the scope, its CPU interval should be # contained within foo's interval. - self.assertTrue(scope_event.cpu_interval.start < rpc_event.cpu_interval.start) - self.assertTrue(scope_event.cpu_interval.end > rpc_event.cpu_interval.end) + self.assertTrue(scope_event.time_range.start < rpc_event.time_range.start) + self.assertTrue(scope_event.time_range.end > rpc_event.time_range.end) # the sender, dest worker, function run, and type of RPC should all # be recorded. self_worker_name = worker_name(self.rank) @@ -1760,10 +1760,10 @@ def _assert_top_level_events(self, process_global_events, expected_top_level_eve last_end_time = 0 for event in thread_local_events: event_name = event.name - cpu_interval = event.cpu_interval - if cpu_interval.start > last_end_time: + time_range = event.time_range + if time_range.start > last_end_time: top_level_event_names.append(event_name) - last_end_time = cpu_interval.end + last_end_time = time_range.end self.assertEqual(sorted(top_level_event_names), sorted(expected_top_level_event_names)) @dist_init