In [117]:
import math
import sys

sys.path.append("../ares-sc2")  # required to import sc2_helper
sys.path.append("../ares-sc2/src")  # required to import ares
sys.path.append("phantom")

import cvxpy as cp
import numpy as np
import scipy as sp

In [118]:
cp.installed_solvers()

['CLARABEL', 'ECOS', 'ECOS_BB', 'OSQP', 'QOCO', 'SCIPY', 'SCS']

In [119]:
SOLVER = cp.ECOS

In [120]:
N = 3
M = 4
max_assigned = math.ceil(N / M)
np.random.seed(0)
W = sp.stats.lognorm.rvs(s=3, size=(N, M))
B = np.full(M, max_assigned)
W

array([[1.98771721e+02, 3.32168314e+00, 1.88443654e+01, 8.31041387e+02],
       [2.71150487e+02, 5.32992169e-02, 1.72923681e+01, 6.35037247e-01],
       [7.33698898e-01, 3.42737789e+00, 1.54053647e+00, 7.84781719e+01]])

In [121]:
x = cp.Variable((N, M), "x")
w = cp.Parameter((N, M), name="w")
constraints = [
    cp.sum(x, 0) <= max_assigned,  # enforce even distribution
    cp.sum(x, 1) == 1,
    x >= 0,
]
problem = cp.Problem(cp.Minimize(cp.vdot(w, x)), constraints)

In [122]:
def solve_cvxpy():
    w.value = W
    problem.solve(solver="ECOS")
    return x.value


solve_cvxpy().argmax(1)

array([1, 3, 0])

In [123]:
A_ub = np.tile(np.identity(M), (1, N))
b_ub = np.full(M, max_assigned)

A_eq = np.repeat(np.identity(N), M, axis=1)
b_eq = np.full(N, 1.0)

c = W.flatten()

In [124]:
def solve_highs():
    A_ub = np.tile(np.identity(M), (1, N))
    b_ub = np.full(M, max_assigned)

    A_eq = np.repeat(np.identity(N), M, axis=1)
    b_eq = np.full(N, 1.0)

    c = W.flatten()
    return sp.optimize.linprog(
        c=c,
        A_ub=A_ub,
        b_ub=b_ub,
        A_eq=A_eq,
        b_eq=b_eq,
        method="highs",
    ).x.reshape((N, M))


solve_highs().argmax(1)

array([1, 3, 0])

In [125]:
np.repeat(np.arange(M), N)

array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])

In [126]:
np.tile(np.arange(M), N)

array([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3])

In [127]:
A_ub

array([[1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.]])

In [128]:
A_ub_sparse = sp.sparse.csc_array(
    (
        np.ones(N * M),
        (
            np.tile(np.arange(M), N),
            np.arange(N * M),
        ),
    )
)
A_ub_sparse.toarray()

array([[1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1.]])

In [129]:
A_eq

array([[1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.]])

In [130]:
A_eq_sparse = sp.sparse.csc_array(
    (
        np.ones(N * M),
        (
            np.repeat(np.arange(N), M),
            np.arange(N * M),
        ),
    )
)
A_eq_sparse.toarray()

array([[1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.]])

In [131]:
A_ub_sparse = sp.sparse.csc_array(A_ub)
A_eq_sparse = sp.sparse.csc_array(A_eq)

In [132]:
def solve_highs_sparse():
    A_ub_sparse = sp.sparse.csc_array(
        (
            np.ones(N * M),
            (
                np.repeat(np.arange(M), N),
                np.tile(np.arange(M), N),
            ),
        )
    )
    A_eq_sparse = sp.sparse.csc_array(A_eq)
    return sp.optimize.linprog(
        c=c,
        A_ub=A_ub_sparse,
        b_ub=b_ub,
        A_eq=A_eq_sparse,
        b_eq=b_eq,
        method="highs",
    ).x.reshape((N, M))


solve_highs().argmax(1)

array([1, 3, 0])

In [133]:
%%timeit
solve_cvxpy()

579 μs ± 17.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [134]:
%%timeit
solve_highs()

902 μs ± 22.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [135]:
%%timeit
solve_highs_sparse()

ValueError: Invalid input for linprog: A_ub must have exactly two dimensions, and the number of columns in A_ub must be equal to the size of c