diff --git a/docs/source/index.rst b/docs/source/index.rst index 3cbe8fc07178..a334bffab01e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -66,6 +66,7 @@ Features described in this documentation are classified by release status: torch.jit torch.linalg torch.overrides + profiler nn.init onnx optim diff --git a/docs/source/profiler.rst b/docs/source/profiler.rst new file mode 100644 index 000000000000..6541f08c4feb --- /dev/null +++ b/docs/source/profiler.rst @@ -0,0 +1,17 @@ +.. currentmodule:: torch.profiler + +torch.profiler +============== + +Overview +-------- +.. automodule:: torch.profiler + + +API Reference +------------- + +.. autoclass:: torch.profiler.profile + :members: + +.. autofunction:: torch.profiler.schedule diff --git a/test/test_profiler.py b/test/test_profiler.py index c6598a1da9a8..8d3d24dde805 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -240,5 +240,47 @@ def test_flops(self): print(profiler_output) self.assertIn("FLOPS", profiler_output) + @unittest.skipIf(not kineto_available(), "Kineto is required") + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") + def test_kineto_profiler_api(self): + called_num = [0] + + with profile(use_cuda=True, use_kineto=True): + self.payload() + + def trace_handler(p): + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + # p.export_chrome_trace("/tmp/test_trace_" + str(called_num[0]) + ".json") + called_num[0] += 1 + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=2), + on_trace_ready=trace_handler + ) as p: + for idx in range(8): + self.payload() + p.next_step() + + self.assertEqual(called_num[0], 2) + + # case without enable_pred + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA] + ) as p: + self.payload() + self.payload() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + + if __name__ == '__main__': run_tests() diff --git a/torch/__init__.py b/torch/__init__.py index 30c328c1da6f..04955623ab2a 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -596,6 +596,7 @@ def _assert(condition, message): import torch.utils.data import torch.__config__ import torch.__future__ +import torch.profiler _C._init_names(list(torch._storage_classes)) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 1cdf408aa70d..5aef3f95aa0a 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -464,6 +464,15 @@ def __enter__(self): torch.autograd._enable_profiler_legacy(self.config()) return self + def _prepare_kineto_trace(self): + assert self.kineto_activities + self.entered = True + torch.autograd._prepare_profiler(self.config(), self.kineto_activities) + + def _start_kineto_trace(self): + assert self.kineto_activities + torch.autograd._enable_profiler(self.config(), self.kineto_activities) + def __exit__(self, exc_type, exc_val, exc_tb): if not self.enabled: return diff --git a/torch/profiler/__init__.py b/torch/profiler/__init__.py new file mode 100644 index 000000000000..dabbf91dff90 --- /dev/null +++ b/torch/profiler/__init__.py @@ -0,0 +1,12 @@ +# type: ignore +r''' +PyTorch Profiler is a tool that allows the collecton of the performance metrics during the training and inference. +Profiler's context manager API can be used to better understand what model operators are the most expensive, +examine their input shapes and stack traces, study device kernel activity and visualize the execution trace. + +.. note:: + An earlier version of the API in ``torch.autograd`` module is considered legacy and will be deprecated. + +''' + +from .profiler import profile, schedule, ProfilerAction, ProfilerActivity diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py new file mode 100644 index 000000000000..652a76262df2 --- /dev/null +++ b/torch/profiler/profiler.py @@ -0,0 +1,293 @@ +import torch.autograd.profiler as prof +from torch.autograd import ProfilerActivity + +from enum import Enum +from typing import Any, Callable, Iterable, Optional +from warnings import warn + + +class ProfilerAction(Enum): + NONE = 0 + WARMUP = 1 + RECORD = 2 + RECORD_AND_SAVE = 3 + + +def schedule(*, wait: int, warmup: int, active: int): + """ + Represents profiler behavior: wait for ``wait`` steps, then + do the warmup for the next ``warmup`` steps, then + do the active recording for the next ``active`` steps and then + repeat the cycle staring with the next step. + """ + def schedule_fn(step: int) -> ProfilerAction: + assert step >= 0 + num_steps = wait + warmup + active + mod_step = step % num_steps + if mod_step < wait: + return ProfilerAction.NONE + elif mod_step < wait + warmup: + return ProfilerAction.WARMUP + else: + return ProfilerAction.RECORD if mod_step < num_steps - 1 \ + else ProfilerAction.RECORD_AND_SAVE + assert wait >= 0 and warmup >= 0 and active > 0, \ + "Invalid profiler schedule arguments" + if warmup == 0: + warn("Profiler won't be using warmup, this can skew profiler results") + return schedule_fn + + +def _default_schedule_fn(_: int) -> ProfilerAction: + """ + Default profiler behavior - immediately starts recording the events, + keeps doing it on every profiler step. + """ + return ProfilerAction.RECORD + + +class profile(object): + """ + Profiler context manager. + + Arguments: + + - ``activities`` - list of activity groups (CPU, CUDA) to use in profiling; + - ``schedule`` - callable that takes step (int) as a single parameter and returns + ``ProfilerAction`` value that specifies the profiler action on each step; + - ``on_trace_ready`` (optional) - callable, called each time the trace is ready + during the profiling; + - ``record_shapes`` - save information about operator's input shapes; + - ``profile_memory`` - track tensor memory allocation/deallocation; + - ``with_stack`` - save stack traces; + - ``use_gpu`` - (deprecated, use ``activities``). + + .. note:: + Use ``torch.profiler.schedule`` to generate the callable schedule. + Non-default schedules are useful when profiling long training jobs + and allow the user to obtain multiple traces at the different iterations + of the training process. + The default schedule simply records all the events continuously for the + duration of the context manager. + + .. note:: + Enabling shape and stack tracing results in additional overhead. + + Examples: + + .. code-block:: python + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA] + ) as p: + code_to_profile() + print(p.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + + Usimg the profiler's ``schedule``, ``on_trace_ready`` and ``next_step`` functions: + + .. code-block:: python + + # Non-default profiler schedule allows user to turn profiler on and off + # on different iterations of the training loop; + # trace_handler is called every time a new trace becomes available + def trace_handler(prof): + print(prof.key_averages().table( + sort_by="self_cuda_time_total", row_limit=-1)) + # prof.export_chrome_trace("/tmp/test_trace_" + str(prof.step()) + ".json") + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA], + + # In this example with wait=1, warmup=1, active=2, + # profiler will skip the first step/iteration, + # start warming up on the second, record + # the third and the forth iterations, + # after which the trace will become available + # and on_trace_ready (when set) is called; + # the cycle repeats starting with the next step + + schedule=torch.profiler.schedule( + wait=1, + warmup=1, + active=2), + on_trace_ready=trace_handler + ) as p: + for iter in range(N): + code_iteration_to_profile(iter) + # send a signal to the profiler that the next iteration has started + p.next_step() + """ + def __init__( + self, + *, + activities: Optional[Iterable[ProfilerActivity]] = None, + schedule: Optional[Callable[[int], ProfilerAction]] = None, + on_trace_ready: Optional[Callable[..., Any]] = None, + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + # deprecated: + use_gpu: Optional[bool] = None): + if activities: + self.activities = activities + else: + if use_gpu is not None: + warn("use_gpu is deprecated, use activities argument instead") + self.activities = set([ProfilerActivity.CPU]) + if use_gpu: + self.activities.add(ProfilerActivity.CUDA) + else: + raise RuntimeError("Profiler activities are not specified") + + if schedule: + self.schedule = schedule + # add step markers into the trace and table view + self.record_steps = True + else: + self.schedule = _default_schedule_fn + self.record_steps = False + self.on_trace_ready = on_trace_ready + self.record_shapes = record_shapes + self.profile_memory = profile_memory + self.with_stack = with_stack + self.step_num = 0 + self.current_action = self.schedule(self.step_num) + self.profiler: Optional[prof.profile] = None + self.step_rec_fn: Optional[prof.record_function] = None + + def __enter__(self): + self._enter_actions() + if self.record_steps: + self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num)) + self.step_rec_fn.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.record_steps and self.step_rec_fn: + self.step_rec_fn.__exit__(None, None, None) + self._exit_actions() + + def next_step(self): + """ + Signals the profiler that the next profiling step has started. + """ + if self.record_steps and self.step_rec_fn: + self.step_rec_fn.__exit__(None, None, None) + prev_action = self.current_action + self.step_num += 1 + self.current_action = self.schedule(self.step_num) + + if self.current_action == ProfilerAction.NONE: + if prev_action == ProfilerAction.NONE: + pass + elif prev_action == ProfilerAction.WARMUP: + warn("Incorrect schedule: WARMUP followed by NONE") + self._start_trace() + self._stop_trace() + elif prev_action == ProfilerAction.RECORD: + warn("Incorrect schedule: RECORD followed by NONE") + self._stop_trace() + else: + assert prev_action == ProfilerAction.RECORD_AND_SAVE + self._stop_trace() + if self.on_trace_ready: + self.on_trace_ready(self) + elif self.current_action == ProfilerAction.WARMUP: + if prev_action == ProfilerAction.NONE: + self._start_warmup() + elif prev_action == ProfilerAction.WARMUP: + pass + elif prev_action == ProfilerAction.RECORD: + warn("Incorrect schedule: RECORD followed by WARMUP") + self._stop_trace() + else: + assert prev_action == ProfilerAction.RECORD_AND_SAVE + self._stop_trace() + if self.on_trace_ready: + self.on_trace_ready(self) + self._start_warmup() + elif self.current_action in \ + [ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE]: + if prev_action == ProfilerAction.NONE: + self._start_warmup() + self._start_trace() + elif prev_action == ProfilerAction.WARMUP: + self._start_trace() + elif prev_action == ProfilerAction.RECORD: + pass + else: + assert prev_action == ProfilerAction.RECORD_AND_SAVE + self._stop_trace() + if self.on_trace_ready: + self.on_trace_ready(self) + self._start_warmup() + self._start_trace() + + if self.record_steps: + self.step_rec_fn = prof.record_function("ProfilerStep#" + str(self.step_num)) + self.step_rec_fn.__enter__() + + def step(self): + """ + Returns the current profiling step. + """ + return self.step_num + + def export_chrome_trace(self, path: str): + """ + Exports the collected trace in Chrome JSON format. + """ + assert self.profiler + return self.profiler.export_chrome_trace(path) + + def key_averages(self, group_by_input_shape: bool = False, group_by_stack_n: int = 0): + """ + Averages events, grouping them by operator name and (optionally) input shapes and + stack. + Note: to use shape/stack functionality make sure to set record_shapes/with_stack + when creating profiler context manager. + """ + assert self.profiler + return self.profiler.key_averages(group_by_input_shape, group_by_stack_n) + + def _enter_actions(self): + if self.current_action == ProfilerAction.WARMUP: + self._start_warmup() + elif self.current_action in \ + [ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE]: + self._start_warmup() + self._start_trace() + + def _exit_actions(self): + if self.current_action == ProfilerAction.WARMUP: + self._start_trace() + self._stop_trace() + elif self.current_action in \ + [ProfilerAction.RECORD, ProfilerAction.RECORD_AND_SAVE]: + self._stop_trace() + if self.on_trace_ready: + self.on_trace_ready(self) + + def _start_warmup(self): + self.profiler = prof.profile( + use_cuda=(ProfilerActivity.CUDA in self.activities), + use_cpu=(ProfilerActivity.CPU in self.activities), + record_shapes=self.record_shapes, + profile_memory=self.profile_memory, + with_stack=self.with_stack, + use_kineto=True, + ) + self.profiler._prepare_kineto_trace() + + def _start_trace(self): + assert self.profiler is not None + self.profiler._start_kineto_trace() + + def _stop_trace(self): + assert self.profiler is not None + self.profiler.__exit__(None, None, None)