Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torch/_dynamo/variables/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def is_placement_type(value):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol Can you remind me why we needed this variable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This boils down to the "custom fx type" that we've been discussing. So fundamentally we need to inline Placement and DeviceMesh as a constant in the closure function (i.e. from_local) and put that closure function to the fx graph as a call_function node.

But if we don't have this PlacementClassVariable, dynamo would trace the DTensor's metadata construction as a UserDefinedClass/Object (i.e. Shard(1)). UDTs are not constant variable (as it's hard to tell whether a UDT is a ConstantVariable unless user explicitly tell dynamo the objects are ConstantVariable). So this PlacementClassVariable allow us to turn the Shard(1) as a PlacementVariable (which is a constant variable), and then the sharding metadata can be inlined as a closure

return type(value) is type and issubclass(value, Placement)

def as_python_constant(self):
return self.value

def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
Expand Down