In [1]:
from math import fabs
import numpy as np
from numpy.linalg import norm
from sklearn.datasets import make_regression

In [2]:
SAMPLES = 100
DIVERGENCE_VALUE = 10
MAX_ITERATIONS = 10000
STOP_THRESHOLD = 0.0001
DIMENSIONS = 7
LAMBDA = 0 #0.01
MU = 0 #0.0001

In [3]:
np.random.seed(0)

In [4]:
(X, y, coef) = make_regression(n_samples = SAMPLES, 
                               n_features = DIMENSIONS - 1, 
                               n_informative = DIMENSIONS - 1, 
                               effective_rank = 2,
                               n_targets = 1, 
                               coef = True,
                               bias = 3,
                               tail_strength = 0)

In [5]:
coef

array([47.93845494, 60.57119573, 63.74622774, 72.78881584, 81.19385617,
       11.56618719])

$\displaystyle \min_{\beta}\frac{1}{n}||y-X\beta||^2_{2}+\lambda||\beta||_1 + \mu||\beta||_2^2$

$L=\frac{1}{n}\displaystyle\sum_{i=1}^{n}\left(y_i - \left(\beta_0 + \displaystyle\sum_{j=1}^{6} \beta_j x_{ij}\right)\right)^2 + \lambda\displaystyle\sum_{j=0}^{6}|\beta_j| + \mu\displaystyle\sum_{j=0}^{6} \beta_j^2$

$\frac{\partial L}{\partial \beta_0} = -\frac{2}{n}\displaystyle\sum_{i=1}^{n}\left(y_i - \beta_0 -\displaystyle\sum_{j=1}^{6} \beta_j x_{ij}\right) + \lambda\frac{\beta_0}{|\beta_0|} + 2\mu\beta_0$

And for $\beta_{k\neq 0}$:

$\frac{\partial L}{\partial \beta_{k}} = -\frac{2}{n}\displaystyle\sum_{i=1}^{n}x_{ik}\left(y_i - \beta_0 -\displaystyle\sum_{j=1}^{6} \beta_j x_{ij}\right) + \lambda\frac{\beta_k}{|\beta_k|} + 2\mu\beta_k$

In [6]:
np.random.seed(0)

In [7]:
(X, y, coef) = make_regression(n_samples = SAMPLES, 
                               n_features = DIMENSIONS - 1, 
                               n_informative = DIMENSIONS - 1, 
                               effective_rank = 2,
                               n_targets = 1, 
                               coef = True,
                               bias = 3,
                               tail_strength = 0)

In [8]:
X = np.hstack((np.ones((SAMPLES, 1)), X))

In [9]:
# Auxiliary function to calculate gradient
# l = lambda
# m = mu
def calculate_gradient(x, y, l, m, current_params):
    db = np.zeros(DIMENSIONS)    
    
    # Common term
    common = (y - 
              current_params[0] - 
              np.sum(np.multiply(np.tile(current_params[1:], (SAMPLES, 1)), x[:, 1:]), axis=1))
    
    # Function for the regularisation factor
    def regularisation(param, l, m):
        return l * param / fabs(param) + 2 * m * param
    
    # db_0
    db[0] = - 2 / float(SAMPLES) * np.sum(common) + regularisation(current_params[0], l, m)

    # db_k, k != 0
    for k in range(1, DIMENSIONS):
        db[k] = -2 / float(SAMPLES) * np.sum(np.multiply(x[:, k], common)) + regularisation(current_params[k], l, m)
    return db

In [10]:
# Auxiliary function for backtracking line search
# Not using exact because I would have to do second partial derivatives
def line_search(current_params, gradient, beta, l, m):
    current_params = np.array(current_params)
    
    def _loss(params, l, m):
        y_pred = np.matmul(X, np.array(params))
        L = np.mean(np.power(y - y_pred, 2), axis = 0)
        L = L + l * np.sum(np.fabs(params)) + m * np.sum(np.power(params, 2))
        return L
        
    t = 1.0
    while _loss(current_params - t * gradient, l, m) > _loss(current_params, l, m) - \
                                                        t / 2.0  * np.matmul(gradient.T, gradient):
        #print((_loss(current_params - t * gradient, l, m), 
        #       _loss(current_params, l, m) - t / 2.0 * np.matmul(gradient.T, gradient)), 
        #      _loss(current_params - t * gradient, l, m) - 
        #       _loss(current_params, l, m) - t / 2.0 * np.matmul(gradient.T, gradient),
        #      t)
        t = t * beta
        
    return t

In [11]:
calculate_gradient(X, y, 1, 0.01, [1, 2, 3, 4, 5, 6, 7])

array([-2.90698589,  0.86882003,  0.87470091,  1.01822005,  1.388318  ,
        0.79615466,  0.80103953])

In [12]:
# We shouldn't initialise to zero to prevent divisions by zero
# when computing gradients
current_params = [0.1 for i in range(DIMENSIONS)]

In [13]:
# Coordinate gradient descent implementation
# Cycling through the dimensions
w = [np.array(current_params[:])]
it = 0
while it == 0 or (np.sum(np.abs(np.array(current_params) - np.array(w[-1]))) > STOP_THRESHOLD and \
                  np.all(np.array(current_params) - np.array(w[-1]) < DIVERGENCE_VALUE) and \
                  it < MAX_ITERATIONS):
    w.append(current_params)
    # Select the gradient for the next dimension
    d = it % DIMENSIONS
    g = np.zeros(DIMENSIONS)
    g[d] = calculate_gradient(X, y, LAMBDA, MU, current_params)[d]
    # Linear search in the direction of the selected dimension
    t = line_search(current_params, g, 0.99, LAMBDA, MU)
    #print("gradient: " + str(g))
    #print("linear search t: " + str(t))
    # Update parameters
    current_params = list(np.array(w[-1]) - t * np.array(g))[:]
    print("current_parameters: " + str(current_params))
    #print('---------------')
    # Wrapping up the iteration
    it = it + 1
w.append(current_params)

current_parameters: [2.936072683260908, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]
current_parameters: [2.936072683260908, 0.24971000603738006, 0.1, 0.1, 0.1, 0.1, 0.1]
current_parameters: [2.936072683260908, 0.24971000603738006, 0.298576624765687, 0.1, 0.1, 0.1, 0.1]
current_parameters: [2.936072683260908, 0.24971000603738006, 0.298576624765687, 0.20812639583164988, 0.1, 0.1, 0.1]
current_parameters: [2.936072683260908, 0.24971000603738006, 0.298576624765687, 0.20812639583164988, -0.21666586794965112, 0.1, 0.1]
current_parameters: [2.936072683260908, 0.24971000603738006, 0.298576624765687, 0.20812639583164988, -0.21666586794965112, 0.46747142720226553, 0.1]
current_parameters: [2.936072683260908, 0.24971000603738006, 0.298576624765687, 0.20812639583164988, -0.21666586794965112, 0.46747142720226553, 0.5028249508097417]
current_parameters: [2.937096676104212, 0.24971000603738006, 0.298576624765687, 0.20812639583164988, -0.21666586794965112, 0.46747142720226553, 0.5028249508097417]
current_parameters

current_parameters: [2.9398682544225894, 2.6509498353484133, 3.406897197713794, 2.2279299950404954, -5.449688938248971, 6.630679653862764, 7.271596735759121]
current_parameters: [2.9398682544225894, 2.756694605174121, 3.406897197713794, 2.2279299950404954, -5.449688938248971, 6.630679653862764, 7.271596735759121]
current_parameters: [2.9398682544225894, 2.756694605174121, 3.5400973338902637, 2.2279299950404954, -5.449688938248971, 6.630679653862764, 7.271596735759121]
current_parameters: [2.9398682544225894, 2.756694605174121, 3.5400973338902637, 2.3311313718876012, -5.449688938248971, 6.630679653862764, 7.271596735759121]
current_parameters: [2.9398682544225894, 2.756694605174121, 3.5400973338902637, 2.3311313718876012, -5.687905990339014, 6.630679653862764, 7.271596735759121]
current_parameters: [2.9398682544225894, 2.756694605174121, 3.5400973338902637, 2.3311313718876012, -5.687905990339014, 6.9158948038093655, 7.271596735759121]
current_parameters: [2.9398682544225894, 2.756694605

current_parameters: [2.943867006936724, 4.527538038520713, 5.715725753063636, 4.294871169426523, -9.62928880235924, 11.711216996735462, 12.854868198102265]
current_parameters: [2.943867006936724, 4.527538038520713, 5.715725753063636, 4.294871169426523, -9.808914926071663, 11.711216996735462, 12.854868198102265]
current_parameters: [2.943867006936724, 4.527538038520713, 5.715725753063636, 4.294871169426523, -9.808914926071663, 11.933726469185787, 12.854868198102265]
current_parameters: [2.943867006936724, 4.527538038520713, 5.715725753063636, 4.294871169426523, -9.808914926071663, 11.933726469185787, 13.099014561139974]
current_parameters: [2.94408135321115, 4.527538038520713, 5.715725753063636, 4.294871169426523, -9.808914926071663, 11.933726469185787, 13.099014561139974]
current_parameters: [2.94408135321115, 4.600913872213161, 5.715725753063636, 4.294871169426523, -9.808914926071663, 11.933726469185787, 13.099014561139974]
current_parameters: [2.94408135321115, 4.600913872213161, 5.8

current_parameters: [2.9482387826025174, 5.723982633341671, 7.118065990719644, 5.951517707075572, -12.783211381318116, 15.684479119498514, 17.201399865213787]
current_parameters: [2.9482387826025174, 5.776564485863306, 7.118065990719644, 5.951517707075572, -12.783211381318116, 15.684479119498514, 17.201399865213787]
current_parameters: [2.9482387826025174, 5.776564485863306, 7.178268018067345, 5.951517707075572, -12.783211381318116, 15.684479119498514, 17.201399865213787]
current_parameters: [2.9482387826025174, 5.776564485863306, 7.178268018067345, 6.032610599835663, -12.783211381318116, 15.684479119498514, 17.201399865213787]
current_parameters: [2.9482387826025174, 5.776564485863306, 7.178268018067345, 6.032610599835663, -12.918854379111997, 15.684479119498514, 17.201399865213787]
current_parameters: [2.9482387826025174, 5.776564485863306, 7.178268018067345, 6.032610599835663, -12.918854379111997, 15.85901739750873, 17.201399865213787]
current_parameters: [2.9482387826025174, 5.7765

current_parameters: [2.9527678024938426, 6.68998446777528, 8.203461907897985, 7.593714989646743, -15.26799422233823, 18.947124120638453, 20.73177823058075]
current_parameters: [2.9527678024938426, 6.68998446777528, 8.203461907897985, 7.593714989646743, -15.369036988501573, 18.947124120638453, 20.73177823058075]
current_parameters: [2.9527678024938426, 6.68998446777528, 8.203461907897985, 7.593714989646743, -15.369036988501573, 19.083246847931296, 20.73177823058075]
current_parameters: [2.9527678024938426, 6.68998446777528, 8.203461907897985, 7.593714989646743, -15.369036988501573, 19.083246847931296, 20.87777383141246]
current_parameters: [2.95297702305476, 6.68998446777528, 8.203461907897985, 7.593714989646743, -15.369036988501573, 19.083246847931296, 20.87777383141246]
current_parameters: [2.95297702305476, 6.725485439841632, 8.203461907897985, 7.593714989646743, -15.369036988501573, 19.083246847931296, 20.87777383141246]
current_parameters: [2.95297702305476, 6.725485439841632, 8.24

current_parameters: [2.9567944861811055, 7.264955495859678, 8.82954419567016, 8.78090524614389, -17.04208148487047, 21.391772840100973, 23.329277751299593]
current_parameters: [2.9567944861811055, 7.289969835343786, 8.82954419567016, 8.78090524614389, -17.04208148487047, 21.391772840100973, 23.329277751299593]
current_parameters: [2.9567944861811055, 7.289969835343786, 8.856517423655879, 8.78090524614389, -17.04208148487047, 21.391772840100973, 23.329277751299593]
current_parameters: [2.9567944861811055, 7.289969835343786, 8.856517423655879, 8.837789267891015, -17.04208148487047, 21.391772840100973, 23.329277751299593]
current_parameters: [2.9567944861811055, 7.289969835343786, 8.856517423655879, 8.837789267891015, -17.118342169937836, 21.391772840100973, 23.329277751299593]
current_parameters: [2.9567944861811055, 7.289969835343786, 8.856517423655879, 8.837789267891015, -17.118342169937836, 21.49992545337272, 23.329277751299593]
current_parameters: [2.9567944861811055, 7.2899698353437

current_parameters: [2.9600965311384067, 7.6496385403415506, 9.243656585527583, 9.72500601591532, -18.26242082288257, 23.164511499554614, 25.167229329521902]
current_parameters: [2.9600965311384067, 7.66731539720459, 9.243656585527583, 9.72500601591532, -18.26242082288257, 23.164511499554614, 25.167229329521902]
current_parameters: [2.9600965311384067, 7.66731539720459, 9.26272782824143, 9.72500601591532, -18.26242082288257, 23.164511499554614, 25.167229329521902]
current_parameters: [2.9600965311384067, 7.66731539720459, 9.26272782824143, 9.772691359613102, -18.26242082288257, 23.164511499554614, 25.167229329521902]
current_parameters: [2.9600965311384067, 7.66731539720459, 9.26272782824143, 9.772691359613102, -18.3214971560778, 23.164511499554614, 25.167229329521902]
current_parameters: [2.9600965311384067, 7.66731539720459, 9.26272782824143, 9.772691359613102, -18.3214971560778, 23.25299521990826, 25.167229329521902]
current_parameters: [2.9600965311384067, 7.66731539720459, 9.26272

current_parameters: [2.963382109600538, 7.942962011382065, 9.563927568965497, 10.591700174116982, -19.296984222884202, 24.76417058583392, 26.77189194430711]
current_parameters: [2.963382109600538, 7.954637877846569, 9.563927568965497, 10.591700174116982, -19.296984222884202, 24.76417058583392, 26.77189194430711]
current_parameters: [2.963382109600538, 7.954637877846569, 9.576979522481377, 10.591700174116982, -19.296984222884202, 24.76417058583392, 26.77189194430711]
current_parameters: [2.963382109600538, 7.954637877846569, 9.576979522481377, 10.63046860107409, -19.296984222884202, 24.76417058583392, 26.77189194430711]
current_parameters: [2.963382109600538, 7.954637877846569, 9.576979522481377, 10.63046860107409, -19.34128540317004, 24.76417058583392, 26.77189194430711]
current_parameters: [2.963382109600538, 7.954637877846569, 9.576979522481377, 10.63046860107409, -19.34128540317004, 24.835524251117914, 26.77189194430711]
current_parameters: [2.963382109600538, 7.954637877846569, 9.5

current_parameters: [2.96600240223458, 8.125282714719841, 9.774684333529976, 11.26149172267866, -20.037107923156718, 25.94160510769189, 27.903127893126427]
current_parameters: [2.96600240223458, 8.125282714719841, 9.774684333529976, 11.26149172267866, -20.037107923156718, 26.000864588150648, 27.903127893126427]
current_parameters: [2.96600240223458, 8.125282714719841, 9.774684333529976, 11.26149172267866, -20.037107923156718, 26.000864588150648, 27.95858681502571]
current_parameters: [2.966139042819696, 8.125282714719841, 9.774684333529976, 11.26149172267866, -20.037107923156718, 26.000864588150648, 27.95858681502571]
current_parameters: [2.966139042819696, 8.132788713848312, 9.774684333529976, 11.26149172267866, -20.037107923156718, 26.000864588150648, 27.95858681502571]
current_parameters: [2.966139042819696, 8.132788713848312, 9.783854765608657, 11.26149172267866, -20.037107923156718, 26.000864588150648, 27.95858681502571]
current_parameters: [2.966139042819696, 8.132788713848312, 9

current_parameters: [2.9684474849638516, 8.234126043686386, 9.916708284095735, 11.778824435054734, -20.569010935164222, 26.97278615158566, 28.842539633488716]
current_parameters: [2.9684474849638516, 8.238689684845488, 9.916708284095735, 11.778824435054734, -20.569010935164222, 26.97278615158566, 28.842539633488716]
current_parameters: [2.9684474849638516, 8.238689684845488, 9.923325822579722, 11.778824435054734, -20.569010935164222, 26.97278615158566, 28.842539633488716]
current_parameters: [2.9684474849638516, 8.238689684845488, 9.923325822579722, 11.804591828353217, -20.569010935164222, 26.97278615158566, 28.842539633488716]
current_parameters: [2.9684474849638516, 8.238689684845488, 9.923325822579722, 11.804591828353217, -20.594490536376615, 26.97278615158566, 28.842539633488716]
current_parameters: [2.9684474849638516, 8.238689684845488, 9.923325822579722, 11.804591828353217, -20.594490536376615, 27.021942445890947, 28.842539633488716]
current_parameters: [2.9684474849638516, 8.23

current_parameters: [2.970596351399778, 8.298877669485364, 10.02443230629185, 12.220843127064937, -20.990822090612166, 27.83216573851965, 29.573817948549948]
current_parameters: [2.970596351399778, 8.301123738146481, 10.02443230629185, 12.220843127064937, -20.990822090612166, 27.83216573851965, 29.573817948549948]
current_parameters: [2.970596351399778, 8.301123738146481, 10.029169925674733, 12.220843127064937, -20.990822090612166, 27.83216573851965, 29.573817948549948]
current_parameters: [2.970596351399778, 8.301123738146481, 10.029169925674733, 12.241493208806224, -20.990822090612166, 27.83216573851965, 29.573817948549948]
current_parameters: [2.970596351399778, 8.301123738146481, 10.029169925674733, 12.241493208806224, -21.0096686065636, 27.83216573851965, 29.573817948549948]
current_parameters: [2.970596351399778, 8.301123738146481, 10.029169925674733, 12.241493208806224, -21.0096686065636, 27.873366405145834, 29.573817948549948]
current_parameters: [2.970596351399778, 8.301123738

In [14]:
y

array([ 3.23107784e+00,  5.57248637e+00,  3.98049002e+00,  4.65304810e+00,
        9.73648501e+00,  9.03136498e-03, -1.17493866e+00,  6.77555529e+00,
        3.83997861e+00, -1.29486692e+00,  1.22982188e+00, -2.71529988e-01,
        3.76706916e+00, -2.83126721e+00,  1.25123898e+00,  4.64621926e+00,
        2.97776653e-01,  7.02508335e+00,  6.28964156e-01,  7.42215544e-01,
       -6.06038611e+00,  9.67766369e+00, -3.79968777e+00, -2.81180075e+00,
       -1.30220493e+00,  1.44606617e+01,  5.81662544e+00, -1.12786472e+00,
        1.69793584e+00,  3.23562676e+00, -3.81989179e+00,  7.24090437e+00,
        3.16259494e+00,  1.38047507e+01, -5.28828147e-01,  1.54552823e+00,
        8.80122856e+00,  2.39490138e+00,  4.20068365e+00,  3.22401915e+00,
       -9.41208139e-01,  4.76850397e+00,  1.13055370e-01, -1.74108615e+00,
        3.14439394e+00,  4.53824982e+00,  1.49203288e+00,  1.03037261e+01,
        4.91960921e+00, -4.32038498e+00, -4.16031400e+00,  5.79678577e+00,
        2.26707086e+00,  

In [15]:
np.matmul(X, [3] + list(coef))

array([ 3.23107784e+00,  5.57248637e+00,  3.98049002e+00,  4.65304810e+00,
        9.73648501e+00,  9.03136498e-03, -1.17493866e+00,  6.77555529e+00,
        3.83997861e+00, -1.29486692e+00,  1.22982188e+00, -2.71529988e-01,
        3.76706916e+00, -2.83126721e+00,  1.25123898e+00,  4.64621926e+00,
        2.97776653e-01,  7.02508335e+00,  6.28964156e-01,  7.42215544e-01,
       -6.06038611e+00,  9.67766369e+00, -3.79968777e+00, -2.81180075e+00,
       -1.30220493e+00,  1.44606617e+01,  5.81662544e+00, -1.12786472e+00,
        1.69793584e+00,  3.23562676e+00, -3.81989179e+00,  7.24090437e+00,
        3.16259494e+00,  1.38047507e+01, -5.28828147e-01,  1.54552823e+00,
        8.80122856e+00,  2.39490138e+00,  4.20068365e+00,  3.22401915e+00,
       -9.41208139e-01,  4.76850397e+00,  1.13055370e-01, -1.74108615e+00,
        3.14439394e+00,  4.53824982e+00,  1.49203288e+00,  1.03037261e+01,
        4.91960921e+00, -4.32038498e+00, -4.16031400e+00,  5.79678577e+00,
        2.26707086e+00,  

In [16]:
np.matmul(X, current_params)

array([ 1.93280813,  4.63509131,  3.84946986,  3.84487046,  8.63330861,
        0.70191142, -1.28581677,  7.18868238,  3.42867866, -0.91314065,
        1.38585565,  0.34642245,  3.24293602, -2.92630287,  1.87278197,
        4.97797226,  1.17253116,  7.14516251, -0.42423515,  0.61646927,
       -6.02008149,  9.77318631, -2.19428692, -1.74949559,  0.08077515,
       13.20662109,  4.55325935, -1.34023956,  1.82911837,  3.58843994,
       -3.19961529,  6.89576098,  3.20670932, 12.98125153, -0.37887961,
        0.94006564,  8.1642303 ,  3.90272761,  3.75951777,  2.60852234,
       -1.14675286,  5.12559241,  1.64787458, -1.26647071,  3.63108032,
        4.79383581,  2.26785401,  9.35836887,  4.06465073, -3.37222673,
       -3.06385791,  4.4322517 ,  3.03320442,  8.10972681,  3.66723534,
       12.73849473,  3.26733567,  4.25860604, -1.15858248, -3.87646139,
       -1.08960457,  2.12069494,  3.08755462,  2.3337396 , 13.52563947,
        1.46379482,  6.16829049,  3.92382318,  3.19720827,  2.75