In [16]:
import math
import sys
from itertools import product

sys.path.append("../ares-sc2")  # required to import sc2_helper
sys.path.append("../ares-sc2/src")  # required to import ares
sys.path.append("phantom")

import cvxpy as cp
import highspy
import numpy as np
import scipy as sp

In [17]:
cp.installed_solvers()

['CLARABEL', 'ECOS', 'ECOS_BB', 'HIGHS', 'OSQP', 'QOCO', 'SCIPY', 'SCS']

In [18]:
SOLVER = cp.ECOS

In [19]:
N = 4
M = 3
max_assigned = math.ceil(N / M)
np.random.seed(0)
W = sp.stats.lognorm.rvs(s=3, size=(N, M))
B = np.full(M, max_assigned)
W

array([[1.98771721e+02, 3.32168314e+00, 1.88443654e+01],
       [8.31041387e+02, 2.71150487e+02, 5.32992169e-02],
       [1.72923681e+01, 6.35037247e-01, 7.33698898e-01],
       [3.42737789e+00, 1.54053647e+00, 7.84781719e+01]])

In [20]:
x = cp.Variable((N, M), "x")
w = cp.Parameter((N, M), name="w")
constraints = [
    cp.sum(x, 0) <= max_assigned,  # enforce even distribution
    cp.sum(x, 1) == 1,
    x >= 0,
]
problem = cp.Problem(cp.Minimize(cp.vdot(w, x)), constraints)

In [21]:
def solve_cvxpy():
    w.value = W
    problem.solve(solver="ECOS")
    return x.value


solve_cvxpy().argmax(1)

array([1, 2, 2, 1])

In [22]:
A_ub = np.tile(np.identity(M), (1, N))
b_ub = np.full(M, max_assigned)

A_eq = np.repeat(np.identity(N), M, axis=1)
b_eq = np.full(N, 1.0)

c = W.flatten()

In [23]:
def solve_highs():
    c = W.flatten()
    return sp.optimize.linprog(
        c=c,
        A_ub=A_ub,
        b_ub=b_ub,
        A_eq=A_eq,
        b_eq=b_eq,
        method="highs",
    ).x.reshape((N, M))


solve_highs().argmax(1)

array([1, 2, 2, 1])

In [36]:
def solve_highspy():
    h = highspy.Highs()
    vars = {(i, j): h.addVariable(lb=0.0, ub=1.0) for i, j in product(range(N), range(M))}
    for j in range(M):
        h.addConstr(sum(vars[i, j] for i in range(N)) <= max_assigned)
    for i in range(N):
        h.addConstr(sum(vars[i, j] for j in range(M)) == 3.0)
    h.addConstr(sum(vars[i, j] * (10 + (j % 3)) for i in range(N) for j in range(M)) == 123.0)
    h.minimize(sum(W[i, j] * vars[i, j] for i in range(N) for j in range(M)))

    lp = h.getLp()
    print(lp.a_matrix_.start_)
    print(lp.a_matrix_.index_)
    print(lp.a_matrix_.value_)
    h.run()

    solution = h.getSolution()
    return np.array(solution.col_value).reshape((N, M))


solve_highspy().argmax(1)

[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36]
[0, 3, 7, 1, 3, 7, 2, 3, 7, 0, 4, 7, 1, 4, 7, 2, 4, 7, 0, 5, 7, 1, 5, 7, 2, 5, 7, 0, 6, 7, 1, 6, 7, 2, 6, 7]
[1.0, 1.0, 10.0, 1.0, 1.0, 11.0, 1.0, 1.0, 12.0, 1.0, 1.0, 10.0, 1.0, 1.0, 11.0, 1.0, 1.0, 12.0, 1.0, 1.0, 10.0, 1.0, 1.0, 11.0, 1.0, 1.0, 12.0, 1.0, 1.0, 10.0, 1.0, 1.0, 11.0, 1.0, 1.0, 12.0]


array([0, 0, 0, 0])

In [25]:
h = highspy.Highs()
h.setOptionValue("presolve", "off")
h.setOptionValue("log_to_console", False)
h.setOptionValue("output_flag", False)
vars = {(i, j): h.addVariable(lb=0.0, ub=1.0) for i, j in product(range(N), range(M))}
for i in range(N):
    h.addConstr(sum(vars[i, j] for j in range(M)) == 1.0)
for j in range(M):
    h.addConstr(sum(vars[i, j] for i in range(N)) <= max_assigned)
h.minimize(sum(W[i, j] * vars[i, j] for i in range(N) for j in range(M)))
lp = h.getLp()

In [26]:
def solve_highspy_lp():
    lp.col_cost_ = W.flatten()
    lp.row_upper_ = np.concatenate((np.ones(N), np.full(M, max_assigned)))

    # h = highspy.Highs()
    h.passModel(lp)
    h.run()
    solution = list(h.getSolution().col_value)
    return np.asarray(solution).reshape((N, M))


solve_highspy_lp().argmax(1)

array([1, 2, 2, 1])

In [27]:
%%timeit
solve_highspy_lp()

51.4 μs ± 1.75 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [28]:
%%timeit
solve_cvxpy()

561 μs ± 20.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [29]:
%%timeit
solve_highspy()

[0, 2, 5, 8, 10, 13, 16, 18, 21, 24, 26, 29, 32]
[0, 3, 1, 3, 7, 2, 3, 7, 0, 4, 1, 4, 7, 2, 4, 7, 0, 5, 1, 5, 7, 2, 5, 7, 0, 6, 1, 6, 7, 2, 6, 7]
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0]
[0, 2, 5, 8, 10, 13, 16, 18, 21, 24, 26, 29, 32]
[0, 3, 1, 3, 7, 2, 3, 7, 0, 4, 1, 4, 7, 2, 4, 7, 0, 5, 1, 5, 7, 2, 5, 7, 0, 6, 1, 6, 7, 2, 6, 7]
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0]
[0, 2, 5, 8, 10, 13, 16, 18, 21, 24, 26, 29, 32]
[0, 3, 1, 3, 7, 2, 3, 7, 0, 4, 1, 4, 7, 2, 4, 7, 0, 5, 1, 5, 7, 2, 5, 7, 0, 6, 1, 6, 7, 2, 6, 7]
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0]
[0, 2, 5, 8, 10, 13, 16, 18, 21, 24, 26, 29, 32]
[0, 3, 1, 3, 7, 2, 3, 7, 0, 4,

In [30]:
%%timeit
solve_highs()

874 μs ± 26.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
