Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 32 additions & 25 deletions test/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@
import torch_xla.utils.utils as xu


# This function must remain a top-level function. Using spawn
# as the fork method requires this function to be pickle-able.
def train_worker(port, training_started):
flags = args_parse.parse_common_options(
datadir='/tmp/mnist-data',
batch_size=16,
momentum=0.5,
lr=0.01,
num_epochs=10)
flags.fake_data = True
flags.profiler_port = port

# Disable programmatic profiling
flags.profile_step = -1
flags.profile_epoch = -1
flags.profile_logdir = None
flags.profile_duration_ms = -1

test_profile_mp_mnist.train_mnist(
flags,
training_started=training_started,
dynamic_graph=True,
fetch_often=True)


class ProfilerTest(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -51,33 +76,15 @@ def _check_trace_namespace_exists(self, path):
f'Expected "build_graph" trace in: {path}')

def test_trace_and_metrics(self):
# Create a new context for forking processes with the spawn method.
# This is necessary so as to avoid CUDA initialization issues when
# both PyTorch and PyTorch/XLA were compiled with CUDA support.
context = multiprocessing.get_context("spawn")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ysiraichi IIUC, the failure happens when we initialize CUDA in the parent process and use CUDA in the child process. I wonder where we initialize CUDA in the parent process before your change in this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some investigation, I believe it comes from importing torch_xla. Specifically, the following chain:

  • torch_xla
  • stablehlo
  • dynamo_bridge
  • torch._inductor.fx_passes.post_grad

I guess, one way to solve this issue is to move ConstructorMoverPass out of inductor tree.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reply.

Curious how do you know torch._inductor.fx_passes.post_grad initializes a CUDA context.
Also, what do you mean by move ConstructorMoverPass out of inductor tree?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do you know torch._inductor.fx_passes.post_grad initializes a CUDA context.

Just by commenting it out, the problem goes away.

what do you mean by move ConstructorMoverPass out of inductor tree?

This class is declared under inductor module. Importing it means that we have to load the inductor module itself, which initializes a CUDA context. If the class is declared somewhere else (which is possible, since it doesn't really depend on anything of inductor), that initialization goes away


port = xu.get_free_tcp_ports()[0]
training_started = multiprocessing.Event()

def train_worker():
flags = args_parse.parse_common_options(
datadir='/tmp/mnist-data',
batch_size=16,
momentum=0.5,
lr=0.01,
num_epochs=10)
flags.fake_data = True
flags.profiler_port = port

# Disable programmatic profiling
flags.profile_step = -1
flags.profile_epoch = -1
flags.profile_logdir = None
flags.profile_duration_ms = -1

test_profile_mp_mnist.train_mnist(
flags,
training_started=training_started,
dynamic_graph=True,
fetch_often=True)

p = multiprocessing.Process(target=train_worker, daemon=True)
training_started = context.Event()
p = context.Process(
target=train_worker, args=(port, training_started), daemon=True)
p.start()
training_started.wait(60)

Expand Down