Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 39 additions & 30 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,37 +613,46 @@ def _aten__embedding_bag(
"""
embedded = _aten_embedding(weight, indices, padding_idx)

def static_dynamic_slice(x, start, size):
return jax.lax.dynamic_slice_in_dim(x, start, size)


# TODO not jittable
def reduce_by_segment(start, size, x, reducer):
res = []
for starti, sizei in zip(start, size):
res.append(reducer(static_dynamic_slice(x, starti, sizei), axis=0))
return jnp.stack(res)

def segsum(x, offsets, reducer):
start, end = offsets, jnp.concat([offsets[1:], jnp.array([x.shape[0]])])
return reduce_by_segment(start, end - start, x, reducer)

if mode not in (0, 1, 2):
raise ValueError("Invalid mode. Please choose 0 (sum) or 1 (mean).")
if mode == 0: # sum
reducer = jnp.sum
elif mode == 1: # mean
reducer = jnp.mean
elif mode == 2: # max
reducer = jnp.max

if indices.ndim == 1 and offsets is not None:
output = segsum(embedded, offsets, reducer)
if offsets is None:
# offsets is None only when indices.ndim > 1
if mode == 0: # sum
output = jnp.sum(embedded, axis=1)
elif mode == 1: # mean
output = jnp.mean(embedded, axis=1)
elif mode == 2: # max
output = jnp.max(embedded, axis=1)
return output, None, None, None

if isinstance(offsets, jax.Array):
offsets_np = np.array(offsets)
else:
output = reducer(embedded, axis=1)

# TODO: return output, offset2bag, bag_size, max_indices
return output, None, None, None
offsets_np = offsets
offset2bag = np.zeros(indices.shape[0], dtype=np.int64)
bag_size = np.zeros(offsets_np.shape[0], dtype=np.int64)
max_indices = jnp.full_like(indices, -1)

for bag in range(offsets_np.shape[0]):
start = int(offsets_np[bag])

end = int(indices.shape[0] if bag + 1 == offsets_np.shape[0] else offsets_np[bag + 1])
bag_size[bag] = end - start
offset2bag = offset2bag.at[start:end].set(bag)

if end - start > 0:
if mode == 0:
output_bag = jnp.sum(embedded[start:end], axis=0)
elif mode == 1:
output_bag = jnp.mean(embedded[start:end], axis=0)
elif mode == 2:
output_bag = jnp.max(embedded[start:end], axis=0)
max_indices = max_indices.at[start:end].set(jnp.argmax(embedded[start:end], axis=0))

# The original code returned offset2bag, bag_size, and max_indices as numpy arrays.
# Converting them to JAX arrays for consistency.
offset2bag = jnp.array(offset2bag)
bag_size = jnp.array(bag_size)

return output_bag, offset2bag, bag_size, max_indices


@op(torch.ops.aten.rsqrt)
Expand Down