## 1.1 Dataset construction

In [1]:
import numpy as np

In [29]:
# create a mxd design matrix X
X = np.random.rand(150,75)

In [39]:
# create weight vector theta
theta_1 = np.zeros(65)
t = [-10,10]
theta_2 = np.random.choice(t,10)
theta = np.concatenate([theta_2, theta_1])

In [40]:
# create the response y
epsilon = 0.1 * np.random.randn(150)
y = np.dot(X, theta) + epsilon

In [41]:
X_train = X[:80,:]
y_train = y[:80]
X_validation = X[80:100,:]
y_validation = y[80:100]
X_test = X[100:,:]
y_test = y[100:]

## 1.2 Ridge Regression

In [58]:
import numpy
from scipy.optimize import minimize

(N,D) = X.shape

w = numpy.random.rand(D,1)
def ridge(X, y, Lambda):
  def ridge_obj(theta):
    return ((numpy.linalg.norm(numpy.dot(X,theta) - y))**2)/(2*N) + Lambda*(numpy.linalg.norm(theta))**2
  return ridge_obj

def compute_loss(X, y, theta):
  return ((numpy.linalg.norm(numpy.dot(X,theta) - y))**2)/(2*N)

def find_best_lambda(X_train, y_train, X_validation, y_validation, tol=10**-3 ):
    result = []
    lambda_list = []
    theta_list=[]
    zero_count_list=[]
    
    #generate list of lambda, loss_result, zero_count(tol <)
    for i in range(-10,4):
      Lambda = 10**i;
      w_opt = minimize(ridge(X_train, y_train, Lambda), w)
      
      theta_list.append(w_opt.x)
      #zero_count of w_opt(theta_imitate)
      
      w_opt_greater_than_tol_tuple = np.where( abs(np.asarray(w_opt.x[0:10]))<=tol )+np.where(abs(np.asarray(w_opt.x[10:-1]>tol)))
      zero_count = int(len(w_opt_greater_than_tol_tuple[0]))
      zero_count_list.append(zero_count)
      result.append(compute_loss(X_validation, y_validation, w_opt.x))
      lambda_list.append(Lambda)
    
    #select best lambda based on loss_result
    result_array = np.asarray(result)
    lambda_array = np.asarray(lambda_list)
    min_result_index = result_array.argmin()
    best_lambda = lambda_list[min_result_index]
    theta_opt = theta_list[min_result_index]
    return lambda_array, result_array, best_lambda, theta_opt, zero_count_list

## 2.Coordinate Descent for Lasso (a.k.a. The Shooting algorithm)

In [114]:
#initialize 
#lambda_max = 2 * (infinity_norm||X.T*y||)
lambda_max = 2*(np.max(abs(np.dot(X_train.T, y_train))))
lambda_max_int = int(lambda_max)
theta_opt = find_best_lambda(X_train, y_train, X_validation, y_validation)[3]

# define loss function
def loss_fuc(X, y, theta):
    loss = (np.linalg.norm(np.dot(X, theta)-y, ord=2)**2)/2*X_train.shape[0]
    return loss

def loss_obj_fuc(X, y, theta, lambda_reg):
    loss_obj = loss_fuc(X, y, theta) + lambda_reg*np.linalg.norm(theta, ord=1)
    return loss_obj

def Lasso_opt_slow (X_train, y_train, X_validation, y_validation, theta_opt, max_iter=10000000, tol=10**-8):
    loss_results = []
    lambda_reg_list = []
    w_list = []
    for k in range(-10,4):
        lambda_reg = 10**k
        previous_loss = 0
        current_loss = 0
        w = np.zeros(X.shape[1])
        num_iter = 0
        diff = 1
        while(num_iter<=max_iter and diff>=tol):
            previous_loss = loss_fuc(X_validation, y_validation, theta_opt)
            for j in range(X_train.shape[1]):
                a = 0
                c = 0
                for i in range(X_train.shape[0]):
                    a += 2*X_train[i,j]**2
                    c += 2*X_train[i,j]*(y_train[i]-np.dot(w.T,X_train[i])+w[j]*X_train[i,j])
                    w[j] = np.sign(c/a) * max(0,abs(c/a)-lambda_reg)
            current_loss = loss_fuc(X_validation, y_validation, w)
            diff = current_loss-previous_loss
            #print(diff)
            num_iter += 1
        w_list.append(w)
        loss_results.append(current_loss)
        lambda_reg_list.append(lambda_reg)
    
    result_arr = np.array(loss_results)
    min_result_index = result_arr.argmin()
    min_result = result_arr[min_result_index]
    best_lambda = lambda_reg_list[min_result_index]
    w_opt = w_list[min_result_index]
    return min_result, best_lambda, w_opt, w_list
    
                    
                    


In [None]:
Lasso_opt_slow (X_train, y_train, X_validation, y_validation, theta_opt, max_iter=100, tol=10**-8)

In [99]:
import matplotlib.pyplot as plt
%matplotlib inline

ridge_lambda = find_best_lambda(X_train, y_train, X_validation, y_validation)[0]
ridge_loss = find_best_lambda(X_train, y_train, X_validation, y_validation)[1]
lasso_lambda =  Lasso_opt_slow (X_train, y_train, X_validation, y_validation, theta_opt,max_iter=200, tol=10**-1)[2]
lasso_loss =  Lasso_opt_slow (X_train, y_train, X_validation, y_validation, theta_opt,max_iter=200, tol=10**-1)[3]

plt.figure(figsize=(20,10))
plt.xlim([0,1000])
plt.ylim([0,150])
plt.plot(ridge_lambda, ridge_loss, 'r--', lasso_lambda, lasso_loss, 'b--')

21509.258655846053
19342.71545806464
18674.74377852873
18545.767124774804
18659.646622622724
18807.012175660322
18856.73271529637
18759.738754466747
18523.220514849258
18180.509504734477
17770.16975194529
17325.43761518894
16870.759648268144
16422.07757020459
15988.606750896623
15574.883286525135
15182.531669998652
14811.581256268597
14461.342275276957
14130.92508769023
13819.502665740032
13526.40638344397
13251.125736007623
12993.262056409229
12752.468506685364
12528.394884232763
12320.646073823234
12128.756680800625
11952.18071168903
11790.29335232009
11642.401310601677
11507.758332768126
11385.583030007476
11275.07682194484
11175.440478569677
11085.888337808365
11005.659754662107
10934.02769826308
10870.304658329485
10813.84617541638
10764.052383407732
10720.36797417213
10682.280977888895
10649.320711050268
10621.055191952346
10597.088266173374
10577.056631014057
10560.626895682706
10547.4927729923
10537.372462312984
10530.006256057966
10525.154381308854
10522.59507352721
10522.1228

10813.84479049303
10764.050983425714
10720.366555531646
10682.279537782988
10649.319247119662
10621.053710302354
10597.08675938237
10577.055093265122
10560.62532267714
10547.491161649603
10537.370810436552
10530.004562049267
10525.15264395262
10522.593291843257
10522.121042766526
10523.545229098287
10526.688652167084
10531.386413688759
10537.484884033409
10544.84078808813
10553.320391690986
10562.798773984485
10573.159183244206
10584.292414399011
10596.096307133681
10608.475247386274
10621.339722204339
10634.60591343089
10648.195326196324
10662.034451398022
10676.054439981252
10690.190844777206
10704.383344330165
10718.575525817292
10732.714600012992
10746.751288960057
10760.639599030432
10774.336663953329
10787.802591129535
10801.000320958467
10813.895497254727
10826.456347277086
10838.653570121976
10850.460232386511
10861.851670131267
10872.805396285376
10883.30101274271
10893.320126497081
10902.846269257694
10911.864820072624
10920.362930566313
10928.32945246814
10935.754867171358
1

10787.493817860943
10800.684875645555
10813.573351463092
10826.127491550693
10838.318003155075
10850.117954973632
10861.502682318122
10872.449696463635
10882.938597813221
10892.95099254237
10902.4704123921
10911.482237287944
10919.973620483568
10927.93341595955
10935.35210784617
10942.221741678368
10948.53585733018
10954.28942351018
10959.47877373096
10964.101543690282
10968.156610022923
10971.643846849538
10974.563916561454
10976.9203229981
10978.71613066141
10979.955408854892
10980.641482240731
10980.782367693007
10980.384640559438
10979.455321940235
10978.002036623397
10976.034592130636
10973.562577883018
10970.595997157694
10967.145337812508
10963.221574760651
10958.836373321652
10954.000974914385
10948.727602657516
10943.02894964764
10936.919065710526
10930.40926132885
10923.514169751215
10916.245492847787
10908.618594935864
10900.647492281985
10892.346197573685
10883.72878240183
10874.809377972882
10865.602155093093
10856.118802866667
10846.37415851908
10836.384706698185
10826.16

10913.121669564176
10908.541356800139
10903.467576408599
10897.919385589483
10891.918574402816
10885.485888354706
10878.60175620304
10871.44137752533
10863.793836234594
10855.724843398124
10847.288381142682
10838.511559026012
10829.411986389297
10820.004362487161
10810.302745240177
10800.321185699231
10789.700419785131
10778.787346206676
10767.676972221207
10756.392979067392
10744.945156957498
10733.339214429132
10721.580509625861
10709.675350619254
10697.631302396745
10685.457100217078
10673.16242873372
10660.757678543392
10648.253724779972
10635.661741773874
10622.993054443175
10610.259021956896
10597.470947881358
10584.640011293941
10571.777214176487
10558.89334131552
10545.998929758694
10533.104245555432
10520.219266048894
10507.353666406894
10494.516809403529
10481.717737711771
10468.965168156628
10456.267487520114
10443.632749596103
10431.068673270645
10418.582641460782
10406.18170078517
10393.87256187008
10381.66160021433
10369.554857552503
10357.558043666184
10345.676538601565


3539.931407378903
3513.384794897904
3487.365783979582
3461.95799538809
3436.9744440552595
3412.505385542465
3388.680793288766
3365.487247902592
3342.9006750582425
3320.892982025599
3299.4271816419296
3278.5980478840465
3258.3819906359467
3238.74568477746
3219.665211404003
3201.1185621150844
3182.984507691279
3162.095225419921
3142.1887063789354
3124.212109860612
3107.0954628798263
3090.6001380189828
3074.5261231065915
3058.753984602367
3043.209891511569
3027.8455717829274
3012.4886118790655
2997.17306700674
2981.94804201559
2967.9015180003307
2955.7393645134994
2944.876441905697
2934.816661823914
2925.245197001909
2915.6609247645783
2905.708550176265
2895.3096820835704
2884.713238210915
2873.948058375089
2863.0390675241106
2852.0058438433853
2840.8646566157468
2829.6302080258556
2818.316616336574
2806.937883201722
21558.531238230673
18803.761577288016
17963.20715395859
17474.509155434644
16978.755540392634
16493.85893016155
16011.178877798682
15277.850265991803
14467.311049598655
13570

36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.7157

36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.71574381477
36136.7157

10816.439956169585
10805.81000278003
10794.990449446748
10783.994962316146
10772.837059117634
10761.53009335212
10750.087239301356
10738.521477844315
10726.845583066442
10715.07210964727
10703.21338101237
10691.28147823485
10679.288229672291
10667.245201324762
10655.163687899112
10643.05470456556
10630.928979391443
10618.796946437751
10606.668739503402
10594.55418650229
10582.462804457919
10570.403795100241
10558.386041049223
10546.41810256944
10534.508214879841
10522.66428600281
10510.893895136333
10499.20429153292
10487.602393869367
10476.094790090374
10464.687737710015
10453.387164554404
10442.19866992882
10431.12752619307
10420.178680728413
10409.356758279651
10398.66606365582
10388.11058477366
10377.693996027043
10367.419661966895
10357.290641275447
10347.309691019194
10337.479271165064
10327.801549344586
10318.278405850962
10308.911438854198
10299.701969819935
10290.65104911762
10281.759461804026
10273.027733568671
10264.45613682752
10256.044696952236
10247.793198622001
10239.701

10377.594214948704
10367.31920325957
10357.189506489676
10347.207881512404
10337.37678810763
10327.698393726278
10318.174578487495
10308.806940394164
10299.596800751697
10290.545209776328
10281.652952378725
10272.920554109356
10264.348287252178
10255.936177053845
10247.684008075446
10239.59133065471
10231.657467466444
10223.881520169634
10216.262376129813
10208.798715205905
10201.489016590369
10194.33156569307
10187.324461058275
10180.465621305637
10173.752792085494
10167.183553039875
10160.75532476037
10154.465375734751
10148.310829274267
10142.28867041407
10136.395752779445
10130.628805410788
21509.258047489995
19342.706696182675
18674.733789522823
18545.754986490938
18659.624868881114
18806.979876300913
18856.69161450925
18759.69012040162
18523.161912601223
18180.441746388646
17770.09378069749
17325.35443532861
16870.666971755087
16421.977253261823
15988.500021066575
15574.772495847314
15182.417415450704
14811.46222413539
14461.219651363228
14130.79954118896
13819.374327036572
13526

KeyboardInterrupt: 

In [145]:
# count zero of lasso shooting algorithm
import pandas as pd
def count_zero_lasso_shooting(tol=10**-3):
    w_list = Lasso_opt_slow(X_train, y_train, X_validation, y_validation, theta_opt, max_iter=100, tol=10**-1)[-1]
    lambda_list = []
    zero_count_list = []
    for k in range(-10,4):
        lambda_reg = 10**k
        lambda_list.append(lambda_reg)
    for i in range(len(w_list)):
        zero_count = (abs(w_list[i][0:10])<=tol).sum()+(abs(w_list[i][10:-1])>tol).sum()
        zero_count_list.append(zero_count)
    return pd.DataFrame(lambda_list,zero_count_list)

In [146]:
count_zero_lasso_shooting(tol=10**-3)

Unnamed: 0,0
64,1e-10
64,1e-09
64,1e-08
64,1e-07
64,1e-06
64,1e-05
64,0.0001
63,0.001
62,0.01
34,0.1
