# Column Generation


In [1]:
import numpy as np
import scipy.sparse
import matplotlib.pyplot as plt
from exputils.Amat.get import get_Amat_sparse
from exputils.extent.custom import calculate_extent_custom
from exputils.state.random_ket import make_random_quantum_state
from exputils.extent.actual import calculate_extent_actual

In [4]:
from exputils.math.q_binom import q_binomial
from exputils.stabilizer_group import total_stabilizer_group_size


n = 10
print(total_stabilizer_group_size(n))
print(1 << 64)

87876754128408960000
18446744073709551616


In [6]:
from exputils.dot.get_topK_Amat import get_topK_Amat


def CG(n: int, psi: np.ndarray, K: float, method: str = "mosek"):
    print(f"CG: {n=}, {K=}, {method=}")
    print("start: calculate dots")
    current_Amat = get_topK_Amat(n, psi, False)
    iter_max = 100
    eps = 1e-8
    discard_current_threshold = 0.8
    violation_max = 10000
    for it in range(iter_max):
        print(f"iteration: {it + 1} / {iter_max}, Amat.shape = {current_Amat.shape}")
        print("start: solve SOCP")
        stabilizer_extent, coeff, dual = calculate_extent_custom(
            n, current_Amat, psi, method
        )
        print(f"{stabilizer_extent=}")
        print("start: calculate dual dots")
        dual_dots_state = get_topK_Amat(n, dual, True)
        dual_dots = np.abs(dual.conj().T @ dual_dots_state)
        dual_violated_indices = dual_dots > 1 + eps
        violated_count = np.sum(dual_violated_indices)
        print(
            f"# of violations: {violated_count}"
            + (
                "+ more"
                if violated_count == dual_dots.size and dual_dots.size > 0
                else ""
            )
        )

        # restrict current Amat
        nonbasic_indices = np.abs(coeff) > eps
        critical_indices = np.abs(dual @ current_Amat) >= (
            discard_current_threshold - eps
        )
        remain_indices = np.logical_or(nonbasic_indices, critical_indices)
        current_Amat = current_Amat[:, remain_indices]

        if violated_count == 0:
            print("OPTIMAL!")
            break
        extra_size = min(violation_max, violated_count)
        extra_Amat = dual_dots_state[
            :, np.argpartition(dual_dots, -extra_size)[-extra_size:]
        ]
        print(f"{current_Amat.shape=}, {extra_Amat.shape=}")
        current_Amat = scipy.sparse.hstack([current_Amat, extra_Amat])
        print(current_Amat.shape)
    return stabilizer_extent

In [7]:
n = 4
for seed in range(3):
    print("=" * 20)
    np.random.seed(seed)
    psi = make_random_quantum_state("pure", n, seed)
    psi_check = psi.copy()
    stabilizerExtent = CG(n, psi, 0.01)
    print(f"{stabilizerExtent=}")
    stabilizerExtent_check = calculate_extent_actual(n, psi)[0]
    print(f"{stabilizerExtent_check=}")
    assert np.allclose(psi, psi_check, atol=1e-5)
    assert np.isclose(stabilizerExtent, stabilizerExtent_check, atol=1e-5)
    print("CORRECT!")

CG: n=4, K=0.01, method='mosek'
start: calculate dots
[k|progress|range]: 1 | 496/36720 | [0.00646055, 0.584277]
[k|progress|range]: 2 | 4976/36720 | [0.0024813, 0.661496]
[k|progress|range]: 3 | 20336/36720 | [0.387347, 0.692573]
[k|progress|range]: 4 | 36720/36720 | [0.387347, 0.692573]
 calculation time : 5[ms]
 values sort time : 0[ms]
branch cut / total: 0/36720
iteration: 1 / 100, Amat.shape = (16, 500)
start: solve SOCP
stabilizer_extent=2.5436045590731933
start: calculate dual dots
[k|progress|range]: 1 | 496/36720 | [999, -999]
[k|progress|range]: 2 | 4976/36720 | [999, -999]
[k|progress|range]: 3 | 20336/36720 | [999, -999]
[k|progress|range]: 4 | 36720/36720 | [999, -999]
 calculation time : 3[ms]
 values sort time : 0[ms]
branch cut / total: 10240/36720
# of violations: 0
OPTIMAL!
stabilizerExtent=2.5436045590731933
stabilizerExtent_check=2.5436046106260144
CORRECT!
CG: n=4, K=0.01, method='mosek'
start: calculate dots
[k|progress|range]: 1 | 496/36720 | [0.00544214, 0.5963

In [8]:
import time

for n in [4, 5, 6, 7]:
    print("=" * 20)
    print(f"{n=}")
    np.random.seed(seed)
    psi = make_random_quantum_state("pure", n, seed)
    t0 = time.perf_counter()
    stabilizer_extent = CG(n, psi, 0.01)
    t1 = time.perf_counter()
    print(f"{stabilizer_extent=} {t1-t0=}")

n=4
CG: n=4, K=0.01, method='mosek'
start: calculate dots
[k|progress|range]: 1 | 496/36720 | [0.0146328, 0.6006]
[k|progress|range]: 2 | 4976/36720 | [0.000825603, 0.668332]
[k|progress|range]: 3 | 20336/36720 | [0.382695, 0.690228]
[k|progress|range]: 4 | 36720/36720 | [0.382695, 0.733214]
 calculation time : 5[ms]
 values sort time : 0[ms]
branch cut / total: 0/36720
iteration: 1 / 100, Amat.shape = (16, 500)
start: solve SOCP
stabilizer_extent=2.5336438140932094
start: calculate dual dots
[k|progress|range]: 1 | 496/36720 | [999, -999]
[k|progress|range]: 2 | 4976/36720 | [999, -999]
[k|progress|range]: 3 | 20336/36720 | [999, -999]
[k|progress|range]: 4 | 36720/36720 | [999, -999]
 calculation time : 3[ms]
 values sort time : 0[ms]
branch cut / total: 8704/36720
# of violations: 0
OPTIMAL!
stabilizerExtent=2.5336438140932094
n=5
CG: n=5, K=0.01, method='mosek'
start: calculate dots
[k|progress|range]: 1 | 2016/2423520 | [0.00340892, 0.438436]
[k|progress|range]: 2 | 41696/2423520 