In [74]:
import math
import sys
from itertools import product

sys.path.append("../ares-sc2")  # required to import sc2_helper
sys.path.append("../ares-sc2/src")  # required to import ares
sys.path.append("phantom")

import cvxpy as cp
import highspy
import numpy as np
import scipy as sp

In [75]:
cp.installed_solvers()

['CLARABEL', 'ECOS', 'ECOS_BB', 'HIGHS', 'OSQP', 'QOCO', 'SCIPY', 'SCS']

In [76]:
SOLVER = cp.ECOS

In [77]:
N = 32
M = 32
max_assigned = math.ceil(N / M)
rng = np.random.default_rng()
W = sp.stats.lognorm.rvs(s=3, size=(N, M))
B = np.full(M, max_assigned)
W

array([[3.08094852e-01, 1.69887480e+02, 9.29891005e-02, ...,
        1.27293160e-01, 1.17934929e+00, 8.15065573e-01],
       [9.05757560e-03, 1.57636031e-01, 2.56029867e-01, ...,
        5.92415704e-02, 1.24605007e+01, 5.28834359e+01],
       [2.56030444e+00, 1.22913962e+02, 1.30779683e+04, ...,
        2.20466450e+00, 2.11503947e+00, 2.40414246e-02],
       ...,
       [4.06073467e-01, 5.04607137e-02, 1.26522040e-01, ...,
        1.19215138e-01, 2.44850191e-01, 2.40977440e+02],
       [5.20666734e+00, 6.05485326e+00, 5.15734625e-02, ...,
        7.57955784e+01, 1.36331779e-01, 1.22541020e+01],
       [2.71225121e-02, 6.68591839e+00, 5.58698169e-01, ...,
        2.66625246e-01, 4.08917379e-02, 2.79792184e+01]], shape=(32, 32))

In [78]:
x = cp.Variable((N, M), "x")
w = cp.Parameter((N, M), name="w")
constraints = [
    cp.sum(x, 0) <= max_assigned,  # enforce even distribution
    cp.sum(x, 1) == 1,
    x >= 0,
]
problem = cp.Problem(cp.Minimize(cp.vdot(w, x)), constraints)

In [79]:
def solve_cvxpy():
    w.value = W
    problem.solve(solver="ECOS")
    return x.value


solve_cvxpy().argmax(1)

array([16, 23,  6, 20, 18, 15, 24,  5,  8,  0, 17,  4, 11,  3, 14,  1, 22,
       25,  9, 31,  2, 19, 21, 12, 13, 30, 10, 29,  7, 28, 27, 26])

In [80]:
A_ub = np.tile(np.identity(M), (1, N))
b_ub = np.full(M, max_assigned)

A_eq = np.repeat(np.identity(N), M, axis=1)
b_eq = np.full(N, 1.0)

c = W.flatten()

In [81]:
def solve_highs():
    c = W.flatten()
    return sp.optimize.linprog(
        c=c,
        A_ub=A_ub,
        b_ub=b_ub,
        A_eq=A_eq,
        b_eq=b_eq,
        method="highs",
    ).x.reshape((N, M))


solve_highs().argmax(1)

array([16, 23,  6, 20, 18, 15, 24,  5,  8,  0, 17,  4, 11,  3, 14,  1, 22,
       25,  9, 31,  2, 19, 21, 12, 13, 30, 10, 29,  7, 28, 27, 26])

In [82]:
def solve_highspy():
    h = highspy.Highs()
    vars = {(i, j): h.addVariable(lb=0.0, ub=1.0) for i, j in product(range(N), range(M))}
    for j in range(M):
        h.addConstr(sum(vars[i, j] for i in range(N)) <= max_assigned)
    for i in range(N):
        h.addConstr(sum(vars[i, j] for j in range(M)) == 3.0)
    h.addConstr(sum(vars[i, j] * (10 + (j % 3)) for i in range(N) for j in range(M)) == 123.0)
    h.minimize(sum(W[i, j] * vars[i, j] for i in range(N) for j in range(M)))

    h.getLp()
    # print(lp.a_matrix_.start_)
    # print(lp.a_matrix_.index_)
    # print(lp.a_matrix_.value_)
    h.run()

    solution = h.getSolution()
    return np.array(solution.col_value).reshape((N, M))


solve_highspy().argmax(1)

array([18,  0,  6,  0,  0, 27,  9,  0,  9,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])

In [130]:
h = highspy.Highs()
h.setOptionValue("presolve", "off")
h.setOptionValue("log_to_console", False)
h.setOptionValue("output_flag", False)
vars = {(i, j): h.addVariable(lb=0.0, ub=1.0) for i, j in product(range(N), range(M))}
for i in range(N):
    h.addConstr(sum(vars[i, j] for j in range(M)) == 1.0)
for j in range(M):
    h.addConstr(sum(vars[i, j] for i in range(N)) <= max_assigned)
h.addConstr(sum(vars[i, j] for i in range(N) for j in range(M)) == 1.0)
h.minimize(sum(W[i, j] * vars[i, j] for i in range(N) for j in range(M)))
lp = h.getLp()

In [131]:
def solve_highspy_lp():
    lp.col_cost_[:] = W.flat
    lp.row_upper_[N : N + M] = [max_assigned]
    lp.row_upper_[-1] = [4]

    # h = highspy.Highs()
    h.passModel(lp)
    h.run()
    return np.reshape(h.getSolution().col_value, (N, M))


solve_highspy_lp().argmax(1)

array([16, 21,  9, 20,  9, 15, 24, 24,  8, 22, 17,  0, 11, 29, 21, 23, 22,
       17, 31, 31,  2,  2, 21,  5, 13,  0, 25, 27,  0, 28, 23, 26])

In [132]:
%%timeit
solve_highspy_lp()

302 μs ± 5.64 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [86]:
%%timeit
solve_cvxpy()

3.71 ms ± 59.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [87]:
%%timeit
solve_highspy()

17.2 ms ± 745 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [88]:
%%timeit
solve_highs()

242 ms ± 5.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
