In [11]:
import jax.numpy as jnp

DTYPE_SIZE = {
    jnp.float32: 4,
    jnp.float64: 8,
    jnp.complex64: 8,
    jnp.complex128: 16,
}

def bytes_of(shape, dtype):
    n = 1
    for s in shape:
        n *= s
    return n * DTYPE_SIZE[dtype]

In [7]:
chol = jnp.zeros((200,20,20))
print(chol.shape,chol.dtype)

(200, 20, 20) float32


In [15]:
mem_chol = bytes_of(chol.shape,jnp.complex128)
print(mem_chol/1000/1000)

1.28


In [16]:
import opt_einsum as oe

# Define your shapes and the Einstein summation string
subscripts = "bij,bjk,bkl->bil"
shape_A = (100, 500, 500)
shape_B = (100, 500, 500)
shape_C = (100, 500, 500)

# Get the path and the info object
path, path_info = oe.contract_path(subscripts, shape_A, shape_B, shape_C)

# Extract the peak memory (number of elements)
peak_elements = path_info.largest_intermediate
print(f"Largest intermediate has {peak_elements} elements.")

ValueError: Einstein sum subscript 'bij' does not contain the correct number of indices for operand 0.

In [18]:
g,i,j,p,q = 1000,15,14,77,66
path, info = oe.contract_path(
    "gpi,ji->gpj",
    jnp.empty((g,p,i)),
    jnp.empty((j, i)),
    optimize="optimal",
)

max_elements = info.largest_intermediate
bytes_estimate = max_elements * 16
print(max_elements,bytes_estimate)

1078000 17248000


In [22]:
def estimate_with_opt_einsum(expr, shapes, dtype):
    import opt_einsum as oe
    import numpy as np

    dummy = [np.empty(s) for s in shapes]
    _, info = oe.contract_path(expr, *dummy, optimize="optimal")

    bytes_intermediate = info.largest_intermediate * dtype.itemsize

    bytes_inputs = sum(
        np.prod(s) * dtype.itemsize for s in shapes
    )

    # conservative estimate
    return bytes_intermediate + bytes_inputs

In [26]:
expr = "gpi,ji->gpj"
shapes = [[g,p,i],[j,i]]
dtype = jnp.complex128.dtype

In [29]:
size = estimate_with_opt_einsum(expr, shapes, dtype)
print(size/1000)

35731.36
