In [1]:
import oars
import wta
import numpy as np
import cvxpy as cp
from time import time

In [34]:
# Functions

def get_ind_value(q, V, W, **kwargs):
    """
    Get the total value if each platform solves independently.
    Inputs:
        q: (n,m) array of survival probabilities
        V: (n,) array of target values
        W: (m,) array of weapon counts
    """
    # Loop through platforms
    n, m = q.shape
    x = np.zeros((n,m))
    for i in range(m):
        # Solve the WTA problem for platform i
        q_i = q[:,i]
        pv, x_i = wta.wta(q_i, V, W[i], **kwargs)
        x[:,i] = x_i[:,0]
    return wta.fullValue(q, V, x), x

def buildWTAProb(data):
    '''
    Builds the WTA problem

    Inputs:
    data is a dictionary containing the following keys:
    QQ is the survival probabilities for the targets in the node
    VV is the value of the targets in the node
    WW is the number of weapons for all weapons
    v0 is the initial consensus parameter
    Lii is the diagonal element of L (typically 0)

    Returns:
    prob is the problem
    w is the variable
    v is the consensus parameter
    r is the resolvent parameter
    '''
    
    QQ = data['QQ']
    VV = data['VV']
    WW = data['WW']
    Lii = data['Lii']

    # Get the number of targets and weapons
    m = QQ.shape

    # Create the variable
    w = cp.Variable(m)

    # Create the parameter
    y = cp.Parameter(m) # resolvent parameter, sum of weighted previous resolvent outputs and v_i
    
    # Create the objective
    weighted_weapons = cp.multiply(w, np.log(QQ)) # (tgts, wpns)
    survival_probs = cp.exp(cp.sum(weighted_weapons, axis=1)) # (tgts,)
    obj = cp.Minimize(VV@survival_probs + .5*cp.sum_squares((1-Lii)*w - y))

    # Create the constraints
    cons = [w >= 0, cp.sum(w, axis=0) <= WW]

    # Create the problem 
    prob = cp.Problem(obj, cons)

    # Return the problem, variable, and parameters
    return prob, w, y

# WTA resolvent
class wtaResolvent:
    '''Resolvent function'''
    def __init__(self, data):
        self.data = data
        prob, w, y = buildWTAProb(data)
        self.prob = prob
        self.w = w
        self.y = y
        self.shape = w.shape
        self.log = []

    def __call__(self, x):
        t = time()
        self.y.value = x
        self.prob.solve(verbose=False)
        st = time()
        self.log.append({'time':(t,st), 
                         'w':self.w.value, 
                         'x':x,
                         'value':fullValue(self.data, self.w.value)})
        # You can implement logging here
        #self.log.append(fullValue(self.data, self.w.value))
        return self.w.value

    def __repr__(self):
        return "wtaResolvent"


def fullValue(d, w):
    '''Get the full value of the problem'''
    return d['V']@wta.get_final_surv_prob(d['Q'], w)

class fVal:
    '''Class to hold the full value function'''
    def __init__(self, data):
        self.V = data['V']
        self.Q = data['Q']

    def __call__(self, x):
        return self.V@wta.get_final_surv_prob(self.Q, x)

    def __repr__(self):
        return "Full value function"



In [3]:
q = np.array([[0.665, 0.671, 0.857, 0.871, 0.756, 0.767],
 [0.664, 0.672, 0.86,  0.875, 0.76,  0.773],
 [0.679, 0.683, 0.872, 0.885, 0.775, 0.783],
 [0.658, 0.669, 0.853, 0.866, 0.753, 0.758],
 [0.672, 0.677, 0.868, 0.876, 0.765, 0.772]])
V = np.array([50, 10, 10, 30, 40])
WW = np.array([4, 4, 6, 6, 2, 2])

In [None]:
pv, xind = get_ind_value(q, V, WW, integer=False)
print(pv)
print(xind)

In [24]:
def generateSplitData(n, Q, V, WW, node_tgts, num_nodes_per_tgt, L, v0=None):
    '''Generate the data for the splitting'''
    tgts, wpns = Q.shape
    m = (tgts, wpns)
    data = []
    if v0 is None:
        v0 = 1/tgts*np.array(WW)*np.ones(m)
    for i in range(n):
        q = np.ones(m)
        q[node_tgts[i]] = Q[node_tgts[i]] # Only use the targets that are in the node
        v = np.zeros(tgts)
        v[node_tgts[i]] = V[node_tgts[i]] # Only use the targets that are in the node
        v = v/num_nodes_per_tgt # Divide the value by the number of nodes that have the target
        data.append({'QQ':q, 'VV':v, 'WW':WW, 'v0':v0, 'Lii':L[i,i], 'Q':Q, 'V':V})

    data.append({'Q':Q, 'V':V, 'WW':WW}) # Add the data for the full problem
    return data

In [30]:
# DR splitting
n = 2
Ldr = np.array([[0, 0], [2, 0] ])
Wdr = np.array([[1, -1], [-1, 1]])

node_tgts = [[0,1,2], [3,4]]
v0 = np.zeros(q.shape)
data = generateSplitData(n, q, V, WW, node_tgts, 1, Ldr, v0=v0)

In [31]:
data

[{'QQ': array([[0.665, 0.671, 0.857, 0.871, 0.756, 0.767],
         [0.664, 0.672, 0.86 , 0.875, 0.76 , 0.773],
         [0.679, 0.683, 0.872, 0.885, 0.775, 0.783],
         [1.   , 1.   , 1.   , 1.   , 1.   , 1.   ],
         [1.   , 1.   , 1.   , 1.   , 1.   , 1.   ]]),
  'VV': array([50., 10., 10.,  0.,  0.]),
  'WW': array([4, 4, 6, 6, 2, 2]),
  'v0': array([[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]),
  'Lii': 0,
  'Q': array([[0.665, 0.671, 0.857, 0.871, 0.756, 0.767],
         [0.664, 0.672, 0.86 , 0.875, 0.76 , 0.773],
         [0.679, 0.683, 0.872, 0.885, 0.775, 0.783],
         [0.658, 0.669, 0.853, 0.866, 0.753, 0.758],
         [0.672, 0.677, 0.868, 0.876, 0.765, 0.772]]),
  'V': array([50, 10, 10, 30, 40])},
 {'QQ': array([[1.   , 1.   , 1.   , 1.   , 1.   , 1.   ],
         [1.   , 1.   , 1.   , 1.   , 1.   , 1.   ],
         [1.   , 1.   , 1.   , 1. 

In [35]:
resolvents = [wtaResolvent]*n
itrs = 10
# t = time()
alg_x, results = oars.solve(n, data, resolvents, Wdr, Ldr, vartol=1e-3, itrs=itrs, parallel=False, verbose=True)   
# print(time()-t)
print(alg_x)
# full value
fullValue(data[0], alg_x)


Starting Serial Algorithm
Iteration 0
Serial Algorithm Loop Time: 0.2554805278778076
[[1.442 1.474 1.601 1.521 0.8   0.667]
 [0.171 0.137 0.991 0.928 0.043 0.021]
 [0.121 0.27  0.784 0.775 0.023 0.022]
 [0.997 0.782 1.427 1.416 0.502 0.632]
 [1.269 1.337 1.197 1.36  0.632 0.658]]


34.126101418980184

In [28]:
prov_val, x_full = wta.wta(q, V, WW, integer=False)
print(prov_val)
print(x_full)

optimal
33.647394862973385
[[1.572 0.    5.271 0.    2.    0.   ]
 [0.997 0.    0.    0.    0.    0.   ]
 [0.    0.933 0.    0.    0.    0.   ]
 [0.    0.    0.729 6.    0.    2.   ]
 [1.431 3.067 0.    0.    0.    0.   ]]


In [33]:
# Dump results to a text file
with open('results-DR-0.txt', 'w') as f:
    f.write(str(results))

In [36]:
results[0]['log'][0]

{'time': (1707421941.2711337, 1707421941.2917852),
 'w': array([[1.885, 1.873, 0.77 , 0.69 , 1.062, 1.042],
        [1.055, 1.052, 0.444, 0.393, 0.474, 0.476],
        [1.06 , 1.075, 0.429, 0.382, 0.464, 0.483],
        [0.   , 0.   , 0.   , 0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   , 0.   , 0.   , 0.   ]]),
 'x': array([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]]),
 'value': 81.06587647046965}

In [37]:
results[1]['log'][0]

{'time': (1707421941.2917852, 1707421941.307312),
 'w': array([[1.724, 1.735, 1.541, 1.379, 0.679, 0.691],
        [0.062, 0.093, 0.889, 0.787, 0.   , 0.   ],
        [0.073, 0.138, 0.857, 0.765, 0.   , 0.   ],
        [0.967, 0.883, 1.145, 1.036, 0.597, 0.603],
        [1.174, 1.15 , 1.147, 1.073, 0.725, 0.705]]),
 'x': array([[3.771, 3.747, 1.541, 1.379, 2.125, 2.083],
        [2.109, 2.104, 0.889, 0.787, 0.948, 0.951],
        [2.12 , 2.149, 0.857, 0.765, 0.927, 0.966],
        [0.   , 0.   , 0.   , 0.   , 0.   , 0.   ],
        [0.   , 0.   , 0.   , 0.   , 0.   , 0.   ]]),
 'value': 35.72552105101482}