Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force synced KJT to trace unbacked SymInt #108960

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
204 changes: 204 additions & 0 deletions test/dynamo/test_torchrec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Owner(s): ["module: dynamo"]
import sys
import unittest
from typing import Dict, List

import torch
import torch._dynamo.test_case
from torch import nn

from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import CompileCounter
from torch.testing._internal.common_utils import NoTest

try:
from torchrec.datasets.random import RandomRecDataset
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor

HAS_TORCHREC = True
except ImportError:
HAS_TORCHREC = False


class BucketizeMod(torch.nn.Module):
def __init__(self, feature_boundaries: Dict[str, List[float]]):
super().__init__()
self.bucket_w = torch.nn.ParameterDict()
self.boundaries_dict = {}
for key, boundaries in feature_boundaries.items():
self.bucket_w[key] = torch.nn.Parameter(
torch.empty([len(boundaries) + 1]).fill_(1.0),
requires_grad=True,
)
buf = torch.tensor(boundaries, requires_grad=False)
self.register_buffer(
f"{key}_boundaries",
buf,
persistent=False,
)
self.boundaries_dict[key] = buf

def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
weights_list = []
for key, boundaries in self.boundaries_dict.items():
jt = features[key]
bucketized = torch.bucketize(jt.weights(), boundaries)
# doesn't super matter I guess
# hashed = torch.ops.fb.index_hash(bucketized, seed=0, modulo=len(boundaries))
hashed = bucketized
weights = torch.gather(self.bucket_w[key], dim=0, index=hashed)
weights_list.append(weights)
return KeyedJaggedTensor(
keys=features.keys(),
values=features.values(),
weights=torch.cat(weights_list),
lengths=features.lengths(),
offsets=features.offsets(),
stride=features.stride(),
length_per_key=features.length_per_key(),
)


if not HAS_TORCHREC:
print("torchrec not available, skipping tests", file=sys.stderr)
TestCase = NoTest # noqa: F811


@unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec")
class TorchRecTests(TestCase):
def test_pooled(self):
tables = [
(nn.EmbeddingBag(2000, 8), ["a0", "b0"]),
(nn.EmbeddingBag(2000, 8), ["a1", "b1"]),
(nn.EmbeddingBag(2000, 8), ["b2"]),
]

embedding_groups = {
"a": ["a0", "a1"],
"b": ["b0", "b1", "b2"],
}

counter = CompileCounter()

@torch.compile(backend=counter, fullgraph=True, dynamic=True)
def f(id_list_features: KeyedJaggedTensor):
id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict()
pooled_embeddings = {}
# TODO: run feature processor
for emb_module, feature_names in tables:
features_dict = id_list_jt_dict
for feature_name in feature_names:
f = features_dict[feature_name]
pooled_embeddings[feature_name] = emb_module(
f.values(), f.offsets()
)

pooled_embeddings_by_group = {}
for group_name, group_embedding_names in embedding_groups.items():
group_embeddings = [
pooled_embeddings[name] for name in group_embedding_names
]
pooled_embeddings_by_group[group_name] = torch.cat(
group_embeddings, dim=1
)

return pooled_embeddings_by_group

dataset = RandomRecDataset(
keys=["a0", "a1", "b0", "b1", "b2"],
batch_size=4,
hash_size=2000,
ids_per_feature=3,
num_dense=0,
)
di = iter(dataset)

# unsync should work

d1 = next(di).sparse_features.unsync()
d2 = next(di).sparse_features.unsync()
d3 = next(di).sparse_features.unsync()

r1 = f(d1)
r2 = f(d2)
r3 = f(d3)

self.assertEqual(counter.frame_count, 1)
counter.frame_count = 0

# sync should work too

d1 = next(di).sparse_features.sync()
d2 = next(di).sparse_features.sync()
d3 = next(di).sparse_features.sync()

r1 = f(d1)
r2 = f(d2)
r3 = f(d3)

self.assertEqual(counter.frame_count, 1)

# export only works with unsync

gm = torch._dynamo.export(f)(next(di).sparse_features.unsync()).graph_module
gm.print_readable()

self.assertEqual(gm(d1), r1)
self.assertEqual(gm(d2), r2)
self.assertEqual(gm(d3), r3)

def test_bucketize(self):
mod = BucketizeMod({"f1": [0.0, 0.5, 1.0]})
features = KeyedJaggedTensor.from_lengths_sync(
keys=["f1"],
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
weights=torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]),
).unsync()

def f(x):
# This is a trick to populate the computed cache and instruct
# ShapeEnv that they're all sizey
x.to_dict()
return mod(x)

torch._dynamo.export(f, aten_graph=True)(features).graph_module.print_readable()

@unittest.expectedFailure
def test_simple(self):
jag_tensor1 = KeyedJaggedTensor(
values=torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
keys=["index_0", "index_1"],
lengths=torch.tensor([0, 0, 1, 1, 1, 3]),
).sync()

# ordinarily, this would trigger one specialization
self.assertEqual(jag_tensor1.length_per_key(), [1, 5])

counter = CompileCounter()

@torch._dynamo.optimize(counter, nopython=True)
def f(jag_tensor):
# The indexing here requires more symbolic reasoning
# and doesn't work right now
return jag_tensor["index_0"].values().sum()

f(jag_tensor1)

self.assertEqual(counter.frame_count, 1)

jag_tensor2 = KeyedJaggedTensor(
values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
keys=["index_0", "index_1"],
lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
).sync()

f(jag_tensor2)

self.assertEqual(counter.frame_count, 1)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

run_tests()
17 changes: 13 additions & 4 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
DimConstraint,
DimDynamic,
RelaxedUnspecConstraint,
Expand Down Expand Up @@ -885,7 +886,7 @@ def wrap_literal(self, value):
elif unspec and type(value) is int:
# unspecializing int by default, but still
# specialize for the following conditions
if (
if not TracingContext.get().force_unspec_int_unbacked_size_like and (
value in self._common_constants()
# Assume integers from global variables want to be specialized
or not self.source.guard_source().is_local()
Expand Down Expand Up @@ -1083,11 +1084,21 @@ def wrap_unspecialized_primitive(self, value):
if self.name in self.tx.output.unspec_variable_map:
return self.tx.output.unspec_variable_map[self.name]
else:
shape_env = self.tx.output.shape_env
if TracingContext.get().force_unspec_int_unbacked_size_like and isinstance(
value, int
):
wrapped_value = shape_env.create_unbacked_symint()
_constrain_range_for_size(wrapped_value)
self.tx.output.tracked_fakes.append(
TrackedFake(wrapped_value, self.source, None)
)

# NB: We do not do float. For motivation, see
# https://docs.google.com/document/d/1INSCdYu1PxXcr43HrD82OudeEuS-qxQe1yZmLg2wy6A/edit
# but the general idea is that we generate kernels that can
# take unspecialized floats and use them in sizevar computation
if (
elif (
isinstance(value, int)
and not is_constant_source(self.get_source())
and not isinstance(self.get_source(), RandomValueSource)
Expand All @@ -1101,8 +1112,6 @@ def wrap_unspecialized_primitive(self, value):
guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
)

shape_env = self.tx.output.shape_env

name = self.source.name()
if name not in self.tx.output.frame_state:
# Note - this esentially means that if this name gets reused as a tensor,
Expand Down
7 changes: 6 additions & 1 deletion torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Dict, List

import torch.nn
from torch._guards import TracingContext

from .. import variables
from ..allowed_functions import is_allowed
Expand Down Expand Up @@ -589,4 +590,8 @@ def __init__(self, value, **kwargs):
assert type(value) is KeyedJaggedTensor
super().__init__(value, **kwargs)

# TODO Handle getattr for _length_per_key and _offset_per_key properly.
def var_getattr(self, tx, name):
if self.source is not None and name in ("_length_per_key", "_offset_per_key"):
with TracingContext.patch(force_unspec_int_unbacked_size_like=True):
return super().var_getattr(tx, name)
return super().var_getattr(tx, name)
24 changes: 24 additions & 0 deletions torch/_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,30 @@ def __init__(self, fake_mode):
# you ever do change this in aot_autograd.py; you should check
# on permutations preferentially.)
self.output_strides: Optional[List[Optional[List[int]]]] = None
# When this is True, whenever we encounter an int in Dynamo tracing,
# we will (1) force unspec it and (2) force it as a size-like unbacked
# integer. This is currently used when processing certain lists of
# ints that are known to be size-like and may have 0/1 entries that we
# must not specialize on.
self.force_unspec_int_unbacked_size_like = False

@staticmethod
@contextmanager
def patch(**kwargs):
prior = {}
ctx = TracingContext.get()
assert ctx is not None

for key in kwargs.keys():
# KeyError on invalid entry
prior[key] = getattr(ctx, key)
for key, val in kwargs.items():
setattr(ctx, key, val)
try:
yield
finally:
for key, val in prior.items():
setattr(ctx, key, val)

@staticmethod
def extract_stack():
Expand Down