This notebook tests the maximum WTA problem size that can be computed in five minutes on this machine. The problem size in three scenarios is increased until the computation time exceeds five minutes.

We then plot the results.

In [3]:
# %load wta.py
import cvxpy as cp
import numpy as np
from microbench import MicroBench
basic_bench = MicroBench()
from time import time 
np.random.seed(1)

@basic_bench
def solve_wta(q, V, W, integer=True, lasso=False, verbose=False):
    """
    Solve the weapon-target assignment problem.
    Inputs:
        q: (n,m) array of survival probabilities
        V: (n,) array of target values
        W: (m,) array of weapon counts
        integer: boolean, whether to solve the integer or continuous problem
        lasso: boolean, whether to solve the lasso problem
    """
    if len(q.shape) == 1:
        n = q.shape[0]
        m = 1
        q = q.reshape((n,m))
    else:
        n, m = q.shape

    # Define the CVXPY problem.
    if integer:
        x = cp.Variable((n,m), integer=True)
    else:
        x = cp.Variable((n,m))
    weighted_weapons = cp.multiply(x, np.log(q)) # (n,m)
    survival_probs = cp.exp(cp.sum(weighted_weapons, axis=1)) # (n,)
    
    if lasso:
        v = 0.1*min(V)
        obj_fun = V@survival_probs + v*cp.sum(x)
    else:
        obj_fun = V@survival_probs
    objective = cp.Minimize(obj_fun)
    cons = [cp.sum(x, axis=0) <= W, x >= 0]

    # Solve
    prob = cp.Problem(objective, cons)
    prob.solve(verbose=verbose)
    if prob.status != 'optimal':    
        print("Problem status:", prob.status)

    return prob.value, x.value

def get_final_surv_prob(q, x):
    """
    Get the final probability of kill for each target.
    Inputs:
        q: (n,m) array of survival probabilities
        x: (n,m) array of weapon assignments
    """
    return np.prod(np.power(q, x), axis=1)

def get_ind_value(q, V, W):
    """
    Get the total value if each platform solves independently.
    Inputs:
        q: (n,m) array of survival probabilities
        V: (n,) array of target values
        W: (m,) array of weapon counts
    """
    # Loop through platforms
    n, m = q.shape
    x = np.zeros((n,m))
    for i in range(m):
        # Solve the WTA problem for platform i
        q_i = q[:,i]
        pv, x_i = wta(q_i, V, W[i])
        x[:,i] = x_i[:,0]
    return V@get_final_surv_prob(q, x), x

def generate_random_problem(n=5, m=3):
    """
    Generate a random problem.
    Inputs:
        n: number of targets
        m: number of weapon types
    """
    #np.random.seed(1)
    q = np.random.rand(n,m)*.8 + .1 # Survival probability
    V = np.random.rand(n)*100 # Value of each target
    W = np.random.randint(1,10,m) # Number of weapons of each type
    return q, V, W

In [4]:
# n = 10, m varies
t = 0
n = 10
m = 20
m_results = []
while t < 5*60:
    # Generate random problem
    for i in range(10):
        q, V, W = generate_random_problem(n, m)
        # tic
        t = time()
        #@basic_bench
        solve_wta(q, V, W)
        # toc
        t = time() - t
        print(n, m, t)
        m_results.append((n, m, i, t))
    m += 1

    

print(m_results)
    

optimal
10 20 2.391085386276245
optimal
10 20 20.251699209213257
optimal
10 20 0.9659976959228516
optimal
10 20 43.98425650596619
optimal
10 20 277.1131443977356
optimal
10 20 91.27977085113525
optimal
10 20 4.179135322570801
optimal
10 20 133.22278833389282
optimal
10 20 1.7264719009399414
optimal
10 20 4.19076943397522
optimal
10 21 6.225506782531738
optimal
10 21 14.212860584259033
optimal
10 21 3.8664727210998535
optimal
10 21 9.952459812164307
optimal
10 21 6.105466842651367
optimal
10 21 3.7517969608306885
optimal
10 21 44.332135915756226
optimal
10 21 247.22594809532166
optimal
10 21 0.7213754653930664
optimal
10 21 2.8147335052490234
optimal
10 22 3.143475294113159
optimal
10 22 8.359628200531006
optimal
10 22 6.739048004150391
optimal
10 22 5.231008529663086
optimal
10 22 6.550501346588135
optimal
10 22 50.21237564086914
optimal
10 22 73.98284482955933
optimal
10 22 41.84234571456909
optimal
10 22 2.282503604888916
optimal
10 22 36.8442063331604
optimal
10 23 353.1196329593658

In [None]:
# n varies, m = 10
max_t = 0
n = 20
m = 10
n_results = []
while max_t < 5*60:
    # Generate random problem
    for i in range(10):
        q, V, W = generate_random_problem(n, m)
        # tic
        t = time()
        #@basic_bench
        solve_wta(q, V, W)
        # toc
        t = time() - t
        max_t = max(max_t, t)
        print(n, m, t)
        n_results.append((n, m, i, t))
    n += 1

    

print(n_results)

In [None]:
# n varies, m varies
max_t = 0
n = 14
m = 14
nm_results = []
while max_t < 5*60:
    # Generate random problem
    for i in range(10):
        q, V, W = generate_random_problem(n, m)
        # tic
        t = time()
        #@basic_bench
        solve_wta(q, V, W)
        # toc
        t = time() - t
        max_t = max(max_t, t)
        print(n, m, t)
        nm_results.append((n, m, i, t))
    if n == m:
        n += 1
    else:
        m += 1

    

print(nm_results)