In [1]:
import jax.numpy as jnp
import matplotlib.pyplot as plt

jnp.tril(-(jnp.ones((3,3))-jnp.eye(3))*1e6)

Array([[      -0.,        0.,        0.],
       [-1000000.,       -0.,        0.],
       [-1000000., -1000000.,       -0.]], dtype=float32)

In [2]:
def int_to_binary_array(x, num_bits):
    """
    Converts an array of integers to their binary representation arrays with a fixed number of bits.
    This function is designed to be compatible with Jax's vmap for vectorization over an array of integers.

    Parameters:
    - x: An array of integers, the numbers to convert.
    - num_bits: Integer, the fixed number of bits for the binary representation.

    Returns:
    - A 2D Jax array where each row is the binary representation of an integer in 'x'.
    """
    # Create an array of bit positions: [2^(num_bits-1), 2^(num_bits-2), ..., 1]
    powers_of_two = 2 ** jnp.arange(num_bits - 1, -1, -1)

    # Expand dims of x and powers_of_two for broadcasting
    x_expanded = x[:, None]
    powers_of_two_expanded = powers_of_two[None, :]

    # Perform bitwise AND between each number and each power of two, then right shift to get the bit value
    binary_matrix = (x_expanded & powers_of_two_expanded) >> jnp.arange(num_bits - 1, -1, -1)

    return binary_matrix.astype(jnp.int32)  # Ensure the result is integer

In [5]:
int_to_binary_array(jnp.ones(3, dtype = int ), 4).shape

(3, 4)

In [27]:
int_to_binary_array(jnp.array([1,2,3,4]), 4)

Array([[0, 0, 0, 1],
       [0, 0, 1, 0],
       [0, 0, 1, 1],
       [0, 1, 0, 0]], dtype=int32)

In [26]:
jnp.transpose(int_to_binary_array(jnp.array([1,2,3,4]), 4).reshape(2, 2, 2, 2), (0, 2, 1, 3)).reshape(4, 4)

Array([[0, 0, 0, 0],
       [0, 1, 1, 0],
       [0, 0, 0, 1],
       [1, 1, 0, 0]], dtype=int32)

In [6]:
n_encode = 16
for i in range(0, n_encode, 4):
    print(i)

0
4
8
12


In [12]:
import torch.nn as nn
import torch
encoder_layer = nn.TransformerEncoderLayer(d_model=128, nhead=4)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
for i, layer in enumerate(transformer_encoder.layers):
    print(i)
    print(layer)

0
TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (linear1): Linear(in_features=128, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=128, bias=True)
  (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)
1
TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
  )
  (linear1): Linear(in_features=128, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=128, bias=True)
  (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((128,), eps=1e-05, e

In [14]:
torch.zeros([2, 0, 16, 128]) is not None

True

In [15]:
jnp.arange(5)%4

Array([0, 1, 2, 3, 0], dtype=int32)

In [18]:
def pos_2d(Ny, Nx, units):
    x_odd_f = jnp.repeat(jnp.array([1, 0, 0, 0]), units // 4)
    x_even_f = jnp.repeat(jnp.array([0, 1, 0, 0]), units // 4)
    y_odd_f = jnp.repeat(jnp.array([0, 0, 1, 0]), units // 4)
    y_even_f = jnp.repeat(jnp.array([0, 0, 0, 1]), units // 4)
    p = jnp.arange(units)/units
    x = jnp.arange(Ny*Nx+1) %  Nx
    y = jnp.arange(Ny*Nx+1) // Nx
    return jnp.sin(jnp.outer(x, 1/10000**(p)))*x_odd_f + jnp.cos(jnp.outer(x, 1/10000**(p)))*x_even_f + jnp.sin(jnp.outer(y, 1/10000**(p)))*y_odd_f + jnp.cos(jnp.outer(y, 1/10000**(p)))*y_even_f
pos_2d(4, 4, 128).shape

(17, 128)

In [19]:
N = 5
jnp.tril(-(jnp.ones((N, N)) - jnp.eye(N))*1e9).T

Array([[-0.e+00, -1.e+09, -1.e+09, -1.e+09, -1.e+09],
       [ 0.e+00, -0.e+00, -1.e+09, -1.e+09, -1.e+09],
       [ 0.e+00,  0.e+00, -0.e+00, -1.e+09, -1.e+09],
       [ 0.e+00,  0.e+00,  0.e+00, -0.e+00, -1.e+09],
       [ 0.e+00,  0.e+00,  0.e+00,  0.e+00, -0.e+00]], dtype=float32)

In [20]:
[3] * 4

[3, 3, 3, 3]

In [25]:
Ny = 4
Nx = 4
jnp.concatenate((jnp.array([0, ]), jnp.arange(Ny*Nx) % Nx))

Array([0, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3], dtype=int32)