From 9f928c84f72161c02d54a70a2f7bed29f980dd05 Mon Sep 17 00:00:00 2001 From: Ishan Kumar Date: Tue, 9 Nov 2021 20:08:05 +0530 Subject: [PATCH 01/10] added PyTorch Profiler --- ignite/handlers/pytorch_profiler.py | 145 ++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 ignite/handlers/pytorch_profiler.py diff --git a/ignite/handlers/pytorch_profiler.py b/ignite/handlers/pytorch_profiler.py new file mode 100644 index 00000000000..dacb2ae3316 --- /dev/null +++ b/ignite/handlers/pytorch_profiler.py @@ -0,0 +1,145 @@ +# coding: utf-8 +from typing import Any, Callable, Union +import os +from ignite.engine import Engine, Events +import ignite.distributed as idist +import datetime + +import torch + + +class PyTorchProfiler: + """PyTorch Profiler for performance debugging. + + The PyTorch profiler is a tool that collects both GPU hardware and PyTorch related + information, correlates them, performs automatic detection of bottlenecks in the model, + and generates recommendations on how to resolve these bottlenecks. + + Examples: + .. code-block:: python + + from ignite.handlers import PyTorchProfiler + + trainer = ... + model = ... + optimizer = ... + + pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path="logs/train") + pt_profiler.attach(trainer) + + # Get profiler results of time + pt_profiler.print_results() + + # Save profiler result to CSV file (requires pandas) + pt_profiler.write_results() + + Both these methods can also be used as the on_trace_ready function which gets called after trace is ready. + + pt_profiler = PyTorchProfiler(on_trace_ready=profiler.write_to_file(10), output_path="logs/train") + + .. versionadded:: 0.4.8 + """ + + def __init__( + self, + cuda_activity: bool = False, + on_trace_ready: Union[Callable[..., Any], str] = "tensorboard", + record_shapes: bool = False, + profile_memory: bool = False, + with_stack: bool = False, + with_flops: bool = False, + with_modules: bool = False, + output_path: str = None, + wait: int = 2, + warmup: int = 2, + active: int = 6, + repeat: int = 1, + ) -> None: + + self.activities = [torch.profiler.ProfilerActivity.CPU] + if cuda_activity and torch.cuda.is_available(): + self.activities.append(torch.profiler.ProfilerActivity.GPU) + + self.output_path = output_path + + self.schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat) + + self.trace_handler = ( + torch.profiler.tensorboard_trace_handler(self.output_path) + if on_trace_ready == "tensorboard" + else on_trace_ready + ) + + self.record_shapes = record_shapes + self.profile_memory = profile_memory + self.with_stack = with_stack + self.with_flops = with_flops + self.with_modules = with_modules + + self.SORT_KEYS = { + "cpu_time", + "cuda_time", + "cpu_time_total", + "cuda_time_total", + "cpu_memory_usage", + "cuda_memory_usage", + "self_cpu_memory_usage", + "self_cuda_memory_usage", + "count", + } + + def _profiler_create(self): + self._profiler = torch.profiler.profile( + activities=self.activities, + schedule=self.schedule, + on_trace_ready=self.trace_handler, + record_shapes=self.record_shapes, + profile_memory=self.profile_memory, + with_stack=self.with_stack, + with_flops=self.with_flops, + with_modules=self.with_modules, + ) + self._profiler.__enter__() + + def _exit_profiler(self): + self._profiler.__exit__() + + def _profiler_step(self): + self.profiler.step() + + def attach(self, engine: Engine,) -> None: + """Attach the profiler to the engine. + + Args: + engine: engine object. + """ + + engine._event_handlers[Events.EPOCH_STARTED].append((self._profiler_create, {}, {})) + engine._event_handlers[Events.GET_BATCH_COMPLETED].append((self._profiler_step, {}, {})) + engine._event_handlers[Events.EPOCH_COMPLETED].append((self._profile_create.__exit__(), {}, {})) + + def get_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): + if sort_key not in self.SORT_KEYS: + raise ValueError( + f" The sort_key {sort_key} is not accepted. Please choose a sort key from {self.SORT_KEYS}" + ) + + return self.profiler.key_averages().table( + sort_by=sort_key, row_limit=n, top_level_events_only=top_level_events_only + ) + + def write_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): + try: + import pandas as pd + except ImportError: + raise RuntimeError("Need pandas to write results as files") + + results_df = pd.DataFrame(self.get_results(n, sort_key, top_level_events_only)) + + now = datetime.now().strftime("%Y%m%d-%H%M%S") + file_name = f"{idist.backend()}_{now}.csv" + + results_df.to_csv(os.path.join(self.output_path, file_name), index=False) + + def print_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): + print(self.get_results(n, sort_key, top_level_events_only)) From 333057e14f22528af96e64181ed741e804317d5d Mon Sep 17 00:00:00 2001 From: Ishan-Kumar2 Date: Tue, 9 Nov 2021 14:39:17 +0000 Subject: [PATCH 02/10] autopep8 fix --- ignite/handlers/pytorch_profiler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ignite/handlers/pytorch_profiler.py b/ignite/handlers/pytorch_profiler.py index dacb2ae3316..7417f839005 100644 --- a/ignite/handlers/pytorch_profiler.py +++ b/ignite/handlers/pytorch_profiler.py @@ -1,12 +1,13 @@ # coding: utf-8 -from typing import Any, Callable, Union -import os -from ignite.engine import Engine, Events -import ignite.distributed as idist import datetime +import os +from typing import Any, Callable, Union import torch +import ignite.distributed as idist +from ignite.engine import Engine, Events + class PyTorchProfiler: """PyTorch Profiler for performance debugging. From 58c2d1855709ebab262bd45aeed6758aa26dd402 Mon Sep 17 00:00:00 2001 From: Ishan Kumar Date: Mon, 27 Dec 2021 01:43:07 +0530 Subject: [PATCH 03/10] updated attach --- ignite/handlers/pytorch_profiler.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/ignite/handlers/pytorch_profiler.py b/ignite/handlers/pytorch_profiler.py index 7417f839005..4333dda5001 100644 --- a/ignite/handlers/pytorch_profiler.py +++ b/ignite/handlers/pytorch_profiler.py @@ -98,15 +98,14 @@ def _profiler_create(self): profile_memory=self.profile_memory, with_stack=self.with_stack, with_flops=self.with_flops, - with_modules=self.with_modules, ) self._profiler.__enter__() def _exit_profiler(self): - self._profiler.__exit__() + self._profiler.__exit__(0, 0, 0) def _profiler_step(self): - self.profiler.step() + self._profiler.step() def attach(self, engine: Engine,) -> None: """Attach the profiler to the engine. @@ -114,10 +113,9 @@ def attach(self, engine: Engine,) -> None: Args: engine: engine object. """ - - engine._event_handlers[Events.EPOCH_STARTED].append((self._profiler_create, {}, {})) - engine._event_handlers[Events.GET_BATCH_COMPLETED].append((self._profiler_step, {}, {})) - engine._event_handlers[Events.EPOCH_COMPLETED].append((self._profile_create.__exit__(), {}, {})) + engine.add_event_handler(Events.EPOCH_STARTED, self._profiler_create) + engine.add_event_handler(Events.GET_BATCH_COMPLETED, self._profiler_step) + engine.add_event_handler(Events.EPOCH_COMPLETED, self._exit_profiler) def get_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): if sort_key not in self.SORT_KEYS: @@ -143,4 +141,4 @@ def write_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", t results_df.to_csv(os.path.join(self.output_path, file_name), index=False) def print_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): - print(self.get_results(n, sort_key, top_level_events_only)) + print(self.get_results(n, sort_key, top_level_events_only)) \ No newline at end of file From dafb4d5b9be850068cbef3eeffe2c77f79dc5b7d Mon Sep 17 00:00:00 2001 From: Ishan Kumar Date: Mon, 27 Dec 2021 01:44:35 +0530 Subject: [PATCH 04/10] formatting fix --- ignite/handlers/pytorch_profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/handlers/pytorch_profiler.py b/ignite/handlers/pytorch_profiler.py index 4333dda5001..18ee91e6bbf 100644 --- a/ignite/handlers/pytorch_profiler.py +++ b/ignite/handlers/pytorch_profiler.py @@ -141,4 +141,4 @@ def write_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", t results_df.to_csv(os.path.join(self.output_path, file_name), index=False) def print_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): - print(self.get_results(n, sort_key, top_level_events_only)) \ No newline at end of file + print(self.get_results(n, sort_key, top_level_events_only)) From f99a2ba1d66769013dd9f432caf53bfbf6984a53 Mon Sep 17 00:00:00 2001 From: Ishan Kumar Date: Wed, 29 Dec 2021 18:29:17 +0530 Subject: [PATCH 05/10] added import in __init__.py --- ignite/handlers/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ignite/handlers/__init__.py b/ignite/handlers/__init__.py index d14d6c3963a..69a0bd48a5f 100644 --- a/ignite/handlers/__init__.py +++ b/ignite/handlers/__init__.py @@ -18,6 +18,7 @@ PiecewiseLinear, create_lr_scheduler_with_warmup, ) +from ignite.handlers.pytorch_profiler import PyTorchProfiler from ignite.handlers.state_param_scheduler import ( ExpStateScheduler, LambdaStateScheduler, @@ -62,6 +63,7 @@ "ExpStateScheduler", "StepStateScheduler", "MultiStepStateScheduler", + "PyTorchProfiler", ] From e738d691008b1316ce9ce2d205d5e6dfdfc8dffa Mon Sep 17 00:00:00 2001 From: Ishan-Kumar2 Date: Wed, 29 Dec 2021 13:21:15 +0000 Subject: [PATCH 06/10] autopep8 fix --- ignite/handlers/pytorch_profiler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ignite/handlers/pytorch_profiler.py b/ignite/handlers/pytorch_profiler.py index 18ee91e6bbf..21998cf861f 100644 --- a/ignite/handlers/pytorch_profiler.py +++ b/ignite/handlers/pytorch_profiler.py @@ -107,7 +107,10 @@ def _exit_profiler(self): def _profiler_step(self): self._profiler.step() - def attach(self, engine: Engine,) -> None: + def attach( + self, + engine: Engine, + ) -> None: """Attach the profiler to the engine. Args: From 74647fbe79b96a37af9b669a80349f4ebcd7b51b Mon Sep 17 00:00:00 2001 From: Ishan Kumar Date: Mon, 3 Jan 2022 11:43:24 +0530 Subject: [PATCH 07/10] add tests and modified write_results to store as txt --- ignite/handlers/pytorch_profiler.py | 21 +++----- .../ignite/handlers/test_pytorch_profiler.py | 51 +++++++++++++++++++ 2 files changed, 57 insertions(+), 15 deletions(-) create mode 100644 tests/ignite/handlers/test_pytorch_profiler.py diff --git a/ignite/handlers/pytorch_profiler.py b/ignite/handlers/pytorch_profiler.py index 21998cf861f..4f5003dd4b5 100644 --- a/ignite/handlers/pytorch_profiler.py +++ b/ignite/handlers/pytorch_profiler.py @@ -1,5 +1,5 @@ # coding: utf-8 -import datetime +from datetime import datetime import os from typing import Any, Callable, Union @@ -107,10 +107,7 @@ def _exit_profiler(self): def _profiler_step(self): self._profiler.step() - def attach( - self, - engine: Engine, - ) -> None: + def attach(self, engine: Engine,) -> None: """Attach the profiler to the engine. Args: @@ -126,22 +123,16 @@ def get_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top f" The sort_key {sort_key} is not accepted. Please choose a sort key from {self.SORT_KEYS}" ) - return self.profiler.key_averages().table( + return self._profiler.key_averages().table( sort_by=sort_key, row_limit=n, top_level_events_only=top_level_events_only ) def write_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): - try: - import pandas as pd - except ImportError: - raise RuntimeError("Need pandas to write results as files") - - results_df = pd.DataFrame(self.get_results(n, sort_key, top_level_events_only)) - now = datetime.now().strftime("%Y%m%d-%H%M%S") - file_name = f"{idist.backend()}_{now}.csv" + file_name = f"{idist.backend()}_{now}.txt" - results_df.to_csv(os.path.join(self.output_path, file_name), index=False) + with open(os.path.join(self.output_path, file_name), "w") as f: + f.write(self.get_results(n, sort_key, top_level_events_only)) def print_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): print(self.get_results(n, sort_key, top_level_events_only)) diff --git a/tests/ignite/handlers/test_pytorch_profiler.py b/tests/ignite/handlers/test_pytorch_profiler.py new file mode 100644 index 00000000000..491123076fd --- /dev/null +++ b/tests/ignite/handlers/test_pytorch_profiler.py @@ -0,0 +1,51 @@ +import glob +import os + +import pytest +import torch + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.handlers import PyTorchProfiler + + +def update_fn(engine, batch): + a = torch.empty((2, 3), dtype=torch.int32) + b = torch.empty((3, 3), dtype=torch.int32) + + return a + torch.mm(a, b) + + +def get_engine(): + dummy_trainer = Engine(update_fn) + return dummy_trainer + + +def test_get_results(tmp_path): + trainer = get_engine() + pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path=tmp_path) + pt_profiler.attach(trainer) + trainer.run(range(10), max_epochs=1) + + with pytest.raises(ValueError, match=r" The sort_key cpu_times is not accepted. Please choose a sort key from"): + pt_profiler.get_results(sort_key="cpu_times") + + +def test_write_results(tmp_path): + n = 5 + + trainer = get_engine() + pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path=tmp_path) + pt_profiler.attach(trainer) + trainer.run(range(10), max_epochs=1) + pt_profiler.write_results(n=n) + + fp = glob.glob(os.path.join(tmp_path, f"{idist.backend()}_*"))[0 - 1] + assert os.path.isfile(fp) + + file_length = 0 + with open(fp, "r") as fp: + for _ in fp: + file_length += 1 + + assert file_length == n + 5 From bbfbf8f3c2f9114e6b1fe4ba12ba6d5c1de82f70 Mon Sep 17 00:00:00 2001 From: Ishan-Kumar2 Date: Mon, 3 Jan 2022 06:14:45 +0000 Subject: [PATCH 08/10] autopep8 fix --- ignite/handlers/pytorch_profiler.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ignite/handlers/pytorch_profiler.py b/ignite/handlers/pytorch_profiler.py index 4f5003dd4b5..09bb391c757 100644 --- a/ignite/handlers/pytorch_profiler.py +++ b/ignite/handlers/pytorch_profiler.py @@ -1,6 +1,6 @@ # coding: utf-8 -from datetime import datetime import os +from datetime import datetime from typing import Any, Callable, Union import torch @@ -107,7 +107,10 @@ def _exit_profiler(self): def _profiler_step(self): self._profiler.step() - def attach(self, engine: Engine,) -> None: + def attach( + self, + engine: Engine, + ) -> None: """Attach the profiler to the engine. Args: From bf753bc580a4a362d193749252e9599442867787 Mon Sep 17 00:00:00 2001 From: Ishan Kumar Date: Wed, 19 Jan 2022 18:49:35 +0530 Subject: [PATCH 09/10] added more tests and refactored code --- ignite/handlers/pytorch_profiler.py | 177 ++++++++++++++---- .../ignite/handlers/test_pytorch_profiler.py | 158 ++++++++++++++-- 2 files changed, 284 insertions(+), 51 deletions(-) diff --git a/ignite/handlers/pytorch_profiler.py b/ignite/handlers/pytorch_profiler.py index 4f5003dd4b5..cf7d9d73d23 100644 --- a/ignite/handlers/pytorch_profiler.py +++ b/ignite/handlers/pytorch_profiler.py @@ -1,6 +1,7 @@ # coding: utf-8 -from datetime import datetime import os +import socket +from datetime import datetime from typing import Any, Callable, Union import torch @@ -16,6 +17,26 @@ class PyTorchProfiler: information, correlates them, performs automatic detection of bottlenecks in the model, and generates recommendations on how to resolve these bottlenecks. + Args: + cuda_activity: If true, records GPU activity in addition to CPU activity, + on_trace_ready: Function that takes a reference to the profiler as an input + and is called by the profiler each time the new trace is ready, + Accepts custom function definition, as well as `tensorboard`, `flame_graph` and `chrome` as handlers. + record_shapes: whether to record shapes of the inputs (necessary if you want to group profiler output by shapes) + profile_memory: whether to report amount of memory consumed by model's Tensors + with_stack: whether to record source information for the operations (necessary for flamegraph), + with_flops: whether to use formula to estimate the FLOPS of specific ops (matrix multiplication and 2D conv), + with_modules: whether to record module hierarchy (including function names) corresponding + to the callstack of the op. e.g. If module A's forward call's module B's forward which + contains an aten::add op, then aten::add's module hierarchy is A.B + output_path: Directory where file should be placed, + file_name: name of output file generated, + skip_first: Scheduling parameter, the profiler first skips the first `skip_first` number of steps + wait: Scheduling parameter, the profiler waits for `wait` number of steps + warmup: Scheduling Parameter, the profile warms up for `warmup` number of steps + active: Scheduling Parameter, the profiler does active profiling for the `active` number of steps + repeat: Scheduling Parameter, number of cycles, 0 means that cycles will continue until profiling is finished. + Examples: .. code-block:: python @@ -31,14 +52,46 @@ class PyTorchProfiler: # Get profiler results of time pt_profiler.print_results() - # Save profiler result to CSV file (requires pandas) + # Save profiler result to text file pt_profiler.write_results() Both these methods can also be used as the on_trace_ready function which gets called after trace is ready. - pt_profiler = PyTorchProfiler(on_trace_ready=profiler.write_to_file(10), output_path="logs/train") - .. versionadded:: 0.4.8 + #The on_trace_handler accepts 3 strings `tensorboard`, `chrome` and `flamegraph` + #Tensorboard + pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path="./logs/train") + + #To view this file enusre you have the PyTorch Profiler Tensorboard Plugin + pip install torch_tb_profiler + + #Then launch tensorboard + tensorboard --logdir=./logs + + #Chrome + #Profiling results can be outputted as a .json trace file which can be viewed in the Chrome Trace viewer + pt_profiler = PyTorchProfiler(on_trace_ready="chrome", output_path="./logs/train") + + #Open `chrome://tracing` on chrome and upload this file + + #Flamegraph + Execution times can be visualised as a flamegraph + pt_profiler = PyTorchProfiler(on_trace_ready="flamegraph", + output_path="./logs/train", file_name = "fg", with_stack=True) + + # To view as an interactive SVG + # git clone https://github.com/brendangregg/FlameGraph + # cd FlameGraph + # ./flamegraph.pl --title "CPU time" --countname "us." ./logs/train/fg_cpu_flamegraph.txt > perf_viz.svg + + #Custom Trace Handlers can also be used + def trace_handler(p): + output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=10) + print(output) + p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json") + pt_profiler = PyTorchProfiler(on_trace_ready=trace_handler, output_path="logs/train") + + .. versionadded:: 0.5.0 """ def __init__( @@ -51,33 +104,70 @@ def __init__( with_flops: bool = False, with_modules: bool = False, output_path: str = None, - wait: int = 2, - warmup: int = 2, - active: int = 6, + file_name: str = None, + skip_first: int = 0, + wait: int = 1, + warmup: int = 1, + active: int = 3, repeat: int = 1, ) -> None: - self.activities = [torch.profiler.ProfilerActivity.CPU] + self._activities = [torch.profiler.ProfilerActivity.CPU] if cuda_activity and torch.cuda.is_available(): - self.activities.append(torch.profiler.ProfilerActivity.GPU) + self._activities.append(torch.profiler.ProfilerActivity.GPU) - self.output_path = output_path + self._output_path = output_path + self._file_name = file_name - self.schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat) + now = datetime.now().strftime("%Y%m%d-%H%M%S") + if not self._file_name: + self._file_name = f"{idist.backend()}_{now}_{socket.gethostname()}_{str(os.getpid())}" - self.trace_handler = ( - torch.profiler.tensorboard_trace_handler(self.output_path) - if on_trace_ready == "tensorboard" - else on_trace_ready - ) + self._with_stack = with_stack - self.record_shapes = record_shapes - self.profile_memory = profile_memory - self.with_stack = with_stack - self.with_flops = with_flops - self.with_modules = with_modules + self._schedule = torch.profiler.schedule( + wait=wait, warmup=warmup, active=active, repeat=repeat, skip_first=skip_first + ) - self.SORT_KEYS = { + if on_trace_ready == "tensorboard": + self._trace_handler = torch.profiler.tensorboard_trace_handler(self._output_path) + + elif on_trace_ready == "chrome": + + def chrome_trace(prof) -> None: + prof.export_chrome_trace(os.path.join(self._output_path, self._file_name + "_chrome_trace.json")) + + self._trace_handler = chrome_trace + + elif on_trace_ready == "flamegraph": + if not with_stack: + raise ValueError("The flag with_stack must be true in order to use flamegraph") + + def flamegraph_trace(prof) -> None: + prof.export_stacks( + os.path.join(self._output_path, self._file_name + "_cpu_flamegraph.txt"), "self_cpu_time_total" + ) + if cuda_activity: + prof.export_stacks( + os.path.join(self._output_path, self._file_name + "_gpu_flamegraph.json"), + "self_cuda_time_total", + ) + + self._trace_handler = flamegraph_trace + else: + if not isinstance(on_trace_ready, Callable): + raise ValueError( + "Trace Handler should be a callable or one of" + f"[`tensorboard`, `chrome`, `flamegraph`]. Found: {on_trace_ready}" + ) + self._trace_handler = on_trace_ready + + self._record_shapes = record_shapes + self._profile_memory = profile_memory + self._with_flops = with_flops + self._with_modules = with_modules + + self._SORT_KEYS = { "cpu_time", "cuda_time", "cpu_time_total", @@ -91,18 +181,20 @@ def __init__( def _profiler_create(self): self._profiler = torch.profiler.profile( - activities=self.activities, - schedule=self.schedule, - on_trace_ready=self.trace_handler, - record_shapes=self.record_shapes, - profile_memory=self.profile_memory, - with_stack=self.with_stack, - with_flops=self.with_flops, + activities=self._activities, + schedule=self._schedule, + on_trace_ready=self._trace_handler, + record_shapes=self._record_shapes, + profile_memory=self._profile_memory, + with_stack=self._with_stack, + with_flops=self._with_flops, ) + + def _profiler_enter(self): self._profiler.__enter__() def _exit_profiler(self): - self._profiler.__exit__(0, 0, 0) + self._profiler.__exit__(None, None, None) def _profiler_step(self): self._profiler.step() @@ -113,25 +205,34 @@ def attach(self, engine: Engine,) -> None: Args: engine: engine object. """ + if not isinstance(engine, Engine): + raise TypeError(f"Argument engine should be ignite.engine.Engine, but given {type(engine)}") + engine.add_event_handler(Events.EPOCH_STARTED, self._profiler_create) - engine.add_event_handler(Events.GET_BATCH_COMPLETED, self._profiler_step) + engine.add_event_handler(Events.EPOCH_STARTED, self._profiler_enter) + engine.add_event_handler(Events.ITERATION_COMPLETED, self._profiler_step) engine.add_event_handler(Events.EPOCH_COMPLETED, self._exit_profiler) - def get_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): - if sort_key not in self.SORT_KEYS: + def get_results( + self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False, group_by_shapes=False + ): + if sort_key not in self._SORT_KEYS: raise ValueError( - f" The sort_key {sort_key} is not accepted. Please choose a sort key from {self.SORT_KEYS}" + f" The sort_key {sort_key} is not accepted. Please choose a sort key from {self._SORT_KEYS}" ) - return self._profiler.key_averages().table( + if group_by_shapes and self._record_shapes is False: + raise ValueError( + "Running with group_by_input_shape=True requires running the profiler with record_shapes=True" + ) + + return self._profiler.key_averages(group_by_input_shape=group_by_shapes).table( sort_by=sort_key, row_limit=n, top_level_events_only=top_level_events_only ) def write_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): - now = datetime.now().strftime("%Y%m%d-%H%M%S") - file_name = f"{idist.backend()}_{now}.txt" - with open(os.path.join(self.output_path, file_name), "w") as f: + with open(os.path.join(self._output_path, self._file_name + ".txt"), "w") as f: f.write(self.get_results(n, sort_key, top_level_events_only)) def print_results(self, n: int = -1, sort_key: str = "self_cuda_memory_usage", top_level_events_only=False): diff --git a/tests/ignite/handlers/test_pytorch_profiler.py b/tests/ignite/handlers/test_pytorch_profiler.py index 491123076fd..fe87390aef6 100644 --- a/tests/ignite/handlers/test_pytorch_profiler.py +++ b/tests/ignite/handlers/test_pytorch_profiler.py @@ -1,19 +1,21 @@ -import glob import os import pytest import torch -import ignite.distributed as idist from ignite.engine import Engine from ignite.handlers import PyTorchProfiler -def update_fn(engine, batch): - a = torch.empty((2, 3), dtype=torch.int32) - b = torch.empty((3, 3), dtype=torch.int32) +def clean_string(s): + return s.lstrip().rstrip() + - return a + torch.mm(a, b) +def update_fn(engine, batch): + x = torch.randn((1, 8), requires_grad=True) + y = torch.randn((8, 1), requires_grad=True) + z = torch.matmul(x, y) + z.backward() def get_engine(): @@ -21,26 +23,156 @@ def get_engine(): return dummy_trainer -def test_get_results(tmp_path): +def output_string_to_dict(output_string): + output_string = output_string.split("\n") + + # Removing the formatting and headers + output_string = output_string[3:-3] + + output_string_split = dict() + + for _output_string in output_string: + split_string = _output_string.split(" ") + split_string = [clean_string(i) for i in split_string if i != ""] + # Using name and shape as key to distinguish between same operation with different shapes + output_string_split[split_string[0] + split_string[-1]] = split_string[1:] + + return output_string_split + + +def check_profiler_output(data, sort_key="cpu_time", wait=1, warmup=1, active=3, repeat=1): + # Returns output of PyTorch Profiler directly (Without using Ignite handler) for comparison + + from torch.profiler import ProfilerActivity, profile, schedule + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=schedule(wait=wait, warmup=warmup, active=active, repeat=repeat), + record_shapes=True, + ) as prof: + for d in data: + x = torch.randn((1, 8), requires_grad=True) + y = torch.randn((8, 1), requires_grad=True) + z = torch.matmul(x, y) + z.backward() + prof.step() + return prof.key_averages(group_by_input_shape=True).table(sort_by=sort_key) + + +def get_both_profiler_outputs(data_len, path, epoch, wait=1, warmup=1, active=3, repeat=1): + data = [i for i in range(data_len)] trainer = get_engine() - pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path=tmp_path) + pt_profiler = PyTorchProfiler( + on_trace_ready="tensorboard", + output_path=path, + record_shapes=True, + wait=wait, + warmup=warmup, + active=active, + repeat=repeat, + with_stack=True, + ) pt_profiler.attach(trainer) - trainer.run(range(10), max_epochs=1) + trainer.run(data, max_epochs=epoch) + output_string = pt_profiler.get_results(sort_key="cpu_time", group_by_shapes=True) + + if not torch.cuda.is_available(): + with pytest.warns(UserWarning): + ref_output = check_profiler_output(data, "cpu_time", wait=wait, warmup=warmup, active=active, repeat=repeat) + else: + ref_output = check_profiler_output(data, "cpu_time", wait=wait, warmup=warmup, active=active, repeat=repeat) + return ref_output, output_string + + +def test_profilers_wrong_inputs(): + pt_profiler = PyTorchProfiler() + + with pytest.raises(TypeError, match=r"Argument engine should be ignite.engine.Engine"): + pt_profiler.attach(None) with pytest.raises(ValueError, match=r" The sort_key cpu_times is not accepted. Please choose a sort key from"): pt_profiler.get_results(sort_key="cpu_times") + with pytest.raises( + ValueError, + match=r"Running with group_by_input_shape=True requires running the profiler with record_shapes=True", + ): + pt_profiler.get_results(group_by_shapes=True) + + with pytest.raises(ValueError, match=r"The flag with_stack must be true in order to use flamegraph"): + pt_profiler = PyTorchProfiler(on_trace_ready="flamegraph", with_stack=False) + + with pytest.raises(ValueError, match=r"Trace Handler should be a callable or one of"): + pt_profiler = PyTorchProfiler(on_trace_ready=10, with_stack=False) + + +@pytest.mark.parametrize("data_len", [1, 6, 10]) +@pytest.mark.parametrize("epoch", [1, 2, 10]) +def test_get_results(epoch, data_len, tmp_path): + ref_output, output_string = get_both_profiler_outputs(data_len, tmp_path, epoch) + print(output_string, ref_output) + output_dict = output_string_to_dict(output_string) + ref_output_dict = output_string_to_dict(ref_output) + + for _key in output_dict.keys(): + # Checks number of calls are same in both profilers + assert output_dict[_key][5] == ref_output_dict[_key][5] + # Checks shapes + assert output_dict[_key][6] == ref_output_dict[_key][6] + + # Check number of elements recorded + assert len(output_dict) == len(ref_output_dict) + + +@pytest.mark.parametrize("wait,warmup,active,repeat", [(99, 2, 1, 1), (2, 99, 1, 1), (99, 2, 1, 2)]) +@pytest.mark.parametrize("epoch", [1, 2, 10]) +def test_none_output(epoch, tmp_path, wait, warmup, active, repeat): + trainer = get_engine() + pt_profiler = PyTorchProfiler( + on_trace_ready="tensorboard", output_path=tmp_path, wait=wait, warmup=warmup, active=active, repeat=repeat + ) + pt_profiler.attach(trainer) + trainer.run(range(100), max_epochs=epoch) + assert pt_profiler.get_results() == "" + + +@pytest.mark.parametrize("wait,warmup,active,repeat", [(1, 1, 2, 1), (6, 2, 92, 2), (99, 1, 10, 10)]) +@pytest.mark.parametrize("epoch", [1, 2, 10]) +def test_schedule(epoch, tmp_path, wait, warmup, active, repeat): + ref_output, output_string = get_both_profiler_outputs(100, tmp_path, epoch, wait, warmup, active, repeat) + + output_dict = output_string_to_dict(output_string) + ref_output_dict = output_string_to_dict(ref_output) + print(output_string, ref_output) + + for _key in output_dict.keys(): + assert output_dict[_key][5] == ref_output_dict[_key][5], print(_key) + assert output_dict[_key][6] == ref_output_dict[_key][6] + + # Check number of elements recorded + assert len(output_dict) == len(ref_output_dict) + + +@pytest.mark.parametrize("epoch", [1, 5, 100]) +def test_multiple_epochs_files(epoch, tmp_path): + # Number of files should be same as epochs + trainer = get_engine() + pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path=tmp_path, with_stack=True) + pt_profiler.attach(trainer) + trainer.run(range(20), max_epochs=epoch) + assert epoch == len(os.listdir(tmp_path)) -def test_write_results(tmp_path): - n = 5 +@pytest.mark.parametrize("n", [1, 5, 10]) +def test_write_results(n, tmp_path): + # File Length should be equal to n (row limit) trainer = get_engine() - pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path=tmp_path) + pt_profiler = PyTorchProfiler(on_trace_ready="tensorboard", output_path=tmp_path, file_name="testing_file") pt_profiler.attach(trainer) trainer.run(range(10), max_epochs=1) pt_profiler.write_results(n=n) - fp = glob.glob(os.path.join(tmp_path, f"{idist.backend()}_*"))[0 - 1] + fp = os.path.join(tmp_path, "testing_file.txt") assert os.path.isfile(fp) file_length = 0 From 27dc96fb483c2fdd65a62acb2fca347b39e30236 Mon Sep 17 00:00:00 2001 From: Ishan-Kumar2 Date: Wed, 19 Jan 2022 13:22:50 +0000 Subject: [PATCH 10/10] autopep8 fix --- ignite/handlers/pytorch_profiler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ignite/handlers/pytorch_profiler.py b/ignite/handlers/pytorch_profiler.py index cf7d9d73d23..86359df9078 100644 --- a/ignite/handlers/pytorch_profiler.py +++ b/ignite/handlers/pytorch_profiler.py @@ -199,7 +199,10 @@ def _exit_profiler(self): def _profiler_step(self): self._profiler.step() - def attach(self, engine: Engine,) -> None: + def attach( + self, + engine: Engine, + ) -> None: """Attach the profiler to the engine. Args: