In [13]:
import jax.numpy as jnp
import matplotlib.pyplot as plt

jnp.tril(-(jnp.ones((3,3))-jnp.eye(3))*1e6)

Array([[      -0.,        0.,        0.],
       [-1000000.,       -0.,        0.],
       [-1000000., -1000000.,       -0.]], dtype=float32)

In [16]:
jnp.ones(3, dtype=int)

Array([1, 1, 1], dtype=int32)

In [17]:
def int_to_binary_array(x, num_bits):
    """
    Converts an array of integers to their binary representation arrays with a fixed number of bits.
    This function is designed to be compatible with Jax's vmap for vectorization over an array of integers.

    Parameters:
    - x: An array of integers, the numbers to convert.
    - num_bits: Integer, the fixed number of bits for the binary representation.

    Returns:
    - A 2D Jax array where each row is the binary representation of an integer in 'x'.
    """
    # Create an array of bit positions: [2^(num_bits-1), 2^(num_bits-2), ..., 1]
    powers_of_two = 2 ** jnp.arange(num_bits - 1, -1, -1)

    # Expand dims of x and powers_of_two for broadcasting
    x_expanded = x[:, None]
    powers_of_two_expanded = powers_of_two[None, :]

    # Perform bitwise AND between each number and each power of two, then right shift to get the bit value
    binary_matrix = (x_expanded & powers_of_two_expanded) >> jnp.arange(num_bits - 1, -1, -1)

    return binary_matrix.astype(jnp.int32)  # Ensure the result is integer


In [27]:
int_to_binary_array(jnp.array([1,2,3,4]), 4)

Array([[0, 0, 0, 1],
       [0, 0, 1, 0],
       [0, 0, 1, 1],
       [0, 1, 0, 0]], dtype=int32)

In [26]:
jnp.transpose(int_to_binary_array(jnp.array([1,2,3,4]), 4).reshape(2, 2, 2, 2), (0, 2, 1, 3)).reshape(4, 4)

Array([[0, 0, 0, 0],
       [0, 1, 1, 0],
       [0, 0, 0, 1],
       [1, 1, 0, 0]], dtype=int32)