Skip to content

Commit b4dfa13

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Add missing fields to KJT's PyTree flatten/unflatten logic for VBE KJT (pytorch#2952)
Summary: Pull Request resolved: pytorch#2952 # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Reviewed By: TroyGarden Differential Revision: D74295924
1 parent 7eee82f commit b4dfa13

File tree

2 files changed

+88
-43
lines changed

2 files changed

+88
-43
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,14 @@ def forward(
207207
num_embeddings=10,
208208
feature_names=["f2"],
209209
)
210+
config3 = EmbeddingBagConfig(
211+
name="t3",
212+
embedding_dim=5,
213+
num_embeddings=10,
214+
feature_names=["f3"],
215+
)
210216
ebc = EmbeddingBagCollection(
211-
tables=[config1, config2],
217+
tables=[config1, config2, config3],
212218
is_weighted=False,
213219
)
214220

@@ -293,42 +299,60 @@ def test_serialize_deserialize_ebc(self) -> None:
293299
self.assertEqual(deserialized.shape, orginal.shape)
294300
self.assertTrue(torch.allclose(deserialized, orginal))
295301

296-
@unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.")
297302
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
298303
model = self.generate_model_for_vbe_kjt()
299-
id_list_features = KeyedJaggedTensor(
300-
keys=["f1", "f2"],
301-
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
302-
lengths=torch.tensor([3, 3, 2]),
303-
stride_per_key_per_rank=[[2], [1]],
304-
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
304+
kjt_1 = KeyedJaggedTensor(
305+
keys=["f1", "f2", "f3"],
306+
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
307+
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
308+
stride_per_key_per_rank=torch.tensor([[3], [2], [1]]),
309+
inverse_indices=(
310+
["f1", "f2", "f3"],
311+
torch.tensor([[0, 1, 2], [0, 1, 0], [0, 0, 0]]),
312+
),
313+
)
314+
kjt_2 = KeyedJaggedTensor(
315+
keys=["f1", "f2", "f3"],
316+
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
317+
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
318+
stride_per_key_per_rank=torch.tensor([[1], [2], [3]]),
319+
inverse_indices=(
320+
["f1", "f2", "f3"],
321+
torch.tensor([[0, 0, 0], [0, 1, 0], [0, 1, 2]]),
322+
),
305323
)
306324

307-
eager_out = model(id_list_features)
325+
eager_out = model(kjt_1)
326+
eager_out_2 = model(kjt_2)
308327

309328
# Serialize EBC
310329
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
311330
ep = torch.export.export(
312331
model,
313-
(id_list_features,),
332+
(kjt_1,),
314333
{},
315334
strict=False,
316335
# Allows KJT to not be unflattened and run a forward on unflattened EP
317336
preserve_module_call_signature=(tuple(sparse_fqns)),
318337
)
319338

320339
# Run forward on ExportedProgram
321-
ep_output = ep.module()(id_list_features)
340+
ep_output = ep.module()(kjt_1)
341+
ep_output_2 = ep.module()(kjt_2)
322342

343+
self.assertEqual(len(ep_output), len(kjt_1.keys()))
344+
self.assertEqual(len(ep_output_2), len(kjt_2.keys()))
323345
for i, tensor in enumerate(ep_output):
324-
self.assertEqual(eager_out[i].shape, tensor.shape)
346+
self.assertEqual(eager_out[i].shape[1], tensor.shape[1])
347+
for i, tensor in enumerate(ep_output_2):
348+
self.assertEqual(eager_out_2[i].shape[1], tensor.shape[1])
325349

326350
# Deserialize EBC
327351
unflatten_ep = torch.export.unflatten(ep)
328352
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
329353

330354
# check EBC config
331-
for i in range(5):
355+
for i in range(1):
332356
ebc_name = f"ebc{i + 1}"
333357
self.assertIsInstance(
334358
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
@@ -343,36 +367,22 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
343367
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
344368
self.assertEqual(deserialized.feature_names, orginal.feature_names)
345369

346-
# check FPEBC config
347-
for i in range(2):
348-
fpebc_name = f"fpebc{i + 1}"
349-
assert isinstance(
350-
getattr(deserialized_model, fpebc_name),
351-
FeatureProcessedEmbeddingBagCollection,
352-
)
353-
354-
for deserialized, orginal in zip(
355-
getattr(
356-
deserialized_model, fpebc_name
357-
)._embedding_bag_collection.embedding_bag_configs(),
358-
getattr(
359-
model, fpebc_name
360-
)._embedding_bag_collection.embedding_bag_configs(),
361-
):
362-
self.assertEqual(deserialized.name, orginal.name)
363-
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
364-
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
365-
self.assertEqual(deserialized.feature_names, orginal.feature_names)
366-
367370
# Run forward on deserialized model and compare the output
368371
deserialized_model.load_state_dict(model.state_dict())
369-
deserialized_out = deserialized_model(id_list_features)
372+
deserialized_out = deserialized_model(kjt_1)
370373

371374
self.assertEqual(len(deserialized_out), len(eager_out))
372375
for deserialized, orginal in zip(deserialized_out, eager_out):
373376
self.assertEqual(deserialized.shape, orginal.shape)
374377
self.assertTrue(torch.allclose(deserialized, orginal))
375378

379+
deserialized_out_2 = deserialized_model(kjt_2)
380+
381+
self.assertEqual(len(deserialized_out_2), len(eager_out_2))
382+
for deserialized, orginal in zip(deserialized_out_2, eager_out_2):
383+
self.assertEqual(deserialized.shape, orginal.shape)
384+
self.assertTrue(torch.allclose(deserialized, orginal))
385+
376386
def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
377387
model = self.generate_model()
378388
feature1 = KeyedJaggedTensor.from_offsets_sync(

torchrec/sparse/jagged_tensor.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,6 +1723,8 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
17231723
"_weights",
17241724
"_lengths",
17251725
"_offsets",
1726+
"_stride_per_key_per_rank",
1727+
"_inverse_indices",
17261728
]
17271729

17281730
def __init__(
@@ -3015,13 +3017,35 @@ def dist_init(
30153017

30163018
def _kjt_flatten(
30173019
t: KeyedJaggedTensor,
3018-
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
3019-
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
3020+
) -> Tuple[List[Optional[torch.Tensor]], Tuple[List[str], List[str]]]:
3021+
"""
3022+
Used by PyTorch's pytree utilities for serialization and processing.
3023+
Extracts tensor attributes of a KeyedJaggedTensor and returns them
3024+
as a flat list, along with the necessary metadata to reconstruct the KeyedJaggedTensor.
3025+
3026+
Component tensors are returned as dynamic attributes.
3027+
KJT metadata are added as static specs.
3028+
3029+
Returns:
3030+
Tuple containing:
3031+
- List[Optional[torch.Tensor]]: All tensor attributes (_values, _weights, _lengths,
3032+
_offsets, _stride_per_key_per_rank, and the tensor part of _inverse_indices if present)
3033+
- Tuple[List[str], List[str]]: Metadata needed for reconstruction:
3034+
- List of keys from the original KeyedJaggedTensor
3035+
- List of inverse indices keys (if present, otherwise empty list)
3036+
"""
3037+
values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
3038+
values.append(t._inverse_indices[1] if t._inverse_indices is not None else None)
3039+
3040+
return values, (
3041+
t._keys,
3042+
t._inverse_indices[0] if t._inverse_indices is not None else [],
3043+
)
30203044

30213045

30223046
def _kjt_flatten_with_keys(
30233047
t: KeyedJaggedTensor,
3024-
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]:
3048+
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], Tuple[List[str], List[str]]]:
30253049
values, context = _kjt_flatten(t)
30263050
# pyre can't tell that GetAttrKey implements the KeyEntry protocol
30273051
return [ # pyre-ignore[7]
@@ -3030,15 +3054,26 @@ def _kjt_flatten_with_keys(
30303054

30313055

30323056
def _kjt_unflatten(
3033-
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
3057+
values: List[Optional[torch.Tensor]],
3058+
context: Tuple[
3059+
List[str], List[str]
3060+
], # context is the (_keys, _inverse_indices[0]) tuple
30343061
) -> KeyedJaggedTensor:
3035-
return KeyedJaggedTensor(context, *values)
3062+
return KeyedJaggedTensor(
3063+
context[0],
3064+
*values[:-2],
3065+
stride_per_key_per_rank=values[-2],
3066+
inverse_indices=(context[1], values[-1]) if values[-1] is not None else None,
3067+
)
30363068

30373069

30383070
def _kjt_flatten_spec(
30393071
t: KeyedJaggedTensor, spec: TreeSpec
30403072
) -> List[Optional[torch.Tensor]]:
3041-
return [getattr(t, a) for a in KeyedJaggedTensor._fields]
3073+
values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
3074+
values.append(t._inverse_indices[1] if t._inverse_indices is not None else None)
3075+
3076+
return values
30423077

30433078

30443079
register_pytree_node(
@@ -3053,7 +3088,7 @@ def _kjt_flatten_spec(
30533088

30543089
def flatten_kjt_list(
30553090
kjt_arr: List[KeyedJaggedTensor],
3056-
) -> Tuple[List[Optional[torch.Tensor]], List[List[str]]]:
3091+
) -> Tuple[List[Optional[torch.Tensor]], List[Tuple[List[str], List[str]]]]:
30573092
_flattened_data = []
30583093
_flattened_context = []
30593094
for t in kjt_arr:
@@ -3064,7 +3099,7 @@ def flatten_kjt_list(
30643099

30653100

30663101
def unflatten_kjt_list(
3067-
values: List[Optional[torch.Tensor]], contexts: List[List[str]]
3102+
values: List[Optional[torch.Tensor]], contexts: List[Tuple[List[str], List[str]]]
30683103
) -> List[KeyedJaggedTensor]:
30693104
num_kjt_fields = len(KeyedJaggedTensor._fields)
30703105
length = len(values)

0 commit comments

Comments
 (0)