Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions test/test_profile_mp_mnist.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,38 @@
"""Fork of test_train_mp_mnist.py to demonstrate how to profile workloads."""
import args_parse

profile_opts = {
'--profile_step': {
'type': int,
'default': -1,
'help': 'Step at which to trigger a profile programmatically',
},
'--profile_epoch': {
'type': int,
'default': -1,
'help': 'Epoch at which to trigger a profile programmatically',
},
'--profile_logdir': {
'type': str,
'default': None,
'help': 'Path to store programmatically-triggered profiles',
},
'--profile_duration_ms': {
'type': int,
'default': 5000,
'help': 'Duration of programmatically-triggered profile captures'
},
}

FLAGS = args_parse.parse_common_options(
datadir='/tmp/mnist-data',
batch_size=128,
momentum=0.5,
lr=0.01,
target_accuracy=98.0,
num_epochs=18,
profiler_port=9012)
profiler_port=9012,
opts=profile_opts.items())

import os
import shutil
Expand Down Expand Up @@ -129,11 +153,20 @@ def train_mnist(flags,
loss_fn = nn.NLLLoss()

server = xp.start_server(flags.profiler_port)
profile_step = flags.profile_step
profile_epoch = flags.profile_epoch

def train_loop_fn(loader):
def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
model.train()
for step, (data, target) in enumerate(loader):
if epoch == profile_epoch and step == profile_step and xm.is_master_ordinal(
):
# Take a profile in a background thread
xp.trace_detached(
f'localhost:{flags.profiler_port}',
flags.profile_logdir,
duration_ms=flags.profile_duration_ms)
if dynamic_graph:
# testing purpose only: dynamic batch size and graph.
index = max(-step, -flags.batch_size + 1) # non-empty
Expand Down Expand Up @@ -177,7 +210,7 @@ def test_loop_fn(loader):
accuracy, max_accuracy = 0.0, 0.0
for epoch in range(1, flags.num_epochs + 1):
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
train_loop_fn(train_device_loader)
train_loop_fn(train_device_loader, epoch)
xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))

accuracy = test_loop_fn(test_device_loader)
Expand Down
46 changes: 46 additions & 0 deletions test/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import sys
import tempfile
import time
import unittest

import args_parse
Expand Down Expand Up @@ -63,6 +64,13 @@ def train_worker():
num_epochs=10)
flags.fake_data = True
flags.profiler_port = port

# Disable programmatic profiling
flags.profile_step = -1
flags.profile_epoch = -1
flags.profile_logdir = None
flags.profile_duration_ms = -1

test_profile_mp_mnist.train_mnist(
flags,
training_started=training_started,
Expand All @@ -85,6 +93,44 @@ def train_worker():
self._check_trace_namespace_exists(path)
self._check_metrics_warnings_exist(self.fname)

def test_trace_detached(self):

port = xu.get_free_tcp_ports()[0]
training_started = multiprocessing.Event()
logdir = tempfile.mkdtemp()

def train_worker():
flags = args_parse.parse_common_options(
datadir='/tmp/mnist-data',
batch_size=16,
momentum=0.5,
lr=0.01,
num_epochs=10)
flags.fake_data = True
flags.profiler_port = port

# Set programmatic capture options
flags.profile_step = 10
flags.profile_epoch = 1
flags.profile_logdir = logdir
flags.profile_duration_ms = 5000

test_profile_mp_mnist.train_mnist(
flags,
training_started=training_started,
dynamic_graph=True,
fetch_often=True)

p = multiprocessing.Process(target=train_worker, daemon=True)
p.start()
training_started.wait(60)
# Delay to allow the profile to capture
time.sleep(5)
p.terminate()
path = self._check_xspace_pb_exist(logdir)
self._check_trace_namespace_exists(path)
self._check_metrics_warnings_exist(self.fname)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/debug/profiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import threading
import torch_xla
import torch_xla.core.xla_model as xm

Expand Down Expand Up @@ -89,6 +90,15 @@ def trace(service_addr: str,
options=options)


def trace_detached(*args, **kwargs):
"""
Wraps the :func:`~torch_xla.debug.profiler.trace` method to capture a profile
in a background thread. See that method for the list of supported parameters
and their semantics.
"""
threading.Thread(target=trace, args=args, kwargs=kwargs).start()


class Trace(torch_xla._XLAC.profiler.TraceMe):
"""Context manager that produces a trace event for profiling.

Expand Down