In [145]:
%matplotlib inline
import torch
import cvxpy as cp
from cvxpylayers.torch import CvxpyLayer
import numpy as np


In [146]:
# the linear program:

# max Sx <= b, get x*, then take b2=(b - x*)
# then max Sx <= b2 internal only
# all LP relaxations

In [147]:
# we're assuming each type here is specific to one hospital
# the fact that compatibility doesn't actually depend on hospital should be captured in the S matrix

In [148]:
# layer 1
num_centers = 2
num_structures = 8
lam = 0.1
num_types = 3*num_centers
x1 = cp.Variable(num_structures)
S = cp.Parameter( (num_types, num_structures) ) # valid structures
w = cp.Parameter(num_structures) # structure weight
z = cp.Parameter(num_structures) # control parameter
b = cp.Parameter(num_types) # max bid
constraints = [x1 >= 0,S @ x1 <= b]

objective = cp.Maximize( (w.T @ x1) - cp.norm(x1 - z, 2) - lam*cp.norm(x1, 2) )
problem = cp.Problem(objective, constraints)

In [149]:
layer1 = CvxpyLayer(problem, parameters = [S, w, b, z], variables=[x1])

In [150]:
test_S = torch.tensor([[1.0,1.0,0.0,0.0,0.0,0.0],
[0.0,0.0,0.0,1.0,1.0,0.0],
[1.0,0.0,0.0,0.0,1.0,0.0],
 [0.0,1.0,0.0,1.0,0.0,0.0],
 [0.0,1.0,1.0,0.0,0.0,0.0],
 [0.0,0.0,0.0,1.0,1.0,0.0],
 [0.0,1.0,0.0,0.0,0.0,1.0],
 [0.0,0.0,1.0,0.0,1.0,0.0]], requires_grad=False).t() # a constant
testW = torch.ones(num_structures) # will come out of NN later
testZ = torch.zeros(num_structures)
testB = 5.0*torch.ones(num_types)

In [151]:
x1_out, = layer1(test_S, testW, testB, testZ)

In [152]:
resulting_allocations = (test_S @ x1_out)

In [153]:
remaining_vals = (testB - resulting_allocations)

In [154]:
center1, center2 = remaining_vals[0:3], remaining_vals[3:]

In [155]:
num_types_2 = 3
num_structures_2 = 2
x2 = cp.Variable(num_structures_2)
S2 = cp.Parameter( (num_types_2, num_structures_2) )
w2 = cp.Parameter(num_structures_2)
b2 = cp.Parameter(num_types_2)
constraints2 = [x2 >= 0,S2 @ x2 <= b2]

objective2 = cp.Maximize( (w2.T @ x2) )
problem2 = cp.Problem(objective2, constraints2)

In [156]:
test_S2 = torch.tensor([[1.0,1.0,0.0],[0.0,1.0,1.0]], requires_grad=False).t()
testW2 = torch.ones(num_structures_2)
layer2 = CvxpyLayer(problem2, parameters = [S2, w2, b2], variables = [x2])

In [157]:
center1_internal_result, = layer2(test_S2, testW2, center1)

In [158]:
center1_internal_matched = test_S2 @ center1_internal_result

In [159]:
center1_internal_matched

tensor([-4.3194e-16, -5.9655e-16, -1.6462e-16])

In [160]:
center1_total_pairs = center1_internal_matched + resulting_allocations[0:3]

In [161]:
center1_total_pairs

tensor([2.5000, 5.0000, 2.5000])