In [7]:
import numpy as np
import cvxpy as cp
from scipy.special import lambertw
from time import time

In [2]:
import wta


In [5]:

def lambertw_prox(q, bv, y, v=1, verbose=False):

    d = q@y
    z = d - lambertw(bv*np.exp(d))
    lam = np.exp(z)
    x = y - lam*v*q
    return x.real

# WTA resolvent
class lambertResolvent:
    '''Resolvent function'''
    def __init__(self, data):
        self.data = data
        self.v0 = data['v0']
        # prob, w, y = buildWTAProb(data)
        # self.prob = prob
        # self.w = w
        # self.y = y
        self.q = np.log(self.data['QQ'])
        self.b = self.q@self.q
        self.bv = self.b*self.data['VV']
        self.shape = self.data['Q'].shape
        #self.log = []

    def __call__(self, x):
        t = time()
        y = x[self.data['s'],:]
        # self.prob.solve(verbose=False, ignore_dpp=True)
        # st = time()
        w = lambertw_prox(self.q, self.bv, y, v=self.data['VV'])
        x[self.data['s'],:] = w
        #self.log.append((t,st))
        # You can implement logging here
        #self.log.append(fullValue(self.data, proj_full(x)))
        return x

    def __repr__(self):
        return "wtaResolvent"



In [13]:
def buildWTAProb(data):
    '''Builds the WTA problem'''
    QQ = data['QQ']
    VV = data['VV']

    # Get the number of targets and weapons
    m = QQ.shape

    # Create the variable
    w = cp.Variable(m) #, nonneg=True) #, integer=True)

    # Create the parameter
    y = cp.Parameter(m) # resolvent parameter, sum of weighted previous resolvent outputs and v_i

    # Create the objective
    weighted_weapons = cp.multiply(w, np.log(QQ)) # (wpns)
    survival_probs = cp.exp(cp.sum(weighted_weapons)) # ()
    obj = cp.Minimize(VV*survival_probs + .5*cp.sum_squares(w - y))

    # Create the problem 
    prob = cp.Problem(obj) #, cons)

    # Return the problem, variable, and parameters
    return prob, w, y

class cvxResolvent:
    '''Resolvent function'''
    def __init__(self, data):
        self.data = data
        self.v0 = data['v0']
        self.prob, self.w, self.y = buildWTAProb(data)
        self.shape = self.data['Q'].shape
        self.log = []

    def __call__(self, x):
        y = x[self.data['s'],:]
        self.y.value = y
        self.prob.solve(verbose=False, ignore_dpp=True)
        
        w = self.w.value
        x[self.data['s'],:] = w
        return x

    def __repr__(self):
        return "wtaResolvent"

In [48]:
# Loop through wpn sizes and compare solve time for each method on first and next forty iterations
for n in range(10, 1211, 200):
    print(f"n = {n}")
    Q = 0.8*np.random.rand(n, n) + 0.1
    QQ = Q[0,:]
    VV = np.random.rand()*10
    data = {'Q':Q, 'QQ':QQ, 'VV':VV, 'v0':1, 's':0}
    resolvent = cvxResolvent(data)
    resolvent2 = lambertResolvent(data)
    np.random.seed(0)
    x = np.random.rand(n, n)
    t = time()
    y = resolvent(x)
    print(f"cvxpy initial: {time()-t}")

    t = time()
    y = resolvent2(x)
    print(f"lambert initial: {time()-t}")

    a = x
    t = time()
    for i in range(100):
        a = resolvent(a)
    print(f"cvxpy: {time()-t}")


    a = x
    t = time()
    for i in range(100):
        a = resolvent2(a)
    print(f"lambert: {time()-t}")
    # print("cvxpy")
    # %timeit resolvent(x)
    # print("lambert")
    # %timeit resolvent2(x)

n = 10
cvxpy initial: 0.024228811264038086
lambert initial: 0.0


cvxpy: 2.823678731918335
lambert: 0.007822751998901367
n = 210
cvxpy initial: 0.024619340896606445
lambert initial: 0.0
cvxpy: 3.099766492843628
lambert: 0.004220724105834961
n = 410
cvxpy initial: 0.031328678131103516
lambert initial: 0.0
cvxpy: 3.2674660682678223
lambert: 0.00842142105102539
n = 610
cvxpy initial: 0.040284156799316406
lambert initial: 0.0
cvxpy: 3.3884711265563965
lambert: 0.0
n = 810
cvxpy initial: 0.049132347106933594
lambert initial: 0.0
cvxpy: 3.442697048187256
lambert: 0.008002281188964844
n = 1010
cvxpy initial: 0.03294062614440918
lambert initial: 0.0
cvxpy: 3.5520071983337402
lambert: 0.003082752227783203
n = 1210
cvxpy initial: 0.04124927520751953
lambert initial: 0.0
cvxpy: 4.09358811378479
lambert: 0.007039785385131836


In [49]:

print("cvxpy")
%timeit resolvent(x)
print("lambert")
%timeit resolvent2(x)

cvxpy
38.6 ms ± 4.69 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
lambert
29.8 µs ± 3.75 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [46]:
y1 = resolvent(np.zeros((n,n)))
y2 = resolvent2(np.zeros((n,n)))

assert(np.isclose(y1, y2, atol=1e-4).all())