From a4d4124e6d87eb03b09bd304cd4af5c81e7db2f9 Mon Sep 17 00:00:00 2001 From: ilia-cher Date: Fri, 16 Oct 2020 07:56:54 -0700 Subject: [PATCH] Use libkineto in profiler Summary: Adding ability to use Kineto (CUPTI) to profile CUDA kernels Test Plan: python test/test_profiler.py [ghstack-poisoned] --- test/test_profiler.py | 20 ++++ torch/autograd/__init__.py | 6 +- torch/autograd/profiler.py | 65 ++++++++++--- torch/csrc/autograd/init.cpp | 14 ++- torch/csrc/autograd/profiler.cpp | 151 +++++++++++++++++++++++++++---- torch/csrc/autograd/profiler.h | 40 +++++++- 6 files changed, 256 insertions(+), 40 deletions(-) diff --git a/test/test_profiler.py b/test/test_profiler.py index f1feff1d0af3..44973546429e 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -99,6 +99,26 @@ def forward(self, x): torch._C._set_graph_executor_optimize(prev_opt) + @unittest.skipIf(not torch.autograd.kineto_available(), "Kineto is required") + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + def test_kineto(self): + x = torch.randn(10, 10).cuda() + y = torch.randn(10, 10).cuda() + with profile(use_cuda=True, use_kineto=True) as p: + z = torch.mm(x, y) + z = z + y + z = z.cpu() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + found_gemm = False + found_memcpy = False + for e in p.function_events: + if "gemm" in e.name: + found_gemm = True + if "Memcpy" in e.name or "memcpy" in e.name: + found_memcpy = True + self.assertTrue(found_gemm) + self.assertTrue(found_memcpy) if __name__ == '__main__': run_tests() diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 4e44536d931c..cec103ea4c8c 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -242,6 +242,6 @@ def variable(*args, **kwargs): raise RuntimeError("autograd initialization failed") # Import all native method/classes -from torch._C._autograd import (ProfilerState, ProfilerConfig, ProfilerEvent, - _enable_profiler, _disable_profiler, _profiler_enabled, - _enable_record_function, _set_empty_test_observer) +from torch._C._autograd import (ProfilerActivity, ProfilerState, ProfilerConfig, ProfilerEvent, + _prepare_profiler, _enable_profiler, _disable_profiler, _profiler_enabled, + _enable_record_function, _set_empty_test_observer, kineto_available) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index eba7368cb03e..c4d23f9efeb4 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -364,16 +364,47 @@ def __init__( use_cuda=False, record_shapes=False, profile_memory=False, - with_stack=False): + with_stack=False, + use_kineto=False): self.enabled = enabled - self.use_cuda = use_cuda - self.function_events = None if not self.enabled: return + self.use_cuda = use_cuda + self.function_events = None self.entered = False self.record_shapes = record_shapes self.profile_memory = profile_memory self.with_stack = with_stack + self.use_kineto = use_kineto + + self.profiler_kind = None + self.kineto_activities = [] + if self.use_kineto: + if self.use_cuda: + self.profiler_kind = torch.autograd.ProfilerState.KINETO + self.kineto_activities = [ + torch.autograd.ProfilerActivity.CPU, + # uses CUPTI + torch.autograd.ProfilerActivity.CUDA_RUNTIME, + torch.autograd.ProfilerActivity.CUDA] + else: + # intially we're not using Kineto for CPU only case + self.profiler_kind = torch.autograd.ProfilerState.CPU + elif self.use_cuda: + # legacy CUDA mode + self.profiler_kind = torch.autograd.ProfilerState.CUDA + else: + self.profiler_kind = torch.autograd.ProfilerState.CPU + self.kineto_activities = set(self.kineto_activities) + + if self.profiler_kind == torch.autograd.ProfilerState.KINETO: + assert torch.autograd.kineto_available() + + self.config = torch.autograd.ProfilerConfig( + self.profiler_kind, + self.record_shapes, + self.profile_memory, + self.with_stack) def __enter__(self): if not self.enabled: @@ -381,15 +412,8 @@ def __enter__(self): if self.entered: raise RuntimeError("autograd profiler traces are not reentrant") self.entered = True - profiler_kind = torch.autograd.ProfilerState.CUDA if self.use_cuda \ - else torch.autograd.ProfilerState.CPU - - config = torch.autograd.ProfilerConfig( - profiler_kind, - self.record_shapes, - self.profile_memory, - self.with_stack) - torch.autograd._enable_profiler(config) + torch.autograd._prepare_profiler(self.config, self.kineto_activities) + torch.autograd._enable_profiler(self.config) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -732,7 +756,7 @@ class FunctionEvent(FormattedTimesMixin): def __init__( self, id, node_id, name, thread, cpu_start, cpu_end, fwd_thread=None, input_shapes=None, stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, is_async=False, - is_remote=True, sequence_nr=-1): + is_remote=True, sequence_nr=-1, device_id=-1): self.id: int = id self.node_id: int = node_id self.name: str = name @@ -751,6 +775,7 @@ def __init__( self.is_async: bool = is_async self.is_remote: bool = is_remote self.sequence_nr: int = sequence_nr + self.device_id: int = device_id def append_kernel(self, name, device, start, end): self.kernels.append(Kernel(name, device, Interval(start, end))) @@ -802,15 +827,21 @@ def self_cpu_time_total(self): @property def cuda_time_total(self): + if self.device_id >= 0: + return self.cpu_interval.elapsed_us() return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) @property def self_cuda_time_total(self): + if self.device_id >= 0: + return self.cuda_time_total - sum([child.cuda_time_total for child in self.cpu_children]) return sum(kinfo.interval.elapsed_us() for kinfo in self.kernels) - \ sum([child.cuda_time_total for child in self.cpu_children]) @property def cpu_time_total(self): + if self.device_id >= 0: + return 0 return self.cpu_interval.elapsed_us() @property @@ -1045,6 +1076,7 @@ def adjusted_time(cuda_record, cuda_records_map): is_async=is_async, is_remote=is_remote_event, sequence_nr=start.sequence_nr(), + device_id=start.device_id(), ) # note: async events have only cpu total time if not is_async and start.has_cuda(): @@ -1180,7 +1212,9 @@ def build_table( has_input_shapes = any( [(event.input_shapes is not None and len(event.input_shapes) > 0) for event in events]) + MAX_NAME_COLUMN_WIDTH = 55 name_column_width = max([len(evt.key) for evt in events]) + 4 + name_column_width = min(name_column_width, MAX_NAME_COLUMN_WIDTH) DEFAULT_COLUMN_WIDTH = 12 @@ -1288,8 +1322,11 @@ def append(s): continue else: event_limit += 1 + name = evt.key + if len(name) >= MAX_NAME_COLUMN_WIDTH-3: + name = name[:(MAX_NAME_COLUMN_WIDTH-3)] + "..." row_values = [ - evt.key, # Name + name, # Self CPU total, 0 for async events. % format_time_share(evt.self_cpu_time_total, self_cpu_time_total), diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 045a732a2016..698931911878 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -39,7 +39,13 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { .value("Disabled", ProfilerState::Disabled) .value("CPU", ProfilerState::CPU) .value("CUDA", ProfilerState::CUDA) - .value("NVTX", ProfilerState::NVTX); + .value("NVTX", ProfilerState::NVTX) + .value("KINETO", ProfilerState::KINETO); + + py::enum_(m, "ProfilerActivity") + .value("CPU", ActivityType::CPU) + .value("CUDA_RUNTIME", ActivityType::CUDA_RUNTIME) + .value("CUDA", ActivityType::CUDA); py::class_(m, "ProfilerConfig") .def(py::init()); @@ -61,11 +67,15 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { .def("is_remote", &Event::isRemote) .def("sequence_nr", &Event::sequenceNr) .def("stack", &Event::stack) - .def("scope", &Event::scope); + .def("scope", &Event::scope) + .def("device_id", &Event::device); py::class_(m, "_ProfilerDisableOptions") .def(py::init()); + m.def("kineto_available", kinetoAvailable); + + m.def("_prepare_profiler", prepareProfiler); m.def("_enable_profiler", enableProfiler); m.def( "_disable_profiler", diff --git a/torch/csrc/autograd/profiler.cpp b/torch/csrc/autograd/profiler.cpp index 5cbb7606e579..0b6203e695fe 100644 --- a/torch/csrc/autograd/profiler.cpp +++ b/torch/csrc/autograd/profiler.cpp @@ -21,6 +21,10 @@ #include +#ifdef USE_KINETO +#include "libkineto.h" +#endif + namespace torch { namespace autograd { namespace profiler { namespace { @@ -48,23 +52,23 @@ enum ProfilerIValueIdx { NUM_PROFILER_CFG_IVALUE_IDX // must be last in list }; - const std::unordered_set disable_cuda_profiling = { - "aten::view", - "aten::t", - "aten::transpose", - "aten::stride", - "aten::empty", - "aten::empty_like", - "aten::empty_strided", - "aten::as_strided", - "aten::expand", - "aten::resize_", - "aten::squeeze", - "aten::unsqueeze", - "aten::slice", - "aten::_unsafe_view", - "aten::size" - }; +const std::unordered_set disable_cuda_profiling = { + "aten::view", + "aten::t", + "aten::transpose", + "aten::stride", + "aten::empty", + "aten::empty_like", + "aten::empty_strided", + "aten::as_strided", + "aten::expand", + "aten::resize_", + "aten::squeeze", + "aten::unsqueeze", + "aten::slice", + "aten::_unsafe_view", + "aten::size" +}; CUDAStubs default_stubs; constexpr CUDAStubs* default_stubs_addr = &default_stubs; @@ -169,6 +173,14 @@ struct FileLineFunc { std::string funcname; }; +static std::atomic corr_id_ {}; +size_t next_correlation_id() { + return corr_id_++; +} +size_t peek_correlation_id() { + return corr_id_; +} + // Profiler state struct ProfilerThreadLocalState : public c10::MemoryReportingInfoBase { explicit ProfilerThreadLocalState(const ProfilerConfig& config) @@ -193,6 +205,12 @@ struct ProfilerThreadLocalState : public c10::MemoryReportingInfoBase { std::make_move_iterator(remoteProfiledEvents_->begin()), std::make_move_iterator(remoteProfiledEvents_->end())); } + if (kinetoEvents_) { + result.insert( + result.end(), + std::make_move_iterator(kinetoEvents_->begin()), + std::make_move_iterator(kinetoEvents_->end())); + } return result; } @@ -224,6 +242,11 @@ struct ProfilerThreadLocalState : public c10::MemoryReportingInfoBase { } } + void setKinetoEvents(std::vector>&& kinetoEvents) { + std::lock_guard guard(state_mutex_); + kinetoEvents_ = std::move(kinetoEvents); + } + void pushRange( const at::RecordFunction& fn, const bool record_cuda, @@ -247,6 +270,7 @@ struct ProfilerThreadLocalState : public c10::MemoryReportingInfoBase { evt.setSequenceNr(fn.seqNr()); evt.setFwdThreadId(fn.forwardThreadId()); evt.setScope((uint8_t)fn.scope()); + evt.setCorrelationId(peek_correlation_id()); #ifndef C10_MOBILE // backward nodes source range corresponds to the forward node // TODO: consider using C++ stack trace @@ -409,6 +433,7 @@ struct ProfilerThreadLocalState : public c10::MemoryReportingInfoBase { ProfilerConfig config_ = ProfilerConfig(ProfilerState::Disabled); at::CallbackHandle handle_ = 0; c10::optional>> remoteProfiledEvents_; + c10::optional>> kinetoEvents_; }; ProfilerThreadLocalState* getProfilerTLSState() { @@ -451,6 +476,11 @@ void pushProfilingCallbacks() { } else { state_ptr->pushRange(fn, record_cuda, msg); } +#ifdef USE_KINETO + if (state_ptr->config().state == ProfilerState::KINETO) { + libkineto::api().pushCorrelationId(next_correlation_id()); + } +#endif }, [](const at::RecordFunction& fn) { auto state_ptr = getProfilerTLSState(); @@ -463,6 +493,11 @@ void pushProfilingCallbacks() { record_cuda = false; } state_ptr->popRange(fn, record_cuda); +#ifdef USE_KINETO + if (state_ptr->config().state == ProfilerState::KINETO) { + libkineto::api().popCorrelationId(); + } +#endif }) .needsInputs(state_ptr->config().report_input_shapes) .needsIds(true)); @@ -519,10 +554,48 @@ bool profilerEnabled() { return state_ptr && state_ptr->config().state != ProfilerState::Disabled; } +bool kinetoAvailable() { +#ifdef USE_KINETO + return true; +#else + return false; +#endif +} + +void prepareProfiler( + const ProfilerConfig& new_config, + const std::set& activities) { +#ifdef USE_KINETO + if (new_config.state == ProfilerState::KINETO) { + std::set k_activities; + if (activities.count(ActivityType::CPU)) { + k_activities.insert(libkineto::ActivityType::EXTERNAL_CORRELATION); + } + if (activities.count(ActivityType::CUDA_RUNTIME)) { + k_activities.insert(libkineto::ActivityType::CUDA_RUNTIME); + } + if (activities.count(ActivityType::CUDA)) { + k_activities.insert(libkineto::ActivityType::GPU_MEMCPY); + k_activities.insert(libkineto::ActivityType::GPU_MEMSET); + k_activities.insert(libkineto::ActivityType::CONCURRENT_KERNEL); + } + + if (!libkineto::api().hasProfilerRegistered()) { + libkineto::api().registerProfiler( + std::make_unique(false)); + } + libkineto::api().initProfilerIfRegistered(); + libkineto::api().prepareTrace(k_activities); + } +#endif +} + void enableProfiler(const ProfilerConfig& new_config) { TORCH_CHECK(new_config.state != ProfilerState::NVTX || cuda_stubs->enabled(), "Can't use NVTX profiler - PyTorch was compiled without CUDA"); + TORCH_CHECK(new_config.state != ProfilerState::KINETO || kinetoAvailable()); + auto state_ptr = getProfilerTLSState(); TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread"); auto state = std::make_shared(new_config); @@ -530,6 +603,12 @@ void enableProfiler(const ProfilerConfig& new_config) { pushProfilingCallbacks(); +#ifdef USE_KINETO + if (new_config.state == ProfilerState::KINETO) { + libkineto::api().startTrace(); + } +#endif + if (new_config.state == ProfilerState::CUDA) { // event recording appears to have some startup overhead, so we need to // to generate some dummy events first before recording synchronization events @@ -569,6 +648,44 @@ thread_event_lists disableProfiler(c10::optional profile at::removeCallback(state_ptr->callbackHandle()); } +#ifdef USE_KINETO + if (state_ptr->config().state == ProfilerState::KINETO) { + auto k_events = libkineto::api().stopTrace(); + std::unordered_map>> events; + for (auto& k_evt : k_events) { + auto& evt_list = events[k_evt.deviceId][k_evt.threadId]; + Event push_evt( + EventKind::PushRange, + at::StringView(k_evt.name), + k_evt.threadId, + false, + k_evt.correlationId); + push_evt.setDevice(k_evt.deviceId); + push_evt.setCpuUS(k_evt.startUs); + push_evt.setCorrelationId(k_evt.correlationId); + evt_list.emplace_back(std::move(push_evt)); + + Event pop_evt( + EventKind::PopRange, + at::StringView(k_evt.name), + k_evt.threadId, + false, + k_evt.correlationId); + pop_evt.setDevice(k_evt.deviceId); + pop_evt.setCpuUS(k_evt.endUs); + pop_evt.setCorrelationId(k_evt.correlationId); + evt_list.emplace_back(std::move(pop_evt)); + } + std::vector> events_list; + for (const auto& it : events) { + for (const auto& it2 : it.second) { + events_list.emplace_back(it2.second); + } + } + state_ptr->setKinetoEvents(std::move(events_list)); + } +#endif + if (!consolidate || state_ptr->config().state == ProfilerState::NVTX) { return thread_event_lists(); } diff --git a/torch/csrc/autograd/profiler.h b/torch/csrc/autograd/profiler.h index 9cfe9ea1fd6e..3bc6022b20fa 100644 --- a/torch/csrc/autograd/profiler.h +++ b/torch/csrc/autograd/profiler.h @@ -104,10 +104,19 @@ struct TORCH_API ProfilerDisableOptions { }; enum class C10_API_ENUM ProfilerState { - Disabled, - CPU, // CPU-only profiling - CUDA, // CPU + CUDA events - NVTX, // only emit NVTX markers + Disabled = 0, + CPU, // CPU-only profiling + CUDA, // CPU + CUDA events + NVTX, // only emit NVTX markers + KINETO, // use libkineto + NUM_PROFILER_STATES, // must be the last one +}; + +enum class C10_API_ENUM ActivityType { + CPU = 0, + CUDA_RUNTIME, // CUDA host events + CUDA, // CUDA kernels + NUM_KINETO_ACTIVITIES, // must be the last one }; struct TORCH_API ProfilerConfig { @@ -238,6 +247,10 @@ struct TORCH_API Event final { return cpu_ns_ / (1000.0); } + void setCpuUS(double cpu_us) { + cpu_ns_ = (int64_t)(cpu_us * 1000); + } + double cudaElapsedUs(const Event& e) const; bool hasCuda() const { @@ -248,6 +261,10 @@ struct TORCH_API Event final { return device_; } + void setDevice(int device) { + device_ = device; + } + void updateMemoryStats(int64_t alloc_size, c10::Device device) { if (device.type() == c10::DeviceType::CUDA || device.type() == c10::DeviceType::HIP) { @@ -303,6 +320,14 @@ struct TORCH_API Event final { return sequence_nr_; } + void setCorrelationId(uint64_t correlation_id) { + correlation_id_ = correlation_id; + } + + uint64_t correlationId() const { + return correlation_id_; + } + const std::vector& stack() const { return stack_; } @@ -347,6 +372,8 @@ struct TORCH_API Event final { std::vector stack_; uint8_t scope_; + + uint64_t correlation_id_; }; // a linked-list of fixed sized vectors, to avoid @@ -403,6 +430,11 @@ TORCH_API ProfilerConfig getProfilerConfig(); // Writes profiled events to a stream. TORCH_API void writeProfilerEventsToStream(std::ostream& out, const std::vector& events); +TORCH_API bool kinetoAvailable(); +TORCH_API void prepareProfiler( + const ProfilerConfig& new_config, + const std::set& activities); + // Usage: // { // RecordProfile guard("filename.trace");