In [30]:
import importlib
import math
import os
import subprocess
import sys

sys.path.append("../ares-sc2")         # required to import sc2_helper
sys.path.append("../ares-sc2/src")     # required to import ares
sys.path.append("phantom")

import cvxpy as cp
import scipy as sp
import numpy as np

from phantom.cvxpygen.assign.assign import cpg_assign



In [31]:
cp.installed_solvers()

['CLARABEL',
 'CVXOPT',
 'ECOS',
 'ECOS_BB',
 'GLPK',
 'GLPK_MI',
 'HIGHS',
 'MOSEK',
 'OSQP',
 'PIQP',
 'PROXQP',
 'SCIP',
 'SCIPY',
 'SCS',
 'SDPA']

In [32]:
SOLVER = cp.ECOS

In [33]:
N = 32
M = 32
max_assigned = math.ceil(N / M)
np.random.seed(0)
W = sp.stats.lognorm.rvs(s=2, size=(N, M))
B = np.full(M, max_assigned)
W

array([[3.40593534e+01, 2.22624079e+00, 7.08143073e+00, ...,
        1.88916030e+01, 1.36328176e+00, 2.13043256e+00],
       [1.69386618e-01, 1.90327721e-02, 4.98663231e-01, ...,
        4.87187450e-01, 1.96657319e-01, 3.16643067e-02],
       [1.42597000e+00, 4.47731359e-01, 3.83731726e-02, ...,
        7.05176594e+00, 2.03955734e+00, 4.10886295e+00],
       ...,
       [2.07292882e+00, 1.89659243e+01, 2.41803787e+01, ...,
        1.99986087e-01, 1.09212340e+02, 7.74345501e-02],
       [4.81378132e-01, 6.52855141e+00, 1.81025251e+00, ...,
        5.45421436e-01, 2.41588533e+00, 1.42987315e+00],
       [2.02129884e-01, 1.61862176e+00, 1.78289957e+00, ...,
        4.06365446e+00, 4.33894136e-01, 1.11359307e-01]], shape=(32, 32))

In [34]:
x = cp.Variable((N, M), 'x')
w = cp.Parameter((N, M), name='w')
constraints = [
    cp.sum(x, 0) <= max_assigned,  # enforce even distribution
    cp.sum(x, 1) == 1,
    0 <= x,
]
problem = cp.Problem(cp.Minimize(cp.vdot(w, x)), constraints)

In [35]:
def solve_cpg():
    return cpg_assign(W, B)
solve_cpg().argmax(1)

array([20, 10,  2,  6, 29, 23, 13, 21, 22, 19, 14, 30,  1, 11,  9, 26, 16,
       27,  4,  3, 28,  0, 31, 12,  7,  8, 18, 15, 17,  5, 24, 25])

In [36]:
def solve_cvxpy():
    w.value = W
    problem.solve(solver="ECOS")
    return x.value

solve_cvxpy().argmax(1)

array([20, 10,  2,  6, 29, 23, 13, 21, 22, 19, 14, 30,  1, 11,  9, 26, 16,
       27,  4,  3, 28,  0, 31, 12,  7,  8, 18, 15, 17,  5, 24, 25])

In [37]:
A_ub = np.tile(np.identity(M), (1, N))
b_ub = np.full(M, max_assigned)

A_eq = np.repeat(np.identity(N), M, axis=1)
b_eq = np.full(N, 1.0)

c = W.flatten()

In [38]:
def solve_highs():
    return sp.optimize.linprog(
        c=c,
        A_ub=A_ub,
        b_ub=b_ub,
        A_eq=A_eq,
        b_eq=b_eq,
        method="highs",
    ).x.reshape((N, M))
solve_highs().argmax(1)

array([20, 10,  2,  6, 29, 23, 13, 21, 22, 19, 14, 30,  1, 11,  9, 26, 16,
       27,  4,  3, 28,  0, 31, 12,  7,  8, 18, 15, 17,  5, 24, 25])

In [39]:
A_ub_sparse = sp.sparse.csr_matrix(A_ub)
A_eq_sparse = sp.sparse.csr_matrix(A_eq)

In [40]:
%%timeit
solve_cvxpy()

3.99 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [41]:
%%timeit
solve_highs()

1.43 ms ± 22.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [42]:
%%timeit
solve_cpg()

2.38 ms ± 101 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
