Skip to content

Commit 24f1a32

Browse files
committed
fix CI
1 parent e293596 commit 24f1a32

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

torch_xla/core/dynamo_bridge.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch_xla.runtime as xr
1919
import torch_xla.utils.utils as xu
2020

21-
debug = os.environ.get("TORCH_XLA_DEBUG") == "1"
21+
debug = os.environ.get("XLA_DYNAMO_DEBUG") == "1"
2222

2323

2424
@dataclasses.dataclass
@@ -322,7 +322,12 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
322322

323323

324324
def extract_internal(xla_model: torch.fx.GraphModule):
325-
xm.mark_step()
325+
if debug:
326+
print('after partitioner')
327+
xla_model._graph.print_tabular()
328+
if any(
329+
torch_xla._XLAC._check_tensor_need_materialization(xla_model.xla_args)):
330+
xm.mark_step(wait=True)
326331
(xla_args_sharding_spec, args_and_out, graph_hash,
327332
arg_index_to_need_update_index, none_remover, graph_input_matcher,
328333
dumb_return_handler, xla_args_need_update) = extract_graph_helper(xla_model)
@@ -445,6 +450,8 @@ def call_module(self, target, args, kwargs):
445450

446451

447452
def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
453+
if debug:
454+
xla_model._graph.print_tabular()
448455
# This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids
449456
xm.mark_step()
450457

0 commit comments

Comments
 (0)