# JAX: print array shapes and values during runtime

In [2]:
import jax
import jax.numpy as jnp
import jax.random as jrandom

In [12]:
jax.__version__

'0.4.13'

## Raw print statement

Fine for printing shapes, not good for printing values

In [3]:
def matmul(A, B):
    print(A.shape, B.shape)
    print(A)
    C = A @ B
    return C

In [4]:
rnd_key = jrandom.PRNGKey(42)
As_key, rnd_key = jrandom.split(rnd_key)
Bs_key, rnd_key = jrandom.split(rnd_key)

In [5]:
As = jrandom.normal(As_key, (10, 10, 2, 3))
Bs = jrandom.normal(As_key, (10, 10, 3, 4))

Printing shapes works fine, but array values are unreadable (the entire `As` is printed).

In [6]:
Cs = jax.vmap(jax.vmap(matmul, in_axes=0, out_axes=0), in_axes=0, out_axes=0)(As, Bs)
Cs.shape

(2, 3) (3, 4)
Traced<ShapedArray(float32[2,3])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(float32[10,2,3])>with<BatchTrace(level=1/0)> with
    val = Array([[[[ 1.41631317e+00,  1.42119959e-01,  5.63737333e-01],
         [-8.08823943e-01, -5.55994034e-01,  1.40435195e+00]],

        [[ 1.47389174e+00,  1.84074283e+00, -3.30875933e-01],
         [-6.48018777e-01,  1.14413667e+00, -1.04831493e+00]],

        [[ 5.35528660e-01, -5.17392397e-01, -5.60730875e-01],
         [-5.07189035e-01, -5.83525896e-01,  1.43764162e+00]],

        [[-1.27212989e+00, -3.57313424e-01,  1.97515979e-01],
         [-3.89750123e-01,  5.70347130e-01,  5.38806558e-01]],

        [[-6.05107784e-01, -8.13592553e-01,  7.08761930e-01],
         [-7.96500325e-01,  2.42265034e+00,  8.75406325e-01]],

        [[ 7.06482470e-01, -1.24035847e+00, -6.96492374e-01],
         [-4.87383485e-01,  3.56826633e-01,  2.73126632e-01]],

        [[ 7.65678823e-01, -2.74423540e-01,  5.19199848e-01],
         [-7.73

(10, 10, 2, 4)

## `jax.debug.print`

Fine for printing shapes, not good for printing values (all As are printed)

In [7]:
def matmul(A, B):
    jax.debug.print("A shape {}" , A.shape)
    jax.debug.print("B shape {}" , B.shape)
    jax.debug.print("A values {}", A)
    C = A @ B
    return C

In [8]:
Cs = jax.vmap(jax.vmap(matmul, in_axes=0, out_axes=0), in_axes=0, out_axes=0)(As, Bs)
Cs.shape

A shape (2, 3)
B shape (3, 4)
A values [[ 1.4163132   0.14211996  0.56373733]
 [-0.80882394 -0.55599403  1.404352  ]]
A values [[-0.30616117 -0.23144977 -0.94812495]
 [-0.87843996 -0.42498556  0.54482937]]
A values [[-0.2443929  -0.2599617   0.720695  ]
 [-0.98735416 -0.17740317  1.6958072 ]]
A values [[ 0.16638309 -0.847358   -0.05225986]
 [ 1.6425724   1.0063661   0.24831384]]
A values [[-1.3084904  -0.74921197 -0.7374831 ]
 [ 0.4404689  -2.0633237  -0.39973775]]
A values [[ 0.76452166 -0.16549315 -1.270762  ]
 [-0.70556366 -1.2112911   0.8208053 ]]
A values [[ 1.3224034  -0.24695611  0.3411725 ]
 [-0.9176694  -0.04110909  0.5098628 ]]
A values [[-0.9061809 -0.939563  -0.1116497]
 [-1.6221457  1.0639017  0.5984525]]
A values [[ 1.3330323   1.9307811   0.07144103]
 [ 2.5634098  -0.20958011 -1.9244288 ]]
A values [[ 0.6843879   0.3470698   2.0005767 ]
 [-0.5939122   0.07250637  0.04880551]]
A values [[ 1.4738917   1.8407428  -0.33087593]
 [-0.6480188   1.1441367  -1.0483149 ]]
A values

(10, 10, 2, 4)

## `jax.debug.breakpoint`

Fine for printing shapes and values

In [9]:
def matmul(A, B):
    C = A @ B
    jax.debug.breakpoint()
    return C

In [10]:
Cs = jax.vmap(jax.vmap(matmul, in_axes=0, out_axes=0), in_axes=0, out_axes=0)(As, Bs)
Cs.shape

Entering jdb:


(jdb)  l


> /var/folders/3p/nlw2tdvn1jjg8pmw13cy2j5r0000gn/T/ipykernel_19042/2048950809.py(3)
    def matmul(A, B):
        C = A @ B
->      jax.debug.breakpoint()
        return C
    


(jdb)  A.shape


(2, 3)


(jdb)  A


Array([[ 1.4163132 ,  0.14211996,  0.56373733],
       [-0.80882394, -0.55599403,  1.404352  ]], dtype=float32)


(jdb)  C


Array([[-0.45385492, -0.41346142, -0.86743045, -0.56505036],
       [ 2.8056443 ,  0.34733254, -3.3420308 , -0.86739   ]],      dtype=float32)


(jdb)  q


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [11]:
As[0, 0]

Array([[ 1.4163132 ,  0.14211996,  0.56373733],
       [-0.80882394, -0.55599403,  1.404352  ]], dtype=float32)