In [154]:
import math
import sys
from itertools import product

sys.path.append("../ares-sc2")  # required to import sc2_helper
sys.path.append("../ares-sc2/src")  # required to import ares
sys.path.append("phantom")

import cvxpy as cp
import highspy
import numpy as np
import scipy as sp

In [155]:
cp.installed_solvers()

['CLARABEL', 'ECOS', 'ECOS_BB', 'HIGHS', 'OSQP', 'QOCO', 'SCIPY', 'SCS']

In [156]:
SOLVER = cp.ECOS

In [157]:
N = 32
M = 32
max_assigned = math.ceil(N / M)
np.random.seed(0)
W = sp.stats.lognorm.rvs(s=3, size=(N, M))
B = np.full(M, max_assigned)
W

array([[1.98771721e+02, 3.32168314e+00, 1.88443654e+01, ...,
        8.21113547e+01, 1.59176311e+00, 3.10957966e+00],
       [6.97137823e-02, 2.62574817e-03, 3.52136481e-01, ...,
        3.40051064e-01, 8.72097758e-02, 5.63449470e-03],
       [1.70280761e+00, 2.99589280e-01, 7.51694825e-03, ...,
        1.87260782e+01, 2.91275448e+00, 8.32880097e+00],
       ...,
       [2.98453438e+00, 8.25963812e+01, 1.18903503e+02, ...,
        8.94333859e-02, 1.14132037e+03, 2.15477620e-02],
       [3.33986980e-01, 1.66811214e+01, 2.43561555e+00, ...,
        4.02808200e-01, 3.75503919e+00, 1.70980375e+00],
       [9.08752858e-02, 2.05929261e+00, 2.38062133e+00, ...,
        8.19172111e+00, 2.85808850e-01, 3.71612040e-02]], shape=(32, 32))

In [158]:
x = cp.Variable((N, M), "x")
w = cp.Parameter((N, M), name="w")
constraints = [
    cp.sum(x, 0) <= max_assigned,  # enforce even distribution
    cp.sum(x, 1) == 1,
    x >= 0,
]
problem = cp.Problem(cp.Minimize(cp.vdot(w, x)), constraints)

In [159]:
def solve_cvxpy():
    w.value = W
    problem.solve(solver="ECOS")
    return x.value


solve_cvxpy().argmax(1)

array([20, 10,  2,  6, 29, 23, 13, 21, 22, 19, 14, 30,  1, 11,  9, 26, 16,
       27,  4,  3, 28,  0, 31, 12,  7,  8, 18, 15, 17,  5, 24, 25])

In [160]:
A_ub = np.tile(np.identity(M), (1, N))
b_ub = np.full(M, max_assigned)

A_eq = np.repeat(np.identity(N), M, axis=1)
b_eq = np.full(N, 1.0)

c = W.flatten()

In [161]:
def solve_highs():
    A_ub = np.tile(np.identity(M), (1, N))
    b_ub = np.full(M, max_assigned)

    A_eq = np.repeat(np.identity(N), M, axis=1)
    b_eq = np.full(N, 1.0)

    c = W.flatten()
    return sp.optimize.linprog(
        c=c,
        A_ub=A_ub,
        b_ub=b_ub,
        A_eq=A_eq,
        b_eq=b_eq,
        method="highs",
    ).x.reshape((N, M))


solve_highs().argmax(1)

array([20, 10,  2,  6, 29, 23, 13, 21, 22, 19, 14, 30,  1, 11,  9, 26, 16,
       27,  4,  3, 28,  0, 31, 12,  7,  8, 18, 15, 17,  5, 24, 25])

In [162]:
def solve_highspy():
    h = highspy.Highs()
    vars = {(i, j): h.addVariable(lb=0.0, ub=1.0) for i, j in product(range(N), range(M))}
    for j in range(M):
        h.addConstr(sum(vars[i, j] for i in range(N)) <= max_assigned)
    for i in range(N):
        h.addConstr(sum(vars[i, j] for j in range(M)) == 1.0)
    h.minimize(sum(W[i, j] * vars[i, j] for i in range(N) for j in range(M)))

    # lp = h.getLp()
    # print(lp.a_matrix_.start_)
    # print(lp.a_matrix_.index_)
    # print(lp.a_matrix_.value_)
    h.run()

    solution = h.getSolution()
    return np.array(solution.col_value).reshape((N, M))


solve_highspy().argmax(1)

array([20, 10,  2,  6, 29, 23, 13, 21, 22, 19, 14, 30,  1, 11,  9, 26, 16,
       27,  4,  3, 28,  0, 31, 12,  7,  8, 18, 15, 17,  5, 24, 25])

In [163]:
np.concatenate((b_eq, b_ub))

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [164]:
A_eq

array([[1., 1., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 1., 1.]], shape=(32, 1024))

In [165]:
np.concatenate((A_eq, A_ub))

array([[1., 1., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]], shape=(64, 1024))

In [166]:
np.arange((N + M) * N * M, step=N * M), np.tile(np.arange(N * M), N + M), np.concatenate((A_eq, A_ub)).flatten()

(array([    0,  1024,  2048,  3072,  4096,  5120,  6144,  7168,  8192,
         9216, 10240, 11264, 12288, 13312, 14336, 15360, 16384, 17408,
        18432, 19456, 20480, 21504, 22528, 23552, 24576, 25600, 26624,
        27648, 28672, 29696, 30720, 31744, 32768, 33792, 34816, 35840,
        36864, 37888, 38912, 39936, 40960, 41984, 43008, 44032, 45056,
        46080, 47104, 48128, 49152, 50176, 51200, 52224, 53248, 54272,
        55296, 56320, 57344, 58368, 59392, 60416, 61440, 62464, 63488,
        64512]),
 array([   0,    1,    2, ..., 1021, 1022, 1023], shape=(65536,)),
 array([1., 1., 1., ..., 0., 0., 1.], shape=(65536,)))

In [167]:
np.arange(N * N * M, step=N * M), np.tile(np.arange(N * M), N), A_eq.flatten()

(array([    0,  1024,  2048,  3072,  4096,  5120,  6144,  7168,  8192,
         9216, 10240, 11264, 12288, 13312, 14336, 15360, 16384, 17408,
        18432, 19456, 20480, 21504, 22528, 23552, 24576, 25600, 26624,
        27648, 28672, 29696, 30720, 31744]),
 array([   0,    1,    2, ..., 1021, 1022, 1023], shape=(32768,)),
 array([1., 1., 1., ..., 1., 1., 1.], shape=(32768,)))

In [168]:
A_ub

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]], shape=(32, 1024))

In [169]:
np.arange(M * N * M, step=N * M), np.tile(np.arange(N * M), M), A_ub.flatten()

(array([    0,  1024,  2048,  3072,  4096,  5120,  6144,  7168,  8192,
         9216, 10240, 11264, 12288, 13312, 14336, 15360, 16384, 17408,
        18432, 19456, 20480, 21504, 22528, 23552, 24576, 25600, 26624,
        27648, 28672, 29696, 30720, 31744]),
 array([   0,    1,    2, ..., 1021, 1022, 1023], shape=(32768,)),
 array([1., 0., 0., ..., 0., 0., 1.], shape=(32768,)))

In [170]:
h = highspy.Highs()
h.setOptionValue("presolve", "off")
h.setOptionValue("log_to_console", False)
h.setOptionValue("output_flag", False)
vars = {(i, j): h.addVariable(lb=0.0, ub=1.0) for i, j in product(range(N), range(M))}
for i in range(N):
    h.addConstr(sum(vars[i, j] for j in range(M)) == 1.0)
for j in range(M):
    h.addConstr(sum(vars[i, j] for i in range(N)) <= max_assigned)
h.minimize(sum(W[i, j] * vars[i, j] for i in range(N) for j in range(M)))
lp = h.getLp()

In [171]:
def solve_highspy_lp():
    lp.col_cost_ = W.flatten()
    lp.row_upper_ = np.concatenate((np.ones(N), np.full(M, max_assigned)))

    # lp = highspy.HighsLp()
    # lp.num_col_ = N * M
    # lp.num_row_ = N
    # lp.col_cost_ = W.flatten()
    # lp.col_lower_ = np.zeros(N * M)
    # lp.col_upper_ = np.ones(N * M)
    # lp.row_lower_ = np.ones(N)
    # lp.row_upper_ = np.ones(N)
    #
    # A_ub = np.tile(np.identity(M), (1, N))
    # A_eq = np.repeat(np.identity(N), M, axis=1)
    #
    # lp.a_matrix_.start_ = np.concatenate((np.arange(N * N * M, step=N * M), [N*N*M]))
    # lp.a_matrix_.index_ = np.tile(np.arange(N * M), N)
    # lp.a_matrix_.value_ = A_eq.flatten()

    # h = highspy.Highs()
    h.passModel(lp)
    h.run()
    solution = list(h.getSolution().col_value)
    return np.asarray(solution).reshape((N, M))


solve_highspy_lp().argmax(1)

array([20, 10,  2,  6, 29, 23, 13, 21, 22, 19, 14, 30,  1, 11,  9, 26, 16,
       27,  4,  3, 28,  0, 31, 12,  7,  8, 18, 15, 17,  5, 24, 25])

In [172]:
np.repeat(np.arange(M), N)

array([ 0,  0,  0, ..., 31, 31, 31], shape=(1024,))

In [173]:
np.tile(np.arange(M), N)

array([ 0,  1,  2, ..., 29, 30, 31], shape=(1024,))

In [174]:
A_ub

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]], shape=(32, 1024))

In [175]:
A_ub_sparse = sp.sparse.csc_array(
    (
        np.ones(N * M),
        (
            np.tile(np.arange(M), N),
            np.arange(N * M),
        ),
    )
)
A_ub_sparse.toarray()

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]], shape=(32, 1024))

In [176]:
A_eq

array([[1., 1., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 1., 1.]], shape=(32, 1024))

In [177]:
A_eq_sparse = sp.sparse.csc_array(
    (
        np.ones(N * M),
        (
            np.repeat(np.arange(N), M),
            np.arange(N * M),
        ),
    )
)
A_eq_sparse.toarray()

array([[1., 1., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 1., 1., 1.]], shape=(32, 1024))

In [178]:
A_ub_sparse = sp.sparse.csc_array(A_ub)
A_eq_sparse = sp.sparse.csc_array(A_eq)

In [179]:
def solve_highs_sparse():
    A_ub_sparse = sp.sparse.coo_array(
        (
            np.ones(N * M),
            (
                np.repeat(np.arange(M), N),
                np.tile(np.arange(M), N),
            ),
        )
    )
    A_eq_sparse = sp.sparse.coo_array(A_eq)
    return sp.optimize.linprog(
        c=c,
        A_ub=A_ub_sparse,
        b_ub=b_ub,
        A_eq=A_eq_sparse,
        b_eq=b_eq,
        method="highs",
    ).x.reshape((N, M))


solve_highs().argmax(1)

array([20, 10,  2,  6, 29, 23, 13, 21, 22, 19, 14, 30,  1, 11,  9, 26, 16,
       27,  4,  3, 28,  0, 31, 12,  7,  8, 18, 15, 17,  5, 24, 25])

In [180]:
%%timeit
solve_cvxpy()

3.99 ms ± 785 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [181]:
%%timeit
solve_highs()

357 ms ± 134 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [182]:
%%timeit
solve_highspy()

14.3 ms ± 2.07 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [183]:
%%timeit
solve_highspy_lp()

634 μs ± 18.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
