We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4091d7d commit 87427f9Copy full SHA for 87427f9
torchrec/sparse/jagged_tensor.py
@@ -1102,7 +1102,7 @@ def _maybe_compute_stride_kjt(
1102
elif (
1103
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
1104
):
1105
- stride = int(stride_per_key_per_rank.sum(dim=1).max().item())
+ stride = torch.sym_int(stride_per_key_per_rank.sum(dim=1).max().item())
1106
elif offsets is not None and offsets.numel() > 0:
1107
stride = (offsets.numel() - 1) // len(keys)
1108
elif lengths is not None:
0 commit comments