In [1]:
import jax
import jax.numpy as jnp

In [2]:
fish_position = jnp.array([3, 5])

def is_fish_pixel(pos):
    return jnp.all(
        (pos[0] >= fish_position[0] - 1) & (pos[0] <= fish_position[0] + 1) & (pos[1] >= fish_position[1] - 1) & (pos[1] <= fish_position[1] + 1)
    )

In [None]:
is_fish_pixel(jnp.array([2, 6]))

In [None]:
aa = jnp.zeros((5, 5))
aa = aa.at[(1, 2), (1, 2)].set(1)
print(aa)

In [None]:
@jax.jit
def set_value(arr: jnp.ndarray, pos: jnp.ndarray):
    indices = jnp.array([[-1, -1], [-1, 0], [-1, 1],
                         [0, -1], [0, 0], [0, 1],
                         [1, -1], [1, 0], [1, 1]])
    
    update_pos = pos[jnp.newaxis, :] + indices
    
    def update(val, idx):
        return val.at[idx[0], idx[1]].set(1)
    
    return jax.lax.fori_loop(0, 9, lambda i, val: update(val, update_pos[i]), arr)

arr = jnp.ones((6, 6))
pos = jnp.array([2, 3])

arr = set_value(arr, pos)
print(arr)

In [6]:
@jax.jit
def copy(arr: jnp.ndarray):
    arr = jnp.array(arr)
    return arr

arr = jnp.zeros((5, 5))
arr = copy(arr)

In [None]:
# from functools import partial

# @partial(jax.jit, static_argnames="n")
def get_position(arr: jnp.ndarray, key, n):
    positions = jax.random.choice(key, jnp.array(jnp.where(arr == 0)).T, (n,)) 
    return positions

arr = jnp.zeros((5, 5))
arr = arr.at[1, 1].set(1)
arr = arr.at[2, 2].set(1)
seed = 0
key = jax.random.key(seed)
pos = get_position(arr, key, 3)
print(pos)

In [21]:
@jax.jit
def pick_one_zero(arr, key):

    # jax.debug.print("key={key}", key=key)

    # Get array shape
    height, width = arr.shape
    
    # Create array of all positions
    rows = jnp.arange(height)[:, None]
    cols = jnp.arange(width)[None, :]
    all_rows = jnp.broadcast_to(rows, (height, width)).ravel()
    all_cols = jnp.broadcast_to(cols, (height, width)).ravel()
    all_positions = jnp.stack([all_rows, all_cols], axis=1)

    all_positions = all_positions[jax.random.permutation(key, all_positions.shape[0])]

    init_val = 0
    cond_fun = lambda i: arr[*all_positions[i] ] != 0
    body_fun = lambda i: i + 1
    zero_idx = jax.lax.while_loop(cond_fun=cond_fun, body_fun=body_fun, init_val=init_val)
    zero_position = all_positions[zero_idx]

    # jax.debug.print("{zero_position}", zero_position=zero_position)
    return zero_position


seed = 19
key = jax.random.key(seed)

food_map = jnp.zeros((5, 5))
food_positions = jnp.array([[1, 1], [2, 2], [3, 3]])
food_map = food_map.at[food_positions.T[0], food_positions.T[1]].set(1)
print(food_map)

pick_one_zero(food_map, key)

[[0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0.]]


Array([0, 4], dtype=int32)

In [22]:

# @jax.jit
def update(food_map, food_positions, is_eaten, key):

    food_num = len(is_eaten)
    
    def update_one_eaten_food(operand):

        food_map, food_positions, key, i = operand
        
        pos = pick_one_zero(arr=food_map, key=key)
        food_positions = food_positions.at[i].set(pos)
        food_map = food_map.at[*pos].set(1)

        return food_map, food_positions
    
    def update_step(carry, i):

        food_map, food_positions, is_eaten, key = carry

        key, subkey = jax.random.split(key)

        food_map, food_positions = jax.lax.cond(
            is_eaten[i],
            update_one_eaten_food,
            lambda x: (x[0], x[1]),
            operand=(food_map, food_positions, subkey, i)
        )

        return (food_map, food_positions, is_eaten, key), None
    
    (_, food_positions, _, _), _ = jax.lax.scan(update_step, (food_map, food_positions, is_eaten, key), jnp.arange(food_num))

    return food_positions


seed = 0
key = jax.random.key(seed)

food_map = jnp.zeros((5, 5))
food_positions = jnp.array([[1, 1], [2, 2], [3, 3]])
food_map = food_map.at[food_positions.T[0], food_positions.T[1]].set(1)
is_eaten = jnp.array([True, False, True])

print(food_map)

food_positions = update(
    food_map=food_map, 
    food_positions=food_positions, 
    is_eaten=is_eaten,
    key=key,
)

print(food_positions)

[[0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0.]]
[[1 0]
 [2 2]
 [4 0]]


In [None]:
a = lambda x: (x[0], x[1])

aa = (3, [4, 5], {6})

bb = a(aa)
print(bb)

In [18]:
import jax
import jax.numpy as jnp
from jaxfish.data_classes import frozen, MINIMUM_BRAIN
from functools import partial


@partial(jax.jit, static_argnames=("brain", "length"))
def get_psp_history(brain, length):
    baselines = jnp.expand_dims(jnp.array([n.baseline_rate for n in brain.neurons]), 1)
    psp_history = jnp.ones((len(brain.neurons), length), dtype=float) * baselines
    return psp_history


brain = MINIMUM_BRAIN
brain.neurons[1].baseline_rate = 1.
brain = frozen(brain)
psp_history = get_psp_history(brain, 1000)
print(psp_history)



[[0. 0. 0. ... 0. 0. 0.]
 [1. 1. 1. ... 1. 1. 1.]]


In [7]:
import jax
import jax.numpy as jnp

@jax.jit
def last_one_index(arr):
    indices = jnp.arange(arr.shape[0])
    masked_indices = jnp.where(arr == 1, indices, -1)
    return jnp.max(masked_indices).astype(jnp.int32)

aa = jnp.array([0, 1, 0, 1, 1, 0, 0, 0])
print(last_one_index(aa))

4


In [8]:
aa = jnp.zeros(5)
aa = aa.at[2].set(True)

print(aa)

[0. 0. 1. 0. 0.]


In [7]:
import jax
import jax.numpy as jnp

aa = jnp.arange(10)
aa * False


Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)