# Estimate Memory

This notebook estimates the memory required to store the Amat in the Scipy csc_matrix format.

The document about the csc_matrix format is [here](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html).

Also, We can confirm how the data type of indices and indptr is determined by [this implementation code](https://github.com/scipy/scipy/blob/v1.11.2/scipy/sparse/_compressed.py#L36), and especially about the dtype, [this code](https://github.com/scipy/scipy/blob/main/scipy/sparse/_sputils.py#L147)


In [1]:
import numpy as np
from exputils.Amat.get import get_Amat_sparse
from exputils.stabilizer_group import total_stabilizer_group_size
from exputils.math.q_binom import q_binomial


def estimate_each_memory(n: int):
    # https://github.com/scipy/scipy/blob/v1.11.2/scipy/sparse/_compressed.py#L36
    M = 2**n
    N = total_stabilizer_group_size(n)
    max_val = max(M, N)

    # https://github.com/scipy/scipy/blob/main/scipy/sparse/_sputils.py#L147
    int32max = np.int32(np.iinfo(np.int32).max)
    dtype = np.int32 if np.intc().itemsize == 4 else np.int64
    if max_val > int32max:
        dtype = np.int64
    idx_dtype_size = 4 if dtype == np.int32 else 8
    assert idx_dtype_size == 4 if n <= 6 else idx_dtype_size == 8

    sz = N

    # the number of non-zero elements are
    # \sum_{k=0}^{n} 2^k (2^{n+k*(k+1)/2}q_binom(n,k))
    nnz = sum(
        (2**k) * (q_binomial(n, k) * (1 << (n + k * (k + 1) // 2)))
        for k in range(n + 1)
    )

    # 16 byte (np.complex128)
    data_estimate = 16 * nnz

    # 4 byte (np.int32) if max_val < 2**31 else 8 byte (np.int64)
    indices_estimate = idx_dtype_size * nnz
    indptr_estimate = idx_dtype_size * (sz + 1)

    return (data_estimate, indices_estimate, indptr_estimate)


def check_each_memory(n):
    Amat = get_Amat_sparse(n)
    return (Amat.data.nbytes, Amat.indices.nbytes, Amat.indptr.nbytes)


# check the correctness
for n in range(1, 5 + 1):
    assert estimate_each_memory(n) == check_each_memory(n)
    print(f"{n} ok")
print("all test passes")

for n in range(1, 8 + 1):
    if n <= 6:
        assert total_stabilizer_group_size(n) < 2**31
    else:
        assert total_stabilizer_group_size(n) > 2**31


# estimate the memory
unit = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"]
for n in range(4, 10 + 1):
    memory = sum(estimate_each_memory(n))
    exponent = 0
    tmp = memory
    while tmp / 1024 >= 1:
        tmp /= 1024
        exponent += 1
    print(f"{n:>2} {round(tmp):>4} {unit[exponent]} ({memory:,} B)")

1 ok
2 ok
3 ok
4 ok
5 ok
all test passes
 4    8 MiB (8,225,284 B)
 5 1011 MiB (1,059,886,084 B)
 6  254 GiB (272,209,766,404 B)
 7  153 TiB (167,771,952,691,208 B)
 8  153 PiB (171,801,080,671,334,408 B)
 9  305 EiB (351,849,950,188,283,289,608 B)
10    1 YiB (1,441,178,767,705,906,944,000,008 B)
