diff --git a/test/test_profile_mp_mnist.py b/test/test_profile_mp_mnist.py index f70a380132ef..2d438ee17f71 100644 --- a/test/test_profile_mp_mnist.py +++ b/test/test_profile_mp_mnist.py @@ -1,6 +1,29 @@ """Fork of test_train_mp_mnist.py to demonstrate how to profile workloads.""" import args_parse +profile_opts = { + '--profile_step': { + 'type': int, + 'default': -1, + 'help': 'Step at which to trigger a profile programmatically', + }, + '--profile_epoch': { + 'type': int, + 'default': -1, + 'help': 'Epoch at which to trigger a profile programmatically', + }, + '--profile_logdir': { + 'type': str, + 'default': None, + 'help': 'Path to store programmatically-triggered profiles', + }, + '--profile_duration_ms': { + 'type': int, + 'default': 5000, + 'help': 'Duration of programmatically-triggered profile captures' + }, +} + FLAGS = args_parse.parse_common_options( datadir='/tmp/mnist-data', batch_size=128, @@ -8,7 +31,8 @@ lr=0.01, target_accuracy=98.0, num_epochs=18, - profiler_port=9012) + profiler_port=9012, + opts=profile_opts.items()) import os import shutil @@ -129,11 +153,20 @@ def train_mnist(flags, loss_fn = nn.NLLLoss() server = xp.start_server(flags.profiler_port) + profile_step = flags.profile_step + profile_epoch = flags.profile_epoch - def train_loop_fn(loader): + def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): + if epoch == profile_epoch and step == profile_step and xm.is_master_ordinal( + ): + # Take a profile in a background thread + xp.trace_detached( + f'localhost:{flags.profiler_port}', + flags.profile_logdir, + duration_ms=flags.profile_duration_ms) if dynamic_graph: # testing purpose only: dynamic batch size and graph. index = max(-step, -flags.batch_size + 1) # non-empty @@ -177,7 +210,7 @@ def test_loop_fn(loader): accuracy, max_accuracy = 0.0, 0.0 for epoch in range(1, flags.num_epochs + 1): xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) - train_loop_fn(train_device_loader) + train_loop_fn(train_device_loader, epoch) xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now())) accuracy = test_loop_fn(test_device_loader) diff --git a/test/test_profiler.py b/test/test_profiler.py index 4d9969974c3d..eb6822d93110 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -5,6 +5,7 @@ import os import sys import tempfile +import time import unittest import args_parse @@ -63,6 +64,13 @@ def train_worker(): num_epochs=10) flags.fake_data = True flags.profiler_port = port + + # Disable programmatic profiling + flags.profile_step = -1 + flags.profile_epoch = -1 + flags.profile_logdir = None + flags.profile_duration_ms = -1 + test_profile_mp_mnist.train_mnist( flags, training_started=training_started, @@ -85,6 +93,44 @@ def train_worker(): self._check_trace_namespace_exists(path) self._check_metrics_warnings_exist(self.fname) + def test_trace_detached(self): + + port = xu.get_free_tcp_ports()[0] + training_started = multiprocessing.Event() + logdir = tempfile.mkdtemp() + + def train_worker(): + flags = args_parse.parse_common_options( + datadir='/tmp/mnist-data', + batch_size=16, + momentum=0.5, + lr=0.01, + num_epochs=10) + flags.fake_data = True + flags.profiler_port = port + + # Set programmatic capture options + flags.profile_step = 10 + flags.profile_epoch = 1 + flags.profile_logdir = logdir + flags.profile_duration_ms = 5000 + + test_profile_mp_mnist.train_mnist( + flags, + training_started=training_started, + dynamic_graph=True, + fetch_often=True) + + p = multiprocessing.Process(target=train_worker, daemon=True) + p.start() + training_started.wait(60) + # Delay to allow the profile to capture + time.sleep(5) + p.terminate() + path = self._check_xspace_pb_exist(logdir) + self._check_trace_namespace_exists(path) + self._check_metrics_warnings_exist(self.fname) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/debug/profiler.py b/torch_xla/debug/profiler.py index a15f075c77ac..cc45c9d81a2a 100644 --- a/torch_xla/debug/profiler.py +++ b/torch_xla/debug/profiler.py @@ -1,4 +1,5 @@ import functools +import threading import torch_xla import torch_xla.core.xla_model as xm @@ -89,6 +90,15 @@ def trace(service_addr: str, options=options) +def trace_detached(*args, **kwargs): + """ + Wraps the :func:`~torch_xla.debug.profiler.trace` method to capture a profile + in a background thread. See that method for the list of supported parameters + and their semantics. + """ + threading.Thread(target=trace, args=args, kwargs=kwargs).start() + + class Trace(torch_xla._XLAC.profiler.TraceMe): """Context manager that produces a trace event for profiling.