In [1]:
import numpy as np
from exputils.RoM.actual import calculate_RoM_actual
from exputils.RoM.custom import calculate_RoM_custom
from exputils.state.random import make_random_quantum_state
from exputils.RoM.dot import get_topK_indices, make_Amat_from_column_index
from scipy.sparse import hstack
from exputils.dot.dot_product import compute_all_dot_products
from exputils.dot.load_data import load_dot_data
from time import perf_counter

In [2]:
eps = 10**-5

In [3]:
def column_generation(
    n,
    rho,
    K,
):
    data_per_col, rows_per_col = load_dot_data(n)
    rho_dots = compute_all_dot_products(n, rho, data_per_col, rows_per_col)
    indices = get_topK_indices(rho_dots, K)
    current_Amat = make_Amat_from_column_index(n, indices, data_per_col, rows_per_col)
    while True:
        RoM, _, dual = calculate_RoM_custom(
            current_Amat,
            rho,
            method="gurobi",
            return_dual=True,
            crossover=False,
            presolve=False,
        )

        if np.isnan(RoM):
            return RoM

        dual_dots = compute_all_dot_products(n, dual, data_per_col, rows_per_col)
        dual_violated_indices = np.abs(dual_dots) > 1 + eps
        violated_count = np.sum(dual_violated_indices)

        if violated_count == 0:
            return RoM
        else:
            indices = np.where(dual_violated_indices)[0]
            extra_Amat = make_Amat_from_column_index(
                n, indices, data_per_col, rows_per_col
            )
        current_Amat = hstack((current_Amat, extra_Amat))

In [4]:
def column_generation_discard(
    n,
    rho,
    K,
    discard_current_threshold,
):
    data_per_col, rows_per_col = load_dot_data(n)
    rho_dots = compute_all_dot_products(n, rho, data_per_col, rows_per_col)
    indices = get_topK_indices(rho_dots, K)
    current_Amat = make_Amat_from_column_index(n, indices, data_per_col, rows_per_col)
    while True:
        RoM, coeff, dual = calculate_RoM_custom(
            current_Amat,
            rho,
            method="gurobi",
            return_dual=True,
            crossover=False,
            presolve=False,
        )

        if np.isnan(RoM):
            return RoM

        dual_dots = compute_all_dot_products(n, dual, data_per_col, rows_per_col)
        dual_violated_indices = np.abs(dual_dots) > 1 + eps
        violated_count = np.sum(dual_violated_indices)

        # restrict current Amat
        nonbasic_indices = np.abs(coeff) > eps
        critical_indices = np.abs(dual @ current_Amat) >= (
            discard_current_threshold - eps
        )
        remain_indices = np.logical_or(nonbasic_indices, critical_indices)
        current_Amat = current_Amat[:, remain_indices]

        if violated_count == 0:
            return RoM
        else:
            indices = np.where(dual_violated_indices)[0]
            extra_Amat = make_Amat_from_column_index(
                n, indices, data_per_col, rows_per_col
            )
        current_Amat = hstack((current_Amat, extra_Amat))

### benchmark

In [5]:
n = 5
rho = make_random_quantum_state("mixed", n, 999)

K = 0.005

In [6]:
# warm up for caching
start = perf_counter()
RoM_cg = column_generation(n, rho, K)
end = perf_counter()
assert not np.isnan(RoM_cg)
print(end - start)

Set parameter Username
Academic license - for non-commercial use only - expires 2024-06-16
4.930721553999319


In [7]:
times_cg = []
for seed in range(10):
    print(seed)
    rho = make_random_quantum_state("mixed", n, seed)

    start = perf_counter()
    RoM_cg = column_generation(n, rho, K)
    end = perf_counter()
    assert not np.isnan(RoM_cg)
    print(end - start)
    times_cg.append(end - start)

0
5.266581851999945
1
4.766703998000594
2
4.941473894999945
3
5.180784238000342
4
5.5144107039996015
5
4.672717871999339
6
4.7594281019992195
7
5.217170906998945
8
6.178248479998729
9
4.33026479900218


In [8]:
print(np.mean(times_cg[:10]))
print(np.std(times_cg[:10]))

5.082778484699884
0.4905775273568628


In [9]:
discard_current_threshold = 0.9

In [10]:
# warm up for caching
rho = make_random_quantum_state("mixed", n, 999)

start = perf_counter()
RoM_cg = column_generation_discard(n, rho, K, discard_current_threshold)
end = perf_counter()
assert not np.isnan(RoM_cg)
print(end - start)

2.0876536020005005


In [11]:
times_dis = []
for seed in range(10):
    print(seed)
    rho = make_random_quantum_state("mixed", n, seed)

    start = perf_counter()
    RoM_dis = column_generation_discard(n, rho, K, discard_current_threshold)
    end = perf_counter()
    assert not np.isnan(RoM_dis)
    print(end - start)
    times_dis.append(end - start)

0
2.107932068000082
1
2.247236094000982
2
2.3499844139987545
3
2.3346061470001587
4
2.5250588189992413
5
2.0028024050006934
6
1.9490256159988348
7
2.3386435929969593
8
2.8588698509993264
9
1.9349171350004326


In [12]:
print(np.mean(times_dis[:10]))
print(np.std(times_dis[:10]))

2.2649076141995463
0.27209017412983405


---

In [13]:
# warm up for caching
start = perf_counter()
RoM_lp = calculate_RoM_actual(n, rho, method="gurobi", crossover=False)[0]
end = perf_counter()
assert not np.isnan(RoM_lp)
print(end - start)

114.01958278600068


In [14]:
times_lp = []
for seed in range(10):
    print(seed)
    rho = make_random_quantum_state("mixed", n, seed)

    start = perf_counter()
    RoM_lp = calculate_RoM_actual(n, rho, method="gurobi", crossover=False)[0]
    end = perf_counter()
    assert not np.isnan(RoM_lp)
    print(end - start)
    times_lp.append(end - start)

0
113.90787842999998
1
113.64439415000015
2
111.51517917100136
3
118.93526448100238
4
118.87568161299714
5
115.76000435199967
6
126.26192376600011
7
119.96230427299815
8
125.79034131899971
9
115.59787922000032


In [15]:
print(np.mean(times_lp[:10]))
print(np.std(times_lp[:10]))

118.0250850774999
4.722204069715391
