Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Conversation

williamwen42
Copy link
Member

Fixes #1167 and related bugs where the current translator object is not properly given to FX graph nodes.

Example: https://gist.github.com/williamwen42/1efa7ce86b7ac797e90a13356668d90d
(Note: running the above script before this change requires this modification to the code:

--- a/torchdynamo/output_graph.py
+++ b/torchdynamo/output_graph.py
@@ -373,7 +373,7 @@ class OutputGraph(fx.Tracer):
             # the call to tabulate can cause a lot of memory to be allocated
             if config.log_level <= logging.INFO:
                 log.info(
-                    f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {format_graph_tabular(gm.graph)}\n"
+                    f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {gm.print_readable()}\n"
                 )

Running after this change requires setting torchdynamo.config.output_graph_code = True).

Running python benchmarks/torchbench.py --only BERT_pytorch --performance --verbose now gives:

class GraphModule(torch.nn.Module):
    def forward(self, inputs_0_ : torch.Tensor, inputs_1_ : torch.Tensor):
        
        # Module stack: {'mod': 'BERT'}, File: /scratch/williamwen/work/torchbenchmark/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/bert.py:40, code: mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        gt = inputs_0_ > 0
        unsqueeze = gt.unsqueeze(1);  gt = None
        repeat = unsqueeze.repeat(1, 128, 1);  unsqueeze = None
        unsqueeze_1 = repeat.unsqueeze(1);  repeat = None
        
        # Module stack: {'mod': 'BERT', 'mod_embedding': 'BERTEmbedding', 'mod_embedding_token': 'TokenEmbedding'}, File: /scratch/williamwen/work/torchdynamo-env/lib/python3.8/site-packages/torch/nn/modules/sparse.py:161, code: input, self.weight, self.padding_idx, self.max_norm,
        mod_embedding_token_weight = self.mod_embedding_token_weight
        
        # Module stack: {'mod': 'BERT', 'mod_embedding': 'BERTEmbedding', 'mod_embedding_token': 'TokenEmbedding'}, File: /scratch/williamwen/work/torchdynamo-env/lib/python3.8/site-packages/torch/nn/modules/sparse.py:160, code: return F.embedding(
        embedding = torch.nn.functional.embedding(inputs_0_, mod_embedding_token_weight, 0, None, 2.0, False, False);  inputs_0_ = mod_embedding_token_weight = None

@mlazos
Copy link
Contributor

mlazos commented Oct 15, 2022

what's the difference between this new stack of tx's that you're adding to OutputGraph and the existing stack (see tx.parent) Could that stack be used instead? Otherwise, LGTM.

Copy link
Contributor

@voznesenskym voznesenskym left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm

def fake_mode(self):
return self.root_tx.fake_mode

def push_tx(self, tx):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you went with the stack after all :) glad it worked

def push_tx(self, tx):
self._current_tx.append(tx)

def pop_tx(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: idiomatic for pop to return what it popped

return result

def unpack_var_sequence(self, tx):
def unpack_var_sequence_range(self, tx, range):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I wonder if we eventually need to do this for all unpacks? As in, make them all take range or provide a default?

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm about to switchover to core, lets hold off on landing

@jansel
Copy link
Contributor

jansel commented Oct 15, 2022

We have migrated torchdynamo to torch._dynamo and will use the pytorch/pytorch repo for future development. Please resubmit this PR to https://github.com/pytorch/pytorch/

More details and instructions to port this PR over can be found in #1588

@jansel jansel closed this Oct 15, 2022
williamwen42 added a commit to pytorch/pytorch that referenced this pull request Oct 17, 2022
@williamwen42
Copy link
Member Author

what's the difference between this new stack of tx's that you're adding to OutputGraph and the existing stack (see tx.parent) Could that stack be used instead? Otherwise, LGTM.

The output graph only keeps a reference to a root (first) tx, not the most current one. The tx stack replaces the current_tx kwarg, since passing around a translator object as a kwarg is quite prone to errors.

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Oct 19, 2022
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Dec 6, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Very first node in FX graph has misattributed stack

5 participants