In [9]:
import numpy as np
import jax
from jax import numpy as jnp

**Can't do update in places like numpy**

In [6]:
nparr = np.zeros((3,3), dtype=np.float32)
print(nparr)
nparr[1, :]=1
print(nparr)

# Update the second row to 1s.
jaxarr = jnp.zeros((3,3), dtype=jnp.float32)
print(jaxarr)
# Throw error.
# jaxarr[1, :]=1

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


**Jax Numpy behaves differently from Numpy in __iadd__ style update**


In [7]:
nparr = np.array([10, 20])
print(nparr)
print(nparr)

jnparr = jnp.array([10, 20])
jnparr_up = jnparr
print(jnparr)
jnparr_up += 10
# Unchanged
print(jnparr)
# Updated
print(jnparr_up)
# A copy was created so they are not the same object.
print(id(jnparr) == id(jnparr_up))


[10 20]
[10 20]
[10 20]
[10 20]
[20 30]
False


**JNP array update array.at[idx].set(value)**

In [8]:
jnparray = jnp.zeros((3,3), dtype=jnp.float32)
print("Initial jnp array =", jnparray)
jnparray_up = jnparray.at[1,:].set(1.0)
print("Updated jnp array=", jnparray_up)
# Note the original jnp array isn't changed.
# print("Original jnp array after update:", jnparray)
print(id(jnparray) == id(jnparray_up))

Initial jnp array = [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Updated jnp array= [[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]
Original jnp array after update: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
False


**Array Index update in JIT**

In [13]:
@jax.jit
def update_second_row(x):
  ret_array = x.at[1, :].set(1)
  print(id(x) == id(ret_array))
  return ret_array

print(update_second_row(jnp.zeros([3,3], dtype=jnp.float32)))

False
[[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]
