# Flat vs Multi-Index

In [2]:
import numpy as np
shape = [2, 3, 3, 2]
flat_index = np.arange(0, np.prod(shape), 3)

r_multi_index = np.unravel_index(
    flat_index, shape, order="C"
)
print(np.stack([flat_index] +list(r_multi_index)).T)

[[ 0  0  0  0  0]
 [ 3  0  0  1  1]
 [ 6  0  1  0  0]
 [ 9  0  1  1  1]
 [12  0  2  0  0]
 [15  0  2  1  1]
 [18  1  0  0  0]
 [21  1  0  1  1]
 [24  1  1  0  0]
 [27  1  1  1  1]
 [30  1  2  0  0]
 [33  1  2  1  1]]


In [3]:
c_multi_index = np.unravel_index(
    flat_index, shape, order="F"
)
print(np.stack([flat_index] +list(c_multi_index)).T)

[[ 0  0  0  0  0]
 [ 3  1  1  0  0]
 [ 6  0  0  1  0]
 [ 9  1  1  1  0]
 [12  0  0  2  0]
 [15  1  1  2  0]
 [18  0  0  0  1]
 [21  1  1  0  1]
 [24  0  0  1  1]
 [27  1  1  1  1]
 [30  0  0  2  1]
 [33  1  1  2  1]]


## Flat index and multi-index interconversion

Flat index to multi-index

In [8]:
def ind2sub(flat_index, shape, order="C"):
    """Convert flat index to multi-index."""
    if order == "C":
        eta = np.cumprod(shape[1:][::-1])
        eta = np.concatenate([[1], eta])[::-1]
    else: # "F"
        eta = np.cumprod(shape)[:-1]
        eta = np.concatenate([[1], eta])

    # (num_indices, 1)
    flat_index = np.asarray(flat_index)[:, None]
    shape = np.asarray(shape)

    # Compute multi-index
    return flat_index // eta % shape

In [13]:
# Verify ind2sub
import numpy as np
shape = np.array([2, 3, 3, 2])
flat_index = np.arange(np.prod(shape))

# Use numpy unravel_index
r_multi_index = np.vstack(np.unravel_index(flat_index, shape, order="C")).T
c_multi_index = np.vstack(np.unravel_index(flat_index, shape, order="F")).T

# Call our own implementations
r_sub = ind2sub(flat_index, shape, order="C")
c_sub = ind2sub(flat_index, shape, order="F")

print("r_sub:\n", r_sub)
print("c_sub:\n", c_sub)

assert np.allclose(r_multi_index, r_sub)
assert np.allclose(c_multi_index, c_sub)

r_sub:
 [[0 0 0 0]
 [0 0 0 1]
 [0 0 1 0]
 [0 0 1 1]
 [0 0 2 0]
 [0 0 2 1]
 [0 1 0 0]
 [0 1 0 1]
 [0 1 1 0]
 [0 1 1 1]
 [0 1 2 0]
 [0 1 2 1]
 [0 2 0 0]
 [0 2 0 1]
 [0 2 1 0]
 [0 2 1 1]
 [0 2 2 0]
 [0 2 2 1]
 [1 0 0 0]
 [1 0 0 1]
 [1 0 1 0]
 [1 0 1 1]
 [1 0 2 0]
 [1 0 2 1]
 [1 1 0 0]
 [1 1 0 1]
 [1 1 1 0]
 [1 1 1 1]
 [1 1 2 0]
 [1 1 2 1]
 [1 2 0 0]
 [1 2 0 1]
 [1 2 1 0]
 [1 2 1 1]
 [1 2 2 0]
 [1 2 2 1]]
c_sub:
 [[0 0 0 0]
 [1 0 0 0]
 [0 1 0 0]
 [1 1 0 0]
 [0 2 0 0]
 [1 2 0 0]
 [0 0 1 0]
 [1 0 1 0]
 [0 1 1 0]
 [1 1 1 0]
 [0 2 1 0]
 [1 2 1 0]
 [0 0 2 0]
 [1 0 2 0]
 [0 1 2 0]
 [1 1 2 0]
 [0 2 2 0]
 [1 2 2 0]
 [0 0 0 1]
 [1 0 0 1]
 [0 1 0 1]
 [1 1 0 1]
 [0 2 0 1]
 [1 2 0 1]
 [0 0 1 1]
 [1 0 1 1]
 [0 1 1 1]
 [1 1 1 1]
 [0 2 1 1]
 [1 2 1 1]
 [0 0 2 1]
 [1 0 2 1]
 [0 1 2 1]
 [1 1 2 1]
 [0 2 2 1]
 [1 2 2 1]]


Multi-index to flat index

In [14]:
import numpy as np
coordinates = [ # point coordinates
    (0, 1, 1, 0), 
    (0, 2, 0, 1),
    (1, 0, 1, 0),
    (1, 1, 1, 1),
]
# one array for each dimension
multi_indices = np.array(coordinates).T.tolist()

shape = [2, 3, 3, 2]

# C-order
r_flat_index = np.ravel_multi_index(
    multi_indices, dims=shape, order="C"
)
print(np.concatenate([r_flat_index[:, None], coordinates], axis=1))

[[ 8  0  1  1  0]
 [13  0  2  0  1]
 [20  1  0  1  0]
 [27  1  1  1  1]]


In [15]:
# F-order
c_flat_index = np.ravel_multi_index(
    multi_indices, dims=shape, order="F"
)
print(np.concatenate([c_flat_index[:, None], coordinates], axis=1))

[[ 8  0  1  1  0]
 [22  0  2  0  1]
 [ 7  1  0  1  0]
 [27  1  1  1  1]]


In [24]:
def sub2ind(multi_index, shape, order="C"):
    """Convert multi-index to flat index."""
    if order == "C":
        eta = np.cumprod(shape[1:][::-1])
        eta = np.concatenate([[1], eta])[::-1]
    else: # "F"
        eta = np.cumprod(shape)[:-1]
        eta = np.concatenate([[1], eta])

    # (num_indices, d)
    multi_index = np.asarray(multi_index)

    return np.sum(multi_index * eta, axis=1)

c_ind = sub2ind(coordinates, shape=shape, order="C")
f_ind = sub2ind(coordinates, shape=shape, order="F")
print("multi-index:", coordinates)
print("shape:", shape)
print("C:", c_ind)
print("F:", f_ind)

multi-index: [(0, 1, 1, 0), (0, 2, 0, 1), (1, 0, 1, 0), (1, 1, 1, 1)]
shape: [2, 3, 3, 2]
C: [ 8 13 20 27]
F: [ 8 22  7 27]


# Boolean indexing

In [17]:
import numpy as np
X = np.array([
    [5, 2, 8, 3],
    [4, 9, 6, 7],
    [8, 11, 12, 10]
])

# Boolean condition of odd elements
# same shape as X
condition = X % 2 == 1

# Extract the odd elements
X[condition]

array([ 5,  3,  9,  7, 11])

In [25]:
import numpy as np
X = np.array([
    [5, 2, 8, 3],
    [4, 9, 6, 7],
    [8, 11, 12, 10]
])

# indices of odd elements
row_index, col_index = np.where(X % 2 == 1)
print("row_index:", row_index)
print("col_index:", col_index)

row_index: [0 0 1 1 2]
col_index: [0 3 1 3 1]


# Slicing

`None`

In [27]:
import numpy as np
X = np.random.rand(4, 5) # shape (4, 5)
# Add an extra 3rd dimension
# shape (4, 5, 1)
assert np.allclose(
    X[:, :, None].shape, 
    np.expand_dims(X, axis=2).shape
)
# Insert an extra dimension in the first dimension
# shape (1, 4, 5)
assert np.allclose(
    X[None, :, :].shape,
    np.expand_dims(X, axis=0).shape
)
# Insert an extra dimension between the original 
# 1st and 2nd dimensions: shape (4, 1, 5)
assert np.allclose(
    X[:, np.newaxis, :].shape,
    np.expand_dims(X, axis=1).shape
)

Ellipsis `...`

In [28]:
import numpy as np
X = np.random.rand(5, 2, 3, 8, 7)
# Selecting subset of a batch
assert np.allclose(
    X[2:4, :, :, :, :].shape,
    X[2:4, ...].shape
)
# Selecting the first slice of the last dimension
assert np.allclose(
    X[:, :, :, :, 0].shape,
    X[..., 0].shape
)
# Selecting a subset of the first dimension and 
# the last slice from the last dimensions
assert np.allclose(
    X[1:3, :, :, :, -1].shape,
    X[1:3, ..., -1].shape
)

Variable number of dimensions example for ellipsis `...`

In [29]:
import numpy as np

def extract_last_event(X):
    return X[:, -1, ...]

# (batch_size, seq_length)
X1 = np.random.rand(16, 50)
assert np.allclose(
    extract_last_event(X1).shape,
    (16, )
)

# (batch_size, seq_length, feature_size)
X2 = np.random.rand(16, 50, 32)
assert np.allclose(
    extract_last_event(X2).shape,
    (16, 32)
)
# (batch_size, seq_length, feature_size, num_instances)
X3 = np.random.rand(16, 50, 32, 8)
assert np.allclose(
    extract_last_event(X3).shape,
    (16, 32, 8)
)

# Reusing slice configurations

In [30]:
import numpy as np
X = np.random.rand(5, 8, 6)
Y = np.random.rand(3, 9, 7)

# Build the slice object
s = np.s_[:4, 2:4, 5:]

# Shape may not necessarily be identical
print(X[s].shape) # shape (4, 2, 1)
print(Y[s].shape) # shape (3, 2, 2)

(4, 2, 1)
(3, 2, 2)


In [32]:
x = [1, 2, 3, 4, 5]

# stop beyond the size of the array
print(x[:7]) # [1, 2, 3, 4, 5]

# stop beyond the size of the array, stride 2
print(x[:7:2]) # [1, 3, 5]

# empty, no elements to take 
# since the slice is beyond the length of the array
print(x[6:8]) # []

# The following, however, will throw an error
# print(x[7]) # IndexError: list index out of range

[1, 2, 3, 4, 5]
[1, 3, 5]
[]


# Case Study: Get Consecutive Index

In [33]:
import numpy as np

def get_consecutive_index(t, N=1, interval=True):
    """
    Given a sorted array of integers, find the start and the end
    of consecutive blocks
    e.g. t = [-1, 1,2,3,4, 7, 9,10,11,12,13, 15],
    return [[1,4], [6,10]]

    Inputs:
    * t: the sorted array of integers
    * N: filter for at least N consecutive. Default 1
    * interval: if True, we are filtering by N consecutive
            intervals instead of N consecutive numbers
    """
    # Determine if the (next element) = (1 + previous)
    x = np.diff(t) == 1
    # Determine which element is the start or end of the block
    f = np.concatenate(([False], x)) != np.concatenate((x, [False]))
    # Index of the start and end of the block
    f = np.where(f)[0]
    # reshape: the start and end must come in pairs
    f = np.reshape(f, (-1, 2))
    # filter for at least N consecutive
    interval = 1 - int(interval)  # 0 if True, 1 if False
    # Determine which interval has more than desired N consecutive numbers
    f = f[(interval + np.diff(f, n=1, axis=1).T[0]) >= N, :]
    return f

In [35]:
t = [-1, 1, 2, 3, 4, 7, 9, 10, 11, 12, 13, 15]
f = get_consecutive_index(t)
f

array([[ 1,  4],
       [ 6, 10]])

# Take

In [36]:
import numpy as np
X = np.array([2, 3, 5, 7, 11, 13, 17, 19, 23, 29])
# 1D array slicing
flat_index = [2, 3, 5, 7]
elements = np.take(X, flat_index)
elements

array([ 5,  7, 13, 19])

In [37]:
import numpy as np
X = np.arange(24).reshape(4, 6)
# Taking elements using flat indices
flat_index = [2, 3, 5, 18, 23]
elements = np.take(X, flat_index)
elements

array([ 2,  3,  5, 18, 23])

In [38]:
import numpy as np
X = np.arange(20).reshape(4, 5)
flat_index = [1, 2, 3, 4, 5]
assert np.allclose(
    np.take(np.asarray(X, order="C"), flat_index),
    np.take(np.asarray(X, order="F"), flat_index)
)

In [40]:
import numpy as np
X = np.arange(10)
# Flat index with a desired output shape
flat_index = [[4, 5], [3, 6], [1, 8]] # (3, 2)
elements = np.take(X, flat_index)
elements

array([[4, 5],
       [3, 6],
       [1, 8]])

In [41]:
import numpy as np
X = np.arange(30).reshape(5, 3, 2)
# Slicing along a single dimension
slice_index = [0, 3]
block = np.take(X, slice_index, axis=0)
block

array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5]],

       [[18, 19],
        [20, 21],
        [22, 23]]])

In [43]:
# This operation is equivalent to slicing using indexing
assert np.allclose(
    block,
    X[slice_index, :, :]
)

# Slicing along a single dimension
# And arrange the slices across multiple dimensions
slice_index = [[0, 3], [1, 2]]
block = np.take(X, slice_index, axis=0) # (2, 2, 3, 2)
assert np.allclose(
    block,
    X[slice_index, :, :]
)

# Take Along Axis

In [44]:
import numpy as np
X = np.array([
    [0, 1, 2, 3, 4],
    [1, 2, 3, 4, 5],
    [2, 3, 4, 5, 6],
    [3, 4, 5, 6, 7],
])
# Reordering index
reorder_index = np.array([
    [3, 1, 0, 4, 2],
    [2, 4, 3, 0, 1],
    [3, 0, 4, 1, 2],
    [0, 2, 1, 4, 3],
])
# Apply the reordering index
y = np.take_along_axis(X, reorder_index, axis=1)

y

array([[3, 1, 0, 4, 2],
       [3, 5, 4, 1, 2],
       [5, 2, 6, 3, 4],
       [3, 5, 4, 7, 6]])

In [46]:
import numpy as np
X1 = np.array([
    [0, 1, 2, 3, 4],
    [1, 2, 3, 4, 5],
    [2, 3, 4, 5, 6],
    [3, 4, 5, 6, 7],
])
X2 = np.array([
    [ 0, 10, 20, 30, 40],
    [10, 20, 30, 40, 50],
    [20, 30, 40, 50, 60],
    [30, 40, 50, 60, 70],
])
X3 = np.array([
    [-4, -3, -2, -1,  0],
    [-5, -4, -3, -2, -1],
    [-6, -5, -4, -3, -2],
    [-7, -6, -5, -4, -3],
])
# (4, 5, 3)
X = np.dstack([X1, X2, X3])

# Reordering index
reorder_index = np.array([
    [3, 1, 0, 4, 2],
    [2, 4, 3, 0, 1],
    [3, 0, 4, 1, 2],
    [0, 2, 1, 4, 3],
])[..., None] # last dimension broadcastable to X

# Apply the reordering index
y = np.take_along_axis(X, reorder_index, axis=1)

# Examine the results one matrix at a time

In [47]:
y[:, :, 0]

array([[3, 1, 0, 4, 2],
       [3, 5, 4, 1, 2],
       [5, 2, 6, 3, 4],
       [3, 5, 4, 7, 6]])

In [48]:
y[:, :, 1]

array([[30, 10,  0, 40, 20],
       [30, 50, 40, 10, 20],
       [50, 20, 60, 30, 40],
       [30, 50, 40, 70, 60]])

In [49]:
y[:, :, 2]

array([[-1, -3, -4,  0, -2],
       [-3, -1, -2, -5, -4],
       [-3, -6, -2, -5, -4],
       [-7, -5, -6, -3, -4]])

batching

In [50]:
import numpy as np
X = np.array([[0, 1, 2, 3, 4]])[:, :, None] # (1, 5, 1)
reorder_index = np.array([
    [3, 1, 0, 4, 2],
    [2, 4, 3, 0, 1],
    [3, 0, 4, 1, 2],
    [0, 2, 1, 4, 3],
])[..., None] # (4, 5, 1)
# Apply the reordering index
y = np.take_along_axis(X, reorder_index, axis=1)
assert np.allclose(
    y.shape,
    (4, 5, 1)
)
assert np.allclose(
    y,
    np.array([[3, 1, 0, 4, 2],
              [2, 4, 3, 0, 1],
              [3, 0, 4, 1, 2],
              [0, 2, 1, 4, 3]])[:, :, None]
)

No broadcasting

In [51]:
import numpy as np
X1 = np.array([
    [0, 1, 2, 3, 4],
    [1, 2, 3, 4, 5],
    [2, 3, 4, 5, 6],
    [3, 4, 5, 6, 7],
])
X2 = np.array([
    [ 0, 10, 20, 30, 40],
    [10, 20, 30, 40, 50],
    [20, 30, 40, 50, 60],
    [30, 40, 50, 60, 70],
])
X3 = np.array([
    [-4, -3, -2, -1,  0],
    [-5, -4, -3, -2, -1],
    [-6, -5, -4, -3, -2],
    [-7, -6, -5, -4, -3],
])
# (4, 5, 3)
X = np.dstack([X1, X2, X3])

# Reordering index
index_1 = np.array([
    [3, 1, 0, 4, 2],
    [2, 4, 3, 0, 1],
    [3, 0, 4, 1, 2],
    [0, 2, 1, 4, 3],
])
index_2 = np.array([
    [1, 0, 2, 3, 4],
    [3, 0, 4, 1, 2],
    [4, 0, 1, 2, 3],
    [0, 1, 2, 3, 4],
])
index_3 = np.array([
    [3, 0, 1, 4, 2],
    [2, 0, 1, 3, 4],
    [2, 1, 4, 3, 0],
    [0, 2, 4, 1, 3],
])
# (4, 5, 3)
reorder_index = np.dstack([index_1, index_2, index_3])

# Apply the reordering index
y = np.take_along_axis(X, reorder_index, axis=1)

# Examine several slices

In [52]:
y[0, :, 0]

array([3, 1, 0, 4, 2])

In [53]:
y[1, :, 2]

array([-3, -5, -4, -2, -1])

In [54]:
y[2, :, 1]

array([60, 20, 30, 40, 50])

In [55]:
# Verify
for ii in range(X.shape[0]):
    for jj in range(X.shape[2]):
        assert np.allclose(
            np.take(X[ii, :, jj], reorder_index[ii, :, jj]),
            y[ii, :, jj],
        )

# Gather

In [57]:
import numpy as np
import tensorflow as tf
import torch
X = np.arange(10)
index = [3, 5]
y_np = np.take(X, index)
y_tf = tf.gather(X, index)
y_torch = torch.gather(
    torch.tensor(X), 0, torch.tensor(index)
)
y_np, y_tf, y_torch

(array([3, 5]),
 <tf.Tensor: shape=(2,), dtype=int64, numpy=array([3, 5])>,
 tensor([3, 5]))

In [59]:
import numpy as np
import tensorflow as tf
import torch
X = np.arange(20).reshape(4, 5)
index = np.array([1, 3])
y_np = np.take(X, index, axis=1)
y_tf = tf.gather(X, index, axis=1)
y_torch = torch.gather(
    torch.tensor(X), 
    0,
    torch.tensor(np.repeat([index], X.shape[1], axis=0).T)
)
y_np, y_tf, y_torch

(array([[ 1,  3],
        [ 6,  8],
        [11, 13],
        [16, 18]]),
 <tf.Tensor: shape=(4, 2), dtype=int64, numpy=
 array([[ 1,  3],
        [ 6,  8],
        [11, 13],
        [16, 18]])>,
 tensor([[ 5,  6,  7,  8,  9],
         [15, 16, 17, 18, 19]]))

In [60]:
y_torch_index = torch.gather(
    torch.tensor(X), 0, torch.tensor(index[:, None])
)
y_torch_index

tensor([[ 5],
        [15]])

In [61]:
import numpy as np
import tensorflow as tf
X = np.arange(30).reshape(3, 5, 2)
slice_index = [[0, 3], [1, 1]]
block_np = np.take(X, slice_index, axis=1)
block_tf = tf.gather(X, slice_index, axis=1)
assert np.allclose(
    block_np, block_tf
)

Take a long axis / dim

In [62]:
import numpy as np
import tensorflow as tf
import torch
X = np.random.randint(0, 10, size=(4, 3, 5))
reorder_index = np.random.randint(0, 5, size=(4, 3, 5))
y_np = np.take_along_axis(X, reorder_index, axis=2)
y_torch = torch.take_along_dim(
    torch.tensor(X), torch.tensor(reorder_index), axis=2
)
y_tf = tf.gather(X, reorder_index, axis=2, batch_dims=2)
y_np, y_torch, y_tf

(array([[[5, 5, 5, 5, 5],
         [3, 9, 3, 3, 3],
         [4, 3, 2, 1, 9]],
 
        [[6, 3, 6, 3, 4],
         [6, 1, 1, 1, 1],
         [9, 0, 7, 0, 7]],
 
        [[1, 1, 1, 3, 3],
         [7, 0, 0, 0, 4],
         [9, 8, 8, 3, 9]],
 
        [[8, 5, 1, 5, 1],
         [6, 1, 6, 7, 6],
         [5, 2, 2, 9, 2]]]),
 tensor([[[5, 5, 5, 5, 5],
          [3, 9, 3, 3, 3],
          [4, 3, 2, 1, 9]],
 
         [[6, 3, 6, 3, 4],
          [6, 1, 1, 1, 1],
          [9, 0, 7, 0, 7]],
 
         [[1, 1, 1, 3, 3],
          [7, 0, 0, 0, 4],
          [9, 8, 8, 3, 9]],
 
         [[8, 5, 1, 5, 1],
          [6, 1, 6, 7, 6],
          [5, 2, 2, 9, 2]]]),
 <tf.Tensor: shape=(4, 3, 5), dtype=int64, numpy=
 array([[[5, 5, 5, 5, 5],
         [3, 9, 3, 3, 3],
         [4, 3, 2, 1, 9]],
 
        [[6, 3, 6, 3, 4],
         [6, 1, 1, 1, 1],
         [9, 0, 7, 0, 7]],
 
        [[1, 1, 1, 3, 3],
         [7, 0, 0, 0, 4],
         [9, 8, 8, 3, 9]],
 
        [[8, 5, 1, 5, 1],
         [6, 1, 6, 7,

Broadcasting

In [63]:
import numpy as np
import tensorflow as tf
import torch
X = np.random.randint(0, 10, size=(4, 3, 5, 6))
# 3rd dimension as the broadcasting dimension
reorder_index = np.random.randint(0, 5, size=(4, 3, 1, 7))
# Returns a tensor of shape (4, 3, 5, 7)
y_np = np.take_along_axis(
    X, reorder_index, axis=3
)
y_torch = torch.take_along_dim(
    torch.tensor(X),
    torch.tensor(reorder_index),
    axis=3
)
y_tf = tf.gather(
    # taking out the singletone dimension
    X, reorder_index[..., 0, :], axis=3, batch_dims=2
)
y_np

array([[[[9, 9, 0, 6, 9, 6, 9],
         [9, 6, 6, 1, 6, 1, 9],
         [2, 2, 4, 8, 2, 8, 7],
         [5, 5, 1, 2, 5, 2, 5],
         [9, 6, 2, 3, 6, 3, 8]],

        [[2, 0, 8, 1, 8, 0, 2],
         [4, 5, 9, 9, 9, 5, 4],
         [4, 6, 7, 6, 7, 6, 4],
         [4, 3, 4, 1, 4, 3, 4],
         [1, 4, 5, 4, 5, 4, 1]],

        [[0, 0, 1, 1, 1, 0, 2],
         [2, 1, 3, 8, 3, 1, 1],
         [9, 4, 5, 4, 5, 4, 4],
         [5, 5, 6, 7, 6, 5, 3],
         [7, 9, 4, 8, 4, 9, 6]]],


       [[[8, 1, 8, 4, 4, 1, 1],
         [7, 8, 7, 3, 3, 8, 2],
         [7, 2, 7, 6, 6, 9, 6],
         [4, 3, 4, 3, 3, 9, 1],
         [7, 9, 7, 2, 2, 1, 7]],

        [[8, 5, 6, 8, 6, 6, 5],
         [1, 6, 4, 1, 4, 4, 6],
         [6, 1, 0, 6, 0, 0, 1],
         [4, 2, 2, 4, 2, 2, 2],
         [1, 2, 3, 1, 3, 3, 2]],

        [[9, 2, 2, 9, 9, 0, 2],
         [9, 8, 8, 8, 8, 0, 8],
         [0, 2, 2, 3, 3, 4, 2],
         [4, 3, 3, 6, 6, 7, 3],
         [2, 6, 6, 3, 3, 6, 6]]],


       [[[5, 0, 0, 8, 5,

In [64]:
y_torch

tensor([[[[9, 9, 0, 6, 9, 6, 9],
          [9, 6, 6, 1, 6, 1, 9],
          [2, 2, 4, 8, 2, 8, 7],
          [5, 5, 1, 2, 5, 2, 5],
          [9, 6, 2, 3, 6, 3, 8]],

         [[2, 0, 8, 1, 8, 0, 2],
          [4, 5, 9, 9, 9, 5, 4],
          [4, 6, 7, 6, 7, 6, 4],
          [4, 3, 4, 1, 4, 3, 4],
          [1, 4, 5, 4, 5, 4, 1]],

         [[0, 0, 1, 1, 1, 0, 2],
          [2, 1, 3, 8, 3, 1, 1],
          [9, 4, 5, 4, 5, 4, 4],
          [5, 5, 6, 7, 6, 5, 3],
          [7, 9, 4, 8, 4, 9, 6]]],


        [[[8, 1, 8, 4, 4, 1, 1],
          [7, 8, 7, 3, 3, 8, 2],
          [7, 2, 7, 6, 6, 9, 6],
          [4, 3, 4, 3, 3, 9, 1],
          [7, 9, 7, 2, 2, 1, 7]],

         [[8, 5, 6, 8, 6, 6, 5],
          [1, 6, 4, 1, 4, 4, 6],
          [6, 1, 0, 6, 0, 0, 1],
          [4, 2, 2, 4, 2, 2, 2],
          [1, 2, 3, 1, 3, 3, 2]],

         [[9, 2, 2, 9, 9, 0, 2],
          [9, 8, 8, 8, 8, 0, 8],
          [0, 2, 2, 3, 3, 4, 2],
          [4, 3, 3, 6, 6, 7, 3],
          [2, 6, 6, 3, 3, 6, 6]

In [65]:
y_tf

<tf.Tensor: shape=(4, 3, 5, 7), dtype=int64, numpy=
array([[[[9, 9, 0, 6, 9, 6, 9],
         [9, 6, 6, 1, 6, 1, 9],
         [2, 2, 4, 8, 2, 8, 7],
         [5, 5, 1, 2, 5, 2, 5],
         [9, 6, 2, 3, 6, 3, 8]],

        [[2, 0, 8, 1, 8, 0, 2],
         [4, 5, 9, 9, 9, 5, 4],
         [4, 6, 7, 6, 7, 6, 4],
         [4, 3, 4, 1, 4, 3, 4],
         [1, 4, 5, 4, 5, 4, 1]],

        [[0, 0, 1, 1, 1, 0, 2],
         [2, 1, 3, 8, 3, 1, 1],
         [9, 4, 5, 4, 5, 4, 4],
         [5, 5, 6, 7, 6, 5, 3],
         [7, 9, 4, 8, 4, 9, 6]]],


       [[[8, 1, 8, 4, 4, 1, 1],
         [7, 8, 7, 3, 3, 8, 2],
         [7, 2, 7, 6, 6, 9, 6],
         [4, 3, 4, 3, 3, 9, 1],
         [7, 9, 7, 2, 2, 1, 7]],

        [[8, 5, 6, 8, 6, 6, 5],
         [1, 6, 4, 1, 4, 4, 6],
         [6, 1, 0, 6, 0, 0, 1],
         [4, 2, 2, 4, 2, 2, 2],
         [1, 2, 3, 1, 3, 3, 2]],

        [[9, 2, 2, 9, 9, 0, 2],
         [9, 8, 8, 8, 8, 0, 8],
         [0, 2, 2, 3, 3, 4, 2],
         [4, 3, 3, 6, 6, 7, 3],
        

Expected output shape

In [68]:
inputs = np.random.rand(5, 3, 4, 8)
axis = 1
batch_dims = 1
np.concatenate([
    inputs.shape[:axis],
    index.shape[batch_dims:],
    inputs.shape[axis+1:]
]).astype(int)

array([5, 4, 8])

# N-dimensional Gather

In [69]:
import numpy as np
import tensorflow as tf
import torch
X = np.array([
    [1, 3, 2, 4],
    [6, 8, 7, 9],
])
# 2D indices
index = np.array([
    [0, 1],
    [0, 3],
    [1, 0],
    [1, 2],
])

# gather_nd
y_tf = tf.gather_nd(X, index)
# For numpy and torch, use indexing
y_np = X[index.T[0], index.T[1]]
y_torch = torch.tensor(X)[index.T[0], index.T[1]]
y_tf, y_np, y_torch

(<tf.Tensor: shape=(4,), dtype=int64, numpy=array([3, 4, 6, 7])>,
 array([3, 4, 6, 7]),
 tensor([3, 4, 6, 7]))

In [70]:
import tensorflow as tf
X = tf.constant([
    [1, 3, 2, 4],
    [6, 8, 7, 9],
])
# 2D indices
index = tf.constant([
    [0, 1],
    [0, 3],
    [1, 0],
    [1, 2],
])
# This throws an error
X[index[:, 0], index[:, 1]]

TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got <tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 0, 1, 1], dtype=int32)>

In [73]:
import tensorflow as tf
X = tf.random.uniform((4, 5, 6))
# Multi-indexing the first 2 dimensions, 
# using last dimension as the slice/block dimension
index = tf.constant([
    [0, 3],
    [1, 2],
    [3, 0],
    [2, 2],
]) #(4, 2)

# Gather 4 slices from the 3D tensor
y = tf.gather_nd(X, index) # (4, 6)
y

<tf.Tensor: shape=(4, 6), dtype=float32, numpy=
array([[0.73575985, 0.6488168 , 0.2867304 , 0.5872328 , 0.7951604 ,
        0.1250056 ],
       [0.3266816 , 0.29827225, 0.25954556, 0.35885704, 0.14033401,
        0.3662238 ],
       [0.5695604 , 0.853847  , 0.39001572, 0.6766658 , 0.4703468 ,
        0.03319669],
       [0.34298062, 0.0477289 , 0.22252929, 0.0397253 , 0.89708626,
        0.8542186 ]], dtype=float32)>

Multi-dimensional batch

In [75]:
import tensorflow as tf
X = tf.random.uniform((2, 3, 5, 4))
index = tf.experimental.numpy.random.randint(
    0, 4, (2, 3, 4, 2)
)
batch_dims=2
y = tf.gather_nd(X, index, batch_dims=batch_dims)
y

<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.9784707 , 0.6681063 , 0.46982932, 0.6225172 ],
        [0.6100161 , 0.88331604, 0.15604722, 0.64186656],
        [0.44745743, 0.9578006 , 0.721297  , 0.85305524]],

       [[0.5169152 , 0.10794866, 0.08310223, 0.75352144],
        [0.05132163, 0.85858   , 0.1718756 , 0.1718756 ],
        [0.45064688, 0.687662  , 0.687662  , 0.17509532]]], dtype=float32)>

# Set Values at Index

In [76]:
import numpy as np
import tensorflow as tf
import torch

# Python list
X_py = [1, 2, 3, 4]
X_py[3] = 100
X_py

[1, 2, 3, 100]

In [78]:
# NumPy
X_np = np.array([1, 2, 3, 4])
X_np[3] = 100
X_np

array([  1,   2,   3, 100])

In [79]:
# PyTorch
X_torch = torch.tensor([1, 2, 3, 4])
X_torch[3] = 100
X_torch

tensor([  1,   2,   3, 100])

In [80]:
# Tensorflow
X_tf = tf.constant([1, 2, 3, 4])
X_tf[3] = 100 # throws an error

TypeError: 'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment

In [81]:
import numpy as np
X = np.array([1, 2, 3, 4])
# create a new variable that points to X
Y = X 
# modify X
X[3] = 100

In [82]:
X

array([  1,   2,   3, 100])

In [83]:
Y

array([  1,   2,   3, 100])

Multi-dimensional array with multi-index assignment

In [85]:
import numpy as np
X = np.arange(20).reshape(4, 5)
X

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])

In [86]:
X[3, 2] -= 20 # 17 -> -3
X

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, -3, 18, 19]])

Replacing slice

In [87]:
import numpy as np
X = np.arange(20).reshape(4, 5)
X

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])

In [88]:
# replacing the last column with -99
X[:, -1] = -99
X

array([[  0,   1,   2,   3, -99],
       [  5,   6,   7,   8, -99],
       [ 10,  11,  12,  13, -99],
       [ 15,  16,  17,  18, -99]])

Broadcasting

In [89]:
import numpy as np
X = np.arange(20).reshape(4, 5)
X

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])

In [90]:
# replacing the last two columns with -88, -99, respectively
X[:, -2:] = np.array([-88, -99])
X

array([[  0,   1,   2, -88, -99],
       [  5,   6,   7, -88, -99],
       [ 10,  11,  12, -88, -99],
       [ 15,  16,  17, -88, -99]])

Setting diagonal example

In [91]:
import numpy as np
X = np.arange(20).reshape(4, 5)
X

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19]])

In [92]:
# Assign the diagonal entries with different values
rows = np.array([0, 1, 2, 3])
cols = np.array([0, 1, 2, 3])
values = np.array([-1, -2, -3, -4])
X[rows, cols] = values
X

array([[-1,  1,  2,  3,  4],
       [ 5, -2,  7,  8,  9],
       [10, 11, -3, 13, 14],
       [15, 16, 17, -4, 19]])

# Put

In [94]:
import numpy as np
X = np.zeros(10, dtype=int)
indices = np.array([2, 4, 6, 8])
values = np.array([1, 2, 3, 4])
# in-place operation, modifies X directly
np.put(X, indices, values)
X

array([0, 0, 1, 0, 2, 0, 3, 0, 4, 0])

In [95]:
import numpy as np
X = np.zeros((3, 3), dtype=int)
indices = np.array([1, 2, 5, 8])
values = np.array([1, 2, 3, 4])
np.put(X, indices, values)
X

array([[0, 1, 2],
       [0, 0, 3],
       [0, 0, 4]])

In [96]:
import numpy as np
X = np.zeros((3, 3), dtype=int)
indices = np.array([[1, 2], [5, 8]])
values = np.array([1, 2, 3, 4])
np.put(X, indices, values)
X

array([[0, 1, 2],
       [0, 0, 3],
       [0, 0, 4]])

# Put along axis

In [99]:
import numpy as np
X = np.zeros((3, 4), dtype=int)
indices = np.array([
    [0, 2],
    [1, 3],
    [2, 1]
])
# (3, 1) broadcastable to (3, 2)
# this will reuse the value along the rows
values = np.array([
    [1],
    [2],
    [3],
])

np.put_along_axis(X, indices, values, axis=1)
X

array([[1, 0, 1, 0],
       [0, 2, 0, 2],
       [0, 3, 3, 0]])

# Multi-index Scatter Replacement

In [100]:
import tensorflow as tf
# [row, col]
indices = tf.constant([
    [0, 2],
    [1, 3],
    [2, 0],
    [3, 1],
    [4, 2],
])
values = tf.constant([1, 3, 5, 7, 9])
shape = (5, 4)
X_tf = tf.scatter_nd(indices, values, shape)
X_tf

<tf.Tensor: shape=(5, 4), dtype=int32, numpy=
array([[0, 0, 1, 0],
       [0, 0, 0, 3],
       [5, 0, 0, 0],
       [0, 7, 0, 0],
       [0, 0, 9, 0]], dtype=int32)>

Alternative ways

In [101]:
import numpy as np
import tensorflow as tf
import torch
# [row, col]
indices = tf.constant([
    [0, 2],
    [1, 3],
    [2, 0],
    [3, 1],
    [4, 2],
])
values = tf.constant([1, 3, 5, 7, 9])
shape = (5, 4)
X_tf = tf.zeros(shape, dtype=values.dtype)
X_tf = tf.tensor_scatter_nd_update(
    X_tf, indices, values
)

# equivalently in NumPy
X_np = np.zeros(shape, dtype=int)
X_np[tuple(indices.numpy().T)] = values
assert np.allclose(X_tf, X_np)

# 2 equivalent operations in PyTorch
# Converting to PyTorch dtypes
indices = torch.tensor(indices.numpy())
values = torch.tensor(values.numpy(), dtype=int)
# Method 1: Using indexed assignment like NumPy
X_torch = torch.zeros(shape, dtype=int)
X_torch[tuple(indices.T)] = values
assert np.allclose(X_tf, X_torch)

# Method 2: Using torch.index_put
X_torch = torch.zeros(shape, dtype=int)
X_torch = torch.index_put(X_torch, tuple(indices.T), values)

# Method 2 In-place: Using torch.Tensor.index_put_
X_torch = torch.zeros(shape, dtype=int)
# Also returns the modified tensor itself
X_torch.index_put_(tuple(indices.T), values)

tensor([[0, 0, 1, 0],
        [0, 0, 0, 3],
        [5, 0, 0, 0],
        [0, 7, 0, 0],
        [0, 0, 9, 0]])

Duplicated entries

In [107]:
import numpy as np
import tensorflow as tf
import torch
indices = torch.tensor([
    [0, 2],
    [1, 3],
    [2, 0],
    [3, 1],
    [4, 2],
    [4, 2],
])
values = torch.tensor([1, 3, 5, 7, 9, 11])
shape = (5, 4)

# NumPy:
X_np = np.zeros(shape, dtype=int)
X_np[tuple(indices.T)] = values

# Tensorflow:
X_tf = tf.zeros(shape, dtype=tf.int32)
X_tf = tf.tensor_scatter_nd_update(X_tf, indices, values)

# PyTorch:
X_torch = torch.zeros(shape, dtype=int)
X_torch = torch.index_put(
    X_torch, tuple(indices.T), values, accumulate=False
)

assert np.allclose(X_np, X_tf)
assert np.allclose(X_tf, X_torch)

X_np

array([[ 0,  0,  1,  0],
       [ 0,  0,  0,  3],
       [ 5,  0,  0,  0],
       [ 0,  7,  0,  0],
       [ 0,  0, 11,  0]])

In [108]:
# Accumulate via summation 
# NumPy:
X_np = np.zeros(shape, dtype=int)
np.add.at(X_np, tuple(indices.T), values) # in-place
# Alternatively,
# X_np[tuple(np.array(indices).T)] += np.array(values)

# Tensorflow:
X_tf = tf.zeros(shape, dtype=tf.int32)
X_tf = tf.tensor_scatter_nd_add(X_tf, indices, values)
X_tf_alt = tf.scatter_nd(indices, values, shape)

# PyTorch:
X_torch = torch.zeros(shape, dtype=int)
X_torch = torch.index_put(
    X_torch, tuple(indices.T), values, accumulate=True
)

assert np.allclose(X_np, X_tf)
assert np.allclose(X_tf, X_tf_alt)
assert np.allclose(X_tf, X_torch)

X_torch

tensor([[ 0,  0,  1,  0],
        [ 0,  0,  0,  3],
        [ 5,  0,  0,  0],
        [ 0,  7,  0,  0],
        [ 0,  0, 20,  0]])

Tensorflow scattering 1D

In [110]:
import tensorflow as tf
indices = tf.constant([0, 2, 4])[:, None]
values = tf.constant([1, 3, 5])
shape = (5, )
X = tf.scatter_nd(indices, values, shape)
X

<tf.Tensor: shape=(5,), dtype=int32, numpy=array([1, 0, 3, 0, 5], dtype=int32)>

PyTorch scattering 1D

In [111]:
import torch
indices = torch.tensor([0, 2, 4])
values = torch.tensor([1, 3, 5])
shape = (5, )
X_torch = torch.zeros(shape, dtype=int)
X_torch = torch.index_put(X_torch, (indices, ), values)
X_torch

tensor([1, 0, 3, 0, 5])

Tensorflow scattering slicesm

In [112]:
import tensorflow as tf
values = tf.random.uniform((2, 3, 4))
indices = tf.constant([3, 0])[:, None]
X = tf.scatter_nd(indices, values, shape=(5, 3, 4))
X

<tf.Tensor: shape=(5, 3, 4), dtype=float32, numpy=
array([[[0.5854825 , 0.6443523 , 0.2271676 , 0.77668583],
        [0.181005  , 0.5815594 , 0.8054135 , 0.31492925],
        [0.08400917, 0.6241863 , 0.73857045, 0.40204942]],

       [[0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ]],

       [[0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ]],

       [[0.8988935 , 0.32361162, 0.713194  , 0.03325117],
        [0.34480882, 0.4378265 , 0.5213853 , 0.306126  ],
        [0.5065043 , 0.9440452 , 0.08750558, 0.40185368]],

       [[0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        ]]], dtype=float32)>

Higher dimensional scattering

In [114]:
import tensorflow as tf
values = tf.random.uniform((2, 3, 4))
indices = tf.constant([[1, 0], [0, 1]]) # (2, 2)
X = tf.scatter_nd(indices, values, shape=(2, 2, 3, 4))
X

<tf.Tensor: shape=(2, 2, 3, 4), dtype=float32, numpy=
array([[[[0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ]],

        [[0.29027224, 0.57536054, 0.49918687, 0.47940528],
         [0.42972338, 0.05375266, 0.56053984, 0.14314198],
         [0.57762897, 0.21365106, 0.14313734, 0.57539463]]],


       [[[0.5233065 , 0.65773225, 0.97701263, 0.30892646],
         [0.11707819, 0.5031285 , 0.97141457, 0.20344663],
         [0.58398056, 0.05704939, 0.38306272, 0.41080678]],

        [[0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ]]]],
      dtype=float32)>

Batching with scattering

In [115]:
import tensorflow as tf
values = tf.random.uniform((4, 3, 4))
indices = tf.constant([
    [1, 0],
    [0, 1],
    [0, 0],
    [2, 2],
]) # (4, 2)
shape = (3, 3, 3, 4)
X = tf.scatter_nd(indices, values, shape=shape)

# fold indices and values into a 2 x 2 grid -> (2, 2, 2)
folded_indices = tf.reshape(indices, (2, 2, 2))
folded_values = tf.reshape(values, (2, 2, 3, 4))
X_folded = tf.scatter_nd(indices, values, shape=shape)

# same output results
assert np.allclose(X, X_folded)
X

<tf.Tensor: shape=(3, 3, 3, 4), dtype=float32, numpy=
array([[[[0.22356176, 0.46682   , 0.78744555, 0.6579579 ],
         [0.03612351, 0.79817665, 0.82357204, 0.4398501 ],
         [0.92811894, 0.27770078, 0.99358535, 0.6468997 ]],

        [[0.81874   , 0.2970785 , 0.19359171, 0.5107218 ],
         [0.9423139 , 0.8019619 , 0.11886549, 0.6156517 ],
         [0.63097227, 0.26805937, 0.5931915 , 0.4372431 ]],

        [[0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ]]],


       [[[0.5076268 , 0.9843929 , 0.18080878, 0.29653823],
         [0.8087219 , 0.66494715, 0.71060205, 0.36653388],
         [0.75558305, 0.15773058, 0.6330824 , 0.7372726 ]],

        [[0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        , 0.        ]],

        [[0.        , 0.        , 0.        , 0. 

# PyTorch scattering

Example 1: Scattering along `dim=0`

In [116]:
import torch
values = torch.tensor([
    [3, 8, 2, 1, 4],
    [5, 4, 7, 6, 2],
    # [1, 3, 2, 4, 6], this extra row will be ignored
], dtype=int)
indices = torch.tensor([
    [1, 2, 3, 0, 1],
    [2, 1, 0, 3, 2],
])

# Scatter along dim=0
X = torch.zeros((4, 5), dtype=int)
y = torch.scatter_reduce(
    input=X, dim=0, index=indices, src=values, reduce="sum"
)
y

tensor([[0, 0, 7, 1, 0],
        [3, 4, 0, 0, 4],
        [5, 8, 0, 0, 2],
        [0, 0, 2, 6, 0]])

In [117]:
# Not exactly matching the full columns:
# indices will be left aligned
y = torch.scatter_reduce(
    input=X,
    dim=0,
    index=indices[:, :-1],
    src=values[:, :-1],
    reduce="sum",
)
y

tensor([[0, 0, 7, 1, 0],
        [3, 4, 0, 0, 0],
        [5, 8, 0, 0, 0],
        [0, 0, 2, 6, 0]])

Example 2: Scattering along `dim=1`

In [118]:
import torch
values = torch.tensor([
    [3, 7],
    [8, 6],
    [4, 2],
    [5, 9],
], dtype=int)
indices = torch.tensor([
    [2, 4],
    [0, 3],
    [4, 0],
    [2, 1],
])

# Scatter along dim=1
X = torch.zeros((4, 5), dtype=int)
y = torch.scatter_reduce(
    input=X, dim=1, index=indices, src=values, reduce="sum"
)
y

tensor([[0, 0, 3, 0, 7],
        [8, 0, 0, 6, 0],
        [2, 0, 0, 0, 4],
        [0, 9, 5, 0, 0]])

In [119]:
# Not exactly matching the full columns: 
# indices will be left aligned
y = torch.scatter_reduce(
    input=X,
    dim=1,
    index=indices[:-1, :],
    src=values[:-1, :],
    reduce="sum",
)
y

tensor([[0, 0, 3, 0, 7],
        [8, 0, 0, 6, 0],
        [2, 0, 0, 0, 4],
        [0, 0, 0, 0, 0]])

`torch.index_add` example

In [120]:
import torch
values = torch.rand((2, 3, 4), dtype=float)
indices = torch.tensor([3, 0])
X = torch.zeros((5, 3, 4), dtype=float)
X = torch.index_add(X, 0, indices, values)
X

tensor([[[0.8246, 0.4835, 0.3055, 0.7646],
         [0.7673, 0.5859, 0.9572, 0.9006],
         [0.7628, 0.1679, 0.4563, 0.6049]],

        [[0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.6473, 0.6150, 0.1845, 0.2119],
         [0.3595, 0.3015, 0.9837, 0.0020],
         [0.6153, 0.9747, 0.0458, 0.9802]],

        [[0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000]]], dtype=torch.float64)

# Case Study: Batch-wise Scatter Values

In [126]:
import tensorflow as tf
import torch

index = tf.constant([
    [1, 3, 6, 2, 4],
    [2, 4, 8, 9, 5],
    [6, 5, 0, 1, 3],
    [3, 7, 2, 5, 4],    
], dtype=tf.int64)
inputs = tf.constant([
    [2, 5, 6, 2, 8],
    [6, 1, 7, 0, 4],    
    [9, 5, 4, 5, 1],    
    [0, 6, 8, 3, 7],    
])
output = torch.scatter(
    torch.zeros((4, 10), dtype=int),
    1, 
    torch.tensor(index.numpy(), dtype=int), 
    torch.tensor(inputs.numpy(), dtype=int),
)
output

tensor([[0, 2, 2, 5, 8, 0, 6, 0, 0, 0],
        [0, 0, 6, 0, 1, 4, 0, 0, 7, 0],
        [4, 5, 0, 1, 0, 5, 9, 0, 0, 0],
        [0, 0, 8, 0, 7, 3, 0, 6, 0, 0]])

In [127]:
import tensorflow as tf

def scatter_batch_2d(index, inputs):
    # Determines the length of each row for the output
    row_length = tf.cast(tf.reduce_max(index), tf.int32) + 1
    row_length = tf.maximum(tf.shape(inputs)[1], row_length)
    # start and end index of each row
    # row 0: inc[0]:inc[1], row 1: inc[1]:inc[2], ...
    inc = row_length * tf.range(tf.shape(inputs)[0]+1)
    # Compute flat indices for each element
    indices = tf.cast(index, inc.dtype) + tf.reshape(inc[:-1], (-1, 1))
    # Flatten to column vector
    indices = tf.reshape(indices, (-1, 1))
    # Flatten updates
    updates = tf.reshape(inputs, (-1,))
    shape = (inc[-1], )
    out = tf.scatter_nd(indices, updates, shape)
    # reshape: (batch_size, row_length)
    out = tf.reshape(out, (tf.shape(index)[0], -1))
    return out

In [128]:
out_tf = scatter_batch_2d(index, inputs)
out_tf

<tf.Tensor: shape=(4, 10), dtype=int32, numpy=
array([[0, 2, 2, 5, 8, 0, 6, 0, 0, 0],
       [0, 0, 6, 0, 1, 4, 0, 0, 7, 0],
       [4, 5, 0, 1, 0, 5, 9, 0, 0, 0],
       [0, 0, 8, 0, 7, 3, 0, 6, 0, 0]], dtype=int32)>

In [129]:
np.allclose(out_tf, output)

True