# Minimization subject to bound constraints

This is an example IPython notebook showing the usage of the Python wrapper for L-BFGS-B-NS for minimization of differentiable non-smooth functions subject to bound constraints.

### Function to minimize

The function to minimize in this example is a sparse Poisson regression as follows:

$$
f(w) = \mathbf{s}^T \mathbf{w} - \sum_i y_i \: \log(\mathbf{w}^T \mathbf{x}_i) + \lambda \lVert \mathbf{w} \rVert^2
\\
s.t. \:\: \mathbf{w} \ge 0
$$

Gradient is given by:
$$
\nabla f(\mathbf{w}) = \mathbf{s} + 2 \lambda \mathbf{w} - \sum_i \frac{y_i}{\mathbf{w}^T \mathbf{x}_i} \mathbf{x}_i
$$

For more details, see [Fast Non-Bayesian Poisson Factorization for Implicit-Feedback Recommendations](https://arxiv.org/abs/1811.01908).

The function is convex but not smooth, so L-BFGS-B will fail to optimize it, while L-BFGS-B-NS will indeed reach the optimum. Note that there are many other methods that can also solve the problem (e.g. Newton and truncated Newton, simple gradient descent, proximal methods, etc.), with varying degrees of scalability.

In [1]:
import numpy as np
from scipy.optimize import rosen, rosen_der

Data is copy-pasted here:

In [2]:
def eval_f(w, X, y, s, lam):
    return s.dot(w) + lam*(w.dot(w)) - (y * np.log(X.dot(w.reshape((-1,1)))).reshape(-1)).sum()

def eval_g(w, X, y, s, lam):
    return s + 2 * lam * w - (X * (y / X.dot(w.reshape((-1,1))).reshape(-1)).reshape((-1,1))).sum(axis=0)

y = np.array([3, 3, 5, 5, 5, 5, 4, 5, 4, 3, 5, 5, 5, 5, 5, 4, 4, 5, 5, 4, 5, 4,
       5, 4, 5, 4, 5, 5, 4, 4, 2, 5, 1, 3, 1, 4, 2, 5, 1, 3, 4, 4, 5, 4,
       3, 3, 4, 2, 2, 3, 3, 4, 5, 4, 4, 3, 1, 1, 4])
X = np.array(
[[0.26890296, 0.4793552, 0.67465879, 0.68423073, 0.82139131, 0.36466368, 0.67274953, 0.55307646, 0.40641524, 0.39795897, 0.6128798, 0.06134945, 0.76449976, 0.23364738, 0.16593851],
[0.81192287, 0.59692526, 0.4389951, 0.16875026, 0.78139256, 0.19409183, 0.99872077, 0.54260373, 0.00936649, 0.4187989, 0.24105192, 0.05313777, 0.71613341, 0.31649153, 0.82106911],
[0.9043358, 0.4547278, 0.59905518, 0.07055586, 0.52944031, 0.22131204, 0.38518965, 0.55999774, 0.81503312, 0.26284202, 0.7278154, 0.49204311, 0.30424486, 0.29040788, 0.24110222],
[0.93954204, 0.69262323, 0.5394226, 0.91872355, 0.08091639, 0.30942264, 0.3663583, 0.01174702, 0.54515794, 0.385575, 0.46514409, 0.05261418, 0.98054913, 0.8753724, 0.31261675],
[0.72776493, 0.80993513, 0.60148481, 0.13080874, 0.01089048, 0.29257282, 0.48388457, 0.07567084, 0.08349122, 0.97949434, 0.62427096, 0.82290832, 0.78100995, 0.51023816, 0.60696214],
[0.28245125, 0.93599604, 0.61686906, 0.40684928, 0.05017834, 0.10484496, 0.44393934, 0.02376267, 0.28857097, 0.21489546, 0.22151066, 0.19268382, 0.83915291, 0.46604407, 0.54313911],
[0.12663989, 0.7322891, 0.94973282, 0.25926606, 0.2837686, 0.44705647, 0.15861842, 0.56858139, 0.44953271, 0.8874232, 0.21483149, 0.16794997, 0.49563388, 0.51662551, 0.4061961, ],
[0.17807243, 0.22776821, 0.09641054, 0.54410508, 0.8361959, 0.34664012, 0.91259992, 0.68391931, 0.82603713, 0.92485427, 0.71894555, 0.77042814, 0.07401756, 0.18868538, 0.31405159],
[0.94695735, 0.61125081, 0.20179795, 0.26678681, 0.2997066, 0.71146243, 0.50951737, 0.21261754, 0.5805612, 0.87572001, 0.12632909, 0.99245371, 0.5733112, 0.07551522, 0.92469878],
[0.9371519, 0.23500392, 0.71728488, 0.04395249, 0.30362461, 0.48304848, 0.36442034, 0.19847324, 0.69868433, 0.92530212, 0.09985799, 0.23013764, 0.28484652, 0.66797306, 0.38661698],
[0.25587843, 0.03399583, 0.01537175, 0.54221875, 0.85536029, 0.44652281, 0.23563715, 0.93653405, 0.61920888, 0.86457616, 0.12811339, 0.44123838, 0.85755397, 0.89819361, 0.87726389],
[0.25054626, 0.44829559, 0.12635292, 0.51156824, 0.86558219, 0.35235348, 0.15564043, 0.30726115, 0.74287455, 0.09127636, 0.73558171, 0.61392196, 0.74025875, 0.00904973, 0.72047264],
[0.15508939, 0.61158701, 0.55495309, 0.0812221, 0.95856954, 0.5013834, 0.49989148, 0.54441703, 0.6879023, 0.67713299, 0.34870043, 0.21310997, 0.60768479, 0.77411473, 0.61732086],
[0.37747477, 0.72004947, 0.53793262, 0.43515652, 0.10903362, 0.87064151, 0.16265942, 0.75884205, 0.6356856, 0.45822136, 0.05152794, 0.30729255, 0.52923374, 0.46875605, 0.6271794, ],
[0.46135471, 0.33155524, 0.89312642, 0.20899859, 0.18088945, 0.07159581, 0.99880462, 0.20553914, 0.50899867, 0.23423862, 0.33686407, 0.42364602, 0.35306633, 0.97153865, 0.00137032],
[0.96778285, 0.34728093, 0.18688127, 0.2542811, 0.36221375, 0.14294638, 0.72696242, 0.38599482, 0.36326628, 0.22393526, 0.26002916, 0.54607151, 0.36103396, 0.02219305, 0.30208314],
[0.02755674, 0.20926923, 0.28707448, 0.274101, 0.9223348, 0.83516646, 0.6913888, 0.3682531, 0.0705815, 0.12776716, 0.72999313, 0.75732376, 0.12589225, 0.14392424, 0.25043938],
[0.79663973, 0.21676126, 0.03685473, 0.82128472, 0.9888421, 0.46434059, 0.66144242, 0.28239108, 0.46053527, 0.46871002, 0.32472131, 0.8106264, 0.93734508, 0.50881284, 0.47163883],
[0.51584772, 0.75746362, 0.54070642, 0.46964418, 0.16099654, 0.28888493, 0.87102677, 0.78683819, 0.92952629, 0.01808714, 0.49057107, 0.53086522, 0.03365724, 0.72380516, 0.69552564],
[0.2486874, 0.80772801, 0.71071957, 0.62115299, 0.34464496, 0.65273483, 0.51526695, 0.86960604, 0.35334841, 0.19872339, 0.59782921, 0.9076014, 0.62331728, 0.80105268, 0.81803121],
[0.73295651, 0.03718759, 0.19089695, 0.17250281, 0.83308635, 0.49478076, 0.78950798, 0.70694848, 0.8246838, 0.2408467, 0.57902663, 0.5776183, 0.70156551, 0.60343239, 0.26797012],
[0.93397292, 0.24993491, 0.38589828, 0.53411427, 0.80735088, 0.59396222, 0.78410159, 0.02210724, 0.36938248, 0.48856187, 0.25885675, 0.51478472, 0.38323223, 0.93483245, 0.5278117, ],
[0.92647018, 0.65173566, 0.93042034, 0.45536654, 0.11087123, 0.95793742, 0.10449456, 0.00994633, 0.45435149, 0.1560623, 0.95616025, 0.02818874, 0.93699319, 0.05904651, 0.26720624],
[0.98754218, 0.23147189, 0.74851453, 0.99138613, 0.35596539, 0.42521882, 0.9313783, 0.4589656, 0.8264721, 0.74497391, 0.1266161, 0.49408128, 0.66161822, 0.70914888, 0.50270141],
[5.41538225e-02, 1.48585021e-01, 1.42823795e-01, 5.18391921e-02, 2.44376266e-01, 3.64592703e-01, 4.58805738e-01, 6.34530785e-01, 4.71291844e-05, 6.69944889e-01, 8.60834516e-02, 2.64038655e-02, 7.26979305e-01, 2.98585685e-01, 7.27288103e-01],
[0.47736504, 0.14856354, 0.53232679, 0.08175642, 0.40394715, 0.83128722, 0.77530955, 0.23297161, 0.60899767, 0.13542966, 0.91816907, 0.9735848, 0.80945861, 0.38704194, 0.41832133],
[0.46385463, 0.03712152, 0.34295285, 0.92770956, 0.01578793, 0.89710183, 0.3876616, 0.31026901, 0.61385889, 0.77235037, 0.11182682, 0.89230274, 0.81432191, 0.76502007, 0.24453753],
[0.26014959, 0.07171677, 0.10600346, 0.33423277, 0.43895562, 0.95152668, 0.34771902, 0.5762678, 0.06888574, 0.35836251, 0.11985976, 0.87240663, 0.47504561, 0.73997711, 0.12543571],
[0.58692765, 0.16349972, 0.38272589, 0.11101517, 0.05152521, 0.21967367, 0.08353512, 0.91739459, 0.87417014, 0.8369208, 0.5019429, 0.92499628, 0.42421405, 0.21511599, 0.34331929],
[0.79131406, 0.86385854, 0.65651769, 0.30395871, 0.44538424, 0.63887687, 0.69919065, 0.40311601, 0.15856795, 0.24531276, 0.66382843, 0.82380756, 0.58915026, 0.49178452, 0.27363358],
[0.06164548, 0.28538775, 0.26417514, 0.97093013, 0.05328983, 0.23505105, 0.64322212, 0.15584001, 0.99838921, 0.32393999, 0.06597043, 0.04253201, 0.96011181, 0.64772181, 0.57325845],
[0.48967055, 0.29423546, 0.9660783, 0.69378867, 0.84083342, 0.71430695, 0.98801754, 0.87289694, 0.77678013, 0.74736133, 0.68398663, 0.46222109, 0.13298258, 0.41956661, 0.91677729],
[0.82866252, 0.85845207, 0.9815443, 0.31825541, 0.81971628, 0.59213125, 0.03726342, 0.77830176, 0.25838911, 0.80951617, 0.51393752, 0.80450235, 0.61833033, 0.46099634, 0.84182917],
[0.86813889, 0.03484487, 0.28036674, 0.94876037, 0.77910835, 0.45897993, 0.12323428, 0.6409947, 0.73642815, 0.55256346, 0.5624075, 0.94365872, 0.25034797, 0.23544492, 0.89781349],
[0.81601303, 0.82864021, 0.02312249, 0.75338475, 0.57007834, 0.41213263, 0.71978378, 0.16025205, 0.18239013, 0.93037745, 0.66429507, 0.0915567, 0.34797177, 0.05902354, 0.94456535],
[0.03654621, 0.43309998, 0.27649464, 0.78748632, 0.39135299, 0.36408323, 0.20541805, 0.22965936, 0.586909, 0.49436132, 0.16717467, 0.83147966, 0.19943844, 0.77709807, 0.51666557],
[0.07089073, 0.35301879, 0.87310177, 0.35124617, 0.98906433, 0.31582966, 0.30678718, 0.87581239, 0.57959415, 0.29594889, 0.18597422, 0.82258306, 0.85683342, 0.72858247, 0.32100066],
[0.11442869, 0.74939915, 0.73087364, 0.85698946, 0.94324916, 0.38179458, 0.61700002, 0.32764756, 0.40007236, 0.45801052, 0.76701058, 0.59603646, 0.8764864, 0.38530129, 0.09721154],
[0.85321362, 0.63889963, 0.28547568, 0.5883448, 0.3086684, 0.30789618, 0.37964855, 0.75876656, 0.95568724, 0.53956393, 0.95852592, 0.7400259, 0.23865271, 0.49396336, 0.47034203],
[0.50665892, 0.54931334, 0.28923876, 0.82574596, 0.12264865, 0.36723207, 0.27512329, 0.80415918, 0.22319407, 0.61716972, 0.88229105, 0.22447437, 0.71311534, 0.32772043, 0.65690548],
[0.38419129, 0.87911469, 0.39337355, 0.26150391, 0.87117385, 0.37244772, 0.28051146, 0.50002798, 0.8392601, 0.53672049, 0.19648284, 0.42845839, 0.27242811, 0.17959537, 0.88155908],
[0.27161705, 0.27204543, 0.31690306, 0.19119609, 0.84150988, 0.73021128, 0.3347017, 0.94023458, 0.37122267, 0.73789884, 0.14026661, 0.47495193, 0.71211803, 0.5089603, 0.55261062],
[0.58353128, 0.70115095, 0.64575901, 0.38333749, 0.13027427, 0.1508311, 0.64046588, 0.91003418, 0.32369276, 0.03891093, 0.02756851, 0.66156781, 0.96627056, 0.98877764, 0.91755132],
[0.50174668, 0.93002602, 0.73319439, 0.51011552, 0.06573494, 0.47306366, 0.57577709, 0.87092607, 0.99863104, 0.05329548, 0.03998092, 0.75447714, 0.96944527, 0.20951602, 0.39608867],
[0.35957079, 0.36944003, 0.3522621, 0.74755989, 0.30680238, 0.48237368, 0.57953709, 0.53442396, 0.92434142, 0.73492701, 0.23609777, 0.24344831, 0.28499476, 0.66044347, 0.85146711],
[0.67653632, 0.96993844, 0.09507005, 0.43829738, 0.66325213, 0.32400663, 0.41987306, 0.14386866, 0.99855599, 0.1765089, 0.45246587, 0.34398454, 0.85588093, 0.51812306, 0.9859589, ],
[9.38036658e-01, 4.05160481e-01, 5.99493314e-01, 8.92129475e-04, 1.71369120e-01, 2.67688984e-01, 2.62909538e-01, 3.65739034e-01, 1.84038891e-01, 6.05919580e-01, 4.57779823e-01, 9.47745094e-01, 1.80765298e-01, 6.90414007e-01, 8.13009587e-01],
[0.29006021, 0.83097929, 0.70830103, 0.29566137, 0.50939746, 0.4520919, 0.40311086, 0.59771864, 0.05819841, 0.32221178, 0.23415609, 0.48022314, 0.69060139, 0.07494378, 0.25356048],
[0.06592471, 0.1711945, 0.12302325, 0.03382358, 0.06716872, 0.67673438, 0.23598813, 0.47888445, 0.50767577, 0.01974602, 0.25483723, 0.99780243, 0.91212514, 0.35022052, 0.55635501],
[0.01852245, 0.99358311, 0.07082277, 0.3684959, 0.08999278, 0.95256405, 0.56136048, 0.57898186, 0.15610655, 0.11075968, 0.40092856, 0.29817516, 0.86190741, 0.10824085, 0.01938995],
[0.18884355, 0.4370917, 0.59668846, 0.59353179, 0.60153332, 0.82905647, 0.11314567, 0.12016627, 0.11992477, 0.22933418, 0.67068034, 0.26108968, 0.28377356, 0.27069217, 0.15369935],
[0.81473416, 0.43563647, 0.97971084, 0.07139221, 0.20060753, 0.13765658, 0.18884694, 0.27306173, 0.26223773, 0.84466746, 0.58753312, 0.97400398, 0.12580954, 0.63899443, 0.55245007],
[0.09920371, 0.44449103, 0.51382409, 0.01650504, 0.94035061, 0.22528694, 0.0370532, 0.46343415, 0.63176573, 0.82459713, 0.69767468, 0.63462817, 0.69951461, 0.49297231, 0.13924186],
[0.182928, 0.35377897, 0.6810879, 0.92709244, 0.63293265, 0.61762203, 0.93029598, 0.93104051, 0.15896521, 0.13441404, 0.42880622, 0.53503705, 0.87517637, 0.51451114, 0.35880925],
[0.13806742, 0.94372984, 0.39415335, 0.70962012, 0.36660026, 0.23569752, 0.71385347, 0.69712394, 0.82799552, 0.21201313, 0.2308793, 0.64042653, 0.15657698, 0.43215713, 0.67008979],
[0.80874168, 0.74862508, 0.02841058, 0.86491928, 0.2357665, 0.08097649, 0.25021752, 0.10530746, 0.24636018, 0.09814915, 0.64252061, 0.38261429, 0.20899012, 0.51627095, 0.19799734],
[0.7695709, 0.3344798, 0.1377126, 0.78875345, 0.21736218, 0.70249248, 0.80379623, 0.47566725, 0.8141711, 0.47236393, 0.43876934, 0.5841235, 0.02977462, 0.03565612, 0.39177575],
[0.88439056, 0.44441623, 0.68876386, 0.82874527, 0.74529948, 0.23510302, 0.82296779, 0.0912073, 0.97620601, 0.83033185, 0.77419735, 0.65947205, 0.01839246, 0.93563031, 0.66161265],
[0.7100493, 0.01777483, 0.17559461, 0.12476507, 0.82772198, 0.51597489, 0.19890218, 0.60864872, 0.27460697, 0.12983167, 0.71522872, 0.9004975, 0.69208513, 0.18076654, 0.96075785]]
)
s = np.array([751.34736517, 733.67929468, 750.75105464, 744.55563618,
       775.88316266, 742.39858199, 734.17353075, 746.67685882,
       756.79907005, 751.75326741, 732.18001409, 757.18726975,
       756.02923567, 742.21839022, 745.22952588])

Random starting point:

In [3]:
np.random.seed(1)
x0 = np.random.gamma(1, 1, size=X.shape[1])
reg_param = 1e-3
eval_f(x0, X, y, s, reg_param)

5953.863287134411

Frist try: L-BFGS-B from SciPy:

In [4]:
from scipy.optimize import minimize

minimize(x0 = x0, fun = eval_f, jac = eval_g, args=(X, y, s, reg_param), bounds=[(0,None)] * x0.shape[0])

  
  """


      fun: 3985.4767097696954
 hess_inv: <15x15 LbfgsInvHessProduct with dtype=float64>
      jac: array([711.17565903, 694.41870389, 712.70413774, 708.73256482,
       735.67934982, 702.42271243, 693.08650993, 705.93962614,
       715.50899896, 714.41274627, 695.85100937, 712.49948523,
       707.15182312, 704.47613935, 703.06070253])
  message: b'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 6
      nit: 2
   status: 0
  success: True
        x: array([3.62372219e-01, 8.55638622e-01, 7.68127842e-05, 2.41766511e-01,
       1.06581405e-01, 6.50623494e-02, 1.38416251e-01, 2.84721343e-01,
       3.39436578e-01, 5.19752570e-01, 3.64879473e-01, 7.76230845e-01,
       1.53599841e-01, 1.41340915e+00, 1.86486873e-02])

Despite the success message, it didn't really reach the optimum, which L-BFGS-B-NS is indeed able to find:

In [5]:
from lbfgsb_ns import minimize_lbfgsb_ns

minimize_lbfgsb_ns(x0 = x0, fun = eval_f, jac = eval_g, args=(X, y, s, reg_param), bounds=[(0, None)] * x0.shape[0])

  
  """


      fun: 638.7663441540454
 hess_inv: <15x15 LbfgsInvHessProduct with dtype=float64>
      jac: array([25.42669025, 18.23529538, 47.3814997 , 95.84634253, 79.92742222,
       67.72082891,  0.        , 39.24281044,  0.        , 63.15766363,
       94.66112649,  0.        ,  0.        , 32.2743374 ,  0.        ])
  message: b'ABNORMAL_TERMINATION_IN_LNSRCH'
     nfev: 102
      nit: 64
   status: 2
  success: False
        x: array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.01860117, 0.        , 0.02179748, 0.        ,
       0.        , 0.10832202, 0.14202509, 0.        , 0.00597394])