Skip to content

Commit

Permalink
Revert "Force synced KJT to trace unbacked SymInt (pytorch#108960)"
Browse files Browse the repository at this point in the history
This reverts commit f9a250c.

Reverted pytorch#108960 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](pytorch#108960 (comment)))
  • Loading branch information
pytorchmergebot authored and pragupta committed Sep 12, 2023
1 parent 9e540a7 commit 1f99f75
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 247 deletions.
204 changes: 0 additions & 204 deletions test/dynamo/test_torchrec.py

This file was deleted.

17 changes: 4 additions & 13 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
DimConstraint,
DimDynamic,
RelaxedUnspecConstraint,
Expand Down Expand Up @@ -886,7 +885,7 @@ def wrap_literal(self, value):
elif unspec and type(value) is int:
# unspecializing int by default, but still
# specialize for the following conditions
if not TracingContext.get().force_unspec_int_unbacked_size_like and (
if (
value in self._common_constants()
# Assume integers from global variables want to be specialized
or not self.source.guard_source().is_local()
Expand Down Expand Up @@ -1084,21 +1083,11 @@ def wrap_unspecialized_primitive(self, value):
if self.name in self.tx.output.unspec_variable_map:
return self.tx.output.unspec_variable_map[self.name]
else:
shape_env = self.tx.output.shape_env
if TracingContext.get().force_unspec_int_unbacked_size_like and isinstance(
value, int
):
wrapped_value = shape_env.create_unbacked_symint()
_constrain_range_for_size(wrapped_value)
self.tx.output.tracked_fakes.append(
TrackedFake(wrapped_value, self.source, None)
)

# NB: We do not do float. For motivation, see
# https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit
# but the general idea is that we generate kernels that can
# take unspecialized floats and use them in sizevar computation
elif (
if (
isinstance(value, int)
and not is_constant_source(self.get_source())
and not isinstance(self.get_source(), RandomValueSource)
Expand All @@ -1112,6 +1101,8 @@ def wrap_unspecialized_primitive(self, value):
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
)

shape_env = self.tx.output.shape_env

name = self.source.name()
if name not in self.tx.output.frame_state:
# Note - this esentially means that if this name gets reused as a tensor,
Expand Down
7 changes: 1 addition & 6 deletions torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Dict, List

import torch.nn
from torch._guards import TracingContext

from .. import variables
from ..allowed_functions import is_allowed
Expand Down Expand Up @@ -590,8 +589,4 @@ def __init__(self, value, **kwargs):
assert type(value) is KeyedJaggedTensor
super().__init__(value, **kwargs)

def var_getattr(self, tx, name):
if self.source is not None and name in ("_length_per_key", "_offset_per_key"):
with TracingContext.patch(force_unspec_int_unbacked_size_like=True):
return super().var_getattr(tx, name)
return super().var_getattr(tx, name)
# TODO Handle getattr for _length_per_key and _offset_per_key properly.
24 changes: 0 additions & 24 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,30 +605,6 @@ def __init__(self, fake_mode):
# you ever do change this in aot_autograd.py; you should check
# on permutations preferentially.)
self.output_strides: Optional[List[Optional[List[int]]]] = None
# When this is True, whenever we encounter an int in Dynamo tracing,
# we will (1) force unspec it and (2) force it as a size-like unbacked
# integer. This is currently used when processing certain lists of
# ints that are known to be size-like and may have 0/1 entries that we
# must not specialize on.
self.force_unspec_int_unbacked_size_like = False

@staticmethod
@contextmanager
def patch(**kwargs):
prior = {}
ctx = TracingContext.get()
assert ctx is not None

for key in kwargs.keys():
# KeyError on invalid entry
prior[key] = getattr(ctx, key)
for key, val in kwargs.items():
setattr(ctx, key, val)
try:
yield
finally:
for key, val in prior.items():
setattr(ctx, key, val)

@staticmethod
def extract_stack():
Expand Down

0 comments on commit 1f99f75

Please sign in to comment.