In [1]:
from proj1_helpers import *
import numpy as np

In [2]:
## Load data
y_train, x_train, idx_train = load_csv_data("train.csv")
y_test, x_test, idx_test = load_csv_data("test.csv")

In [None]:
def split_data(x_train):
    tx_der = x_train[:, :14]
    tx_pri = x_train[:, 13:]
    return tx_der, tx_pri

In [None]:
np.corrcoef(x_train)


## helper function

In [3]:
## Strandardize
def standardize(x):
    mean_x = np.mean(x)
    x = x - mean_x
    std_x = np.std(x)
    x = x / std_x
    return x, mean_x, std_x

In [4]:
## Batch_iter
def batch_iter(y, tx, batch_size, num_batches=1, shuffle=True):
    """
    Generate a minibatch iterator for a dataset.
    Takes as input two iterables (here the output desired values 'y' and the input data 'tx')
    Outputs an iterator which gives mini-batches of `batch_size` matching elements from `y` and `tx`.
    Data can be randomly shuffled to avoid ordering in the original data messing with the randomness of the minibatches.
    Example of use :
    for minibatch_y, minibatch_tx in batch_iter(y, tx, 32):
        <DO-SOMETHING>
    """
    data_size = len(y)

    if shuffle:
        shuffle_indices = np.random.permutation(np.arange(data_size))
        shuffled_y = y[shuffle_indices]
        shuffled_tx = tx[shuffle_indices]
    else:
        shuffled_y = y
        shuffled_tx = tx
    for batch_num in range(num_batches):
        start_index = batch_num * batch_size
        end_index = min((batch_num + 1) * batch_size, data_size)
        if start_index != end_index:
            yield shuffled_y[start_index:end_index], shuffled_tx[start_index:end_index]

In [5]:
## Build k_indices (Ridge Regression)
def build_k_indices(y, k_fold, seed):
    num_row = y.shape[0]
    interval = int(num_row / k_fold)
    np.random.seed(seed)
    indices = np.random.permutation(num_row)
    k_indices = [indices[k * interval: (k + 1) * interval]
                 for k in range(k_fold)]
    return np.array(k_indices)

## Transform Data

In [6]:
# remove -999
x_train[x_train == -999] = 0
x_test[x_test == -999] = 0
# Standardize

tx_train, tx_trmean, tx_trstd = standardize(x_train)
tx_test, tx_temena, tx_testd = standardize(x_test)

# ## split_data into der and pri
# tx_train_der, tx_train_pri = split_data(tx_train)
# tx_test_der, tx_test_pri = split_data(tx_test)


## Regression

### Least Squares GD

In [None]:
def compute_gradient(y, tx, w):
    """Compute the gradient."""
    err = y - tx.dot(w)
    grad = -tx.T.dot(err) / len(err)
    return grad, err

def calculate_mse(y, tx, w):
    """mean square error"""
    err = y - tx.dot(w)
    return 1/2*np.mean(err**2)


In [None]:
def least_squares_GD(y, tx, initial_w, max_iters, gamma):

    threshold = 1e-8

    losses = []
    w = initial_w

    for i in range(max_iters):
        # compute loss, gradient
        grad, err = compute_gradient(y, tx, w)
        loss = calculate_mse(y, tx, w)
        # gradient w by descent update
        w = w - gamma * grad
        # store w and loss
        losses.append(loss)
        
        print("GD({bi}/{ti}): loss={l}".format(
              bi=i, ti=max_iters - 1, l=loss))  
        
        # converge criteria
        if len(losses) > 1 and np.abs(losses[-1] - losses[-2]) < threshold:
            break
    return losses[-1], w

In [None]:
## training by least_squares_GD
max_iters = 10000
gamma = 0.01
initial_w = np.zeros(tx_train_der.shape[1])
lgd_loss, lgd_w = least_squares_GD(y_train, tx_train_der, initial_w, max_iters, gamma) 

In [None]:
lgd_w

In [None]:
lgd_loss

### Least Squares SGD

In [None]:
def compute_stoch_gradient(y, tx, w):
    e = y - tx.dot(w)
    gradient = -tx.T.dot(e) / len(e)
    return gradient, e

In [None]:
def least_squares_SGD(y, tx, initial_w, max_iters=1000, gamma=0.01):
    threshold = 1e-8
    losses = []
    w = initial_w

    for n_iter in range(max_iters):
        for y_batch, tx_batch in batch_iter(y, tx, 1, 1):
            
            grad, err = compute_stoch_gradient(y_batch, tx_batch, w)
            w = w - gamma * grad
            
            loss = calculate_mse(y, tx, w)
            losses.append(loss)
        
        print("SGD({bi}/{ti}): loss={l}".format(
              bi=n_iter, ti=max_iters - 1, l=loss))  
        
        if len(losses) > 1 and np.abs(losses[-1] - losses[-2]) < threshold:
            break

    return losses[-1], w

In [None]:
## training by least_squares_SGD
max_iters = 10000
gamma = 0.001
initial_w = np.zeros(tx_train.shape[1])
lsgd_loss, lsgd_w = least_squares_SGD(y_train, tx_train, initial_w, max_iters, gamma) 

In [None]:
lsgd_w

In [None]:
lsgd_loss

### Ridge Regression

In [None]:
def ridge_regression(y, tx, lambda_):
    Xt = tx.T
    term1 = Xt.dot(tx) + 2 * tx.shape[0] * lambda_ * np.identity(tx.shape[1])
    weights = np.linalg.inv(term1).dot(Xt).dot(y)

    loss = compute_loss(y, tx, weights)
    return loss, weights

### Least Squares

In [None]:
# Least squares
def least_squares(y, tx):
    
    """least squares regression using normal equations"""
    
    a = tx.T.dot(tx)
    b = tx.T.dot(y)
    w = np.linalg.solve(a, b)
    loss = calculate_mse(y, tx, w)
    
    return loss, w


In [None]:
ls_loss, ls_w = least_squares(y_train, tx_train_pri)

In [None]:
ls_loss

In [None]:
ls_w

In [None]:
ls_loss

### Logistic Regression (SGD)

In [17]:
def box_cox(x):
    if (x == 0):
        

def sigmoid(x):
    return np.exp(-np.logaddexp(0, -x))

def calculate_gradient(y, tx, w):
    """compute the gradient of loss."""
    sig = sigmoid(tx.dot(w))
    gradient = tx.T.dot(sig - y)
    return gradient

def calculate_negative_log_likelihood(y, tx, w):
    """compute the cost by negative log likelihood."""
    epsilon = 1e-12
    sig = sigmoid(tx.dot(w))
    loss = y.T.dot(np.log(sig + epsilon )) + (1 - y).T.dot(np.log(1 - sig + epsilon))

    return np.squeeze(-loss)

def learning_by_gradient_descent(y, tx, w, gamma):

    loss = calculate_negative_log_likelihood(y, tx, w)
    gradient = calculate_gradient(y, tx, w)
    w = w - gamma * gradient
    return loss, w

In [18]:
def logistic_regression_SGD(y, tx, initial_w, max_iters, gamma):
    """Logistic regression using stochastic SGD"""

    threshold = 1e-8
    losses = []
    w = initial_w
    for n_iter in range(max_iters):
        for y_batch, tx_batch in batch_iter(y, tx, batch_size=1, num_batches=1):
            # compute a stochastic gradient and loss
            grad = calculate_gradient(y_batch, tx_batch, w)
            # update w through the stochastic gradient update
            w = w - gamma * grad
            # calculate loss
            loss = calculate_negative_log_likelihood(y, tx, w)
            # store w and loss

            losses.append(loss)

        print("LR_SGD({bi}/{ti}): loss={l}".format(
              bi=n_iter, ti=max_iters - 1, l=loss))
        if len(losses) > 1 and np.abs(losses[-1] - losses[-2]) < threshold:
            break
        
    return losses[-1], w

In [19]:
## training by logistci_regression
max_iters = 10000
gamma = 0.001
initial_w = np.zeros(tx_train.shape[1])
lrsgd_loss, lrsgd_w = logistic_regression_SGD(y_train, tx_train, initial_w, max_iters, gamma) 

LR_SGD(0/4999): loss=174002.69733087323
LR_SGD(1/4999): loss=175029.6066102339
LR_SGD(2/4999): loss=173040.3773916068
LR_SGD(3/4999): loss=171336.86583928028
LR_SGD(4/4999): loss=172724.62706090487
LR_SGD(5/4999): loss=170040.14680903545
LR_SGD(6/4999): loss=170925.7633430998
LR_SGD(7/4999): loss=168487.04627542468
LR_SGD(8/4999): loss=169307.82849522715
LR_SGD(9/4999): loss=165621.49757442775
LR_SGD(10/4999): loss=164258.39389870805
LR_SGD(11/4999): loss=164825.39624994906
LR_SGD(12/4999): loss=165839.16480029433
LR_SGD(13/4999): loss=163587.78288350243
LR_SGD(14/4999): loss=161131.23483095237
LR_SGD(15/4999): loss=162103.74784343078
LR_SGD(16/4999): loss=159010.15820905886
LR_SGD(17/4999): loss=156085.75221938777
LR_SGD(18/4999): loss=154679.77710999548
LR_SGD(19/4999): loss=148374.58746120255
LR_SGD(20/4999): loss=149458.66604861311
LR_SGD(21/4999): loss=150314.593860671
LR_SGD(22/4999): loss=148432.43986384541
LR_SGD(23/4999): loss=142017.23144045443
LR_SGD(24/4999): loss=138205.21

LR_SGD(200/4999): loss=-24489.391433863682
LR_SGD(201/4999): loss=-25735.300435974146
LR_SGD(202/4999): loss=-24901.302326325298
LR_SGD(203/4999): loss=-25907.06236913039
LR_SGD(204/4999): loss=-26937.628256793134
LR_SGD(205/4999): loss=-27437.01909461929
LR_SGD(206/4999): loss=-28597.863179760607
LR_SGD(207/4999): loss=-29510.759906984007
LR_SGD(208/4999): loss=-30781.277798096635
LR_SGD(209/4999): loss=-31793.271563045753
LR_SGD(210/4999): loss=-32546.352130045998
LR_SGD(211/4999): loss=-30852.557377446123
LR_SGD(212/4999): loss=-31717.54804566645
LR_SGD(213/4999): loss=-31094.93507878638
LR_SGD(214/4999): loss=-29692.79120559407
LR_SGD(215/4999): loss=-28715.141154749814
LR_SGD(216/4999): loss=-29620.33173780926
LR_SGD(217/4999): loss=-31384.052994943777
LR_SGD(218/4999): loss=-30037.948107414573
LR_SGD(219/4999): loss=-31358.244855960234
LR_SGD(220/4999): loss=-33347.77946870777
LR_SGD(221/4999): loss=-34532.445656080934
LR_SGD(222/4999): loss=-33357.01028628745
LR_SGD(223/4999): l

LR_SGD(397/4999): loss=-109976.97472180505
LR_SGD(398/4999): loss=-111161.3880173828
LR_SGD(399/4999): loss=-112169.80825060465
LR_SGD(400/4999): loss=-111059.15639983566
LR_SGD(401/4999): loss=-110465.46140706068
LR_SGD(402/4999): loss=-111298.4637627971
LR_SGD(403/4999): loss=-112072.34329307705
LR_SGD(404/4999): loss=-112875.26577852179
LR_SGD(405/4999): loss=-113406.9490788104
LR_SGD(406/4999): loss=-114065.43864862222
LR_SGD(407/4999): loss=-115009.31024431848
LR_SGD(408/4999): loss=-116159.06615957302
LR_SGD(409/4999): loss=-115525.52627698127
LR_SGD(410/4999): loss=-116554.52582567587
LR_SGD(411/4999): loss=-117773.87949460628
LR_SGD(412/4999): loss=-116719.3237810052
LR_SGD(413/4999): loss=-117365.17196018319
LR_SGD(414/4999): loss=-116693.2680297381
LR_SGD(415/4999): loss=-116164.88909636682
LR_SGD(416/4999): loss=-115708.88385831853
LR_SGD(417/4999): loss=-116840.7338320939
LR_SGD(418/4999): loss=-117627.14350662436
LR_SGD(419/4999): loss=-117082.31807521418
LR_SGD(420/4999):

LR_SGD(590/4999): loss=-155672.92323972925
LR_SGD(591/4999): loss=-155116.7479843762
LR_SGD(592/4999): loss=-154472.3694757152
LR_SGD(593/4999): loss=-153907.6812679806
LR_SGD(594/4999): loss=-152975.90816398704
LR_SGD(595/4999): loss=-153900.47389117494
LR_SGD(596/4999): loss=-154737.68665839895
LR_SGD(597/4999): loss=-156138.77808860037
LR_SGD(598/4999): loss=-157077.4326179462
LR_SGD(599/4999): loss=-157996.60329276772
LR_SGD(600/4999): loss=-157094.54866245342
LR_SGD(601/4999): loss=-158738.19023246592
LR_SGD(602/4999): loss=-159559.68284460457
LR_SGD(603/4999): loss=-160142.31048570093
LR_SGD(604/4999): loss=-160878.4281962764
LR_SGD(605/4999): loss=-161419.81617173378
LR_SGD(606/4999): loss=-162908.0773580874
LR_SGD(607/4999): loss=-163673.40542187775
LR_SGD(608/4999): loss=-164491.765861215
LR_SGD(609/4999): loss=-162730.84742407413
LR_SGD(610/4999): loss=-163555.1579279533
LR_SGD(611/4999): loss=-164632.8045656521
LR_SGD(612/4999): loss=-165482.65804009
LR_SGD(613/4999): loss=-

LR_SGD(783/4999): loss=-235684.70197219058
LR_SGD(784/4999): loss=-236578.25312323374
LR_SGD(785/4999): loss=-237228.25544212316
LR_SGD(786/4999): loss=-236462.3907069865
LR_SGD(787/4999): loss=-237124.53279209856
LR_SGD(788/4999): loss=-236447.04637289172
LR_SGD(789/4999): loss=-237246.93453702654
LR_SGD(790/4999): loss=-237858.00662062853
LR_SGD(791/4999): loss=-237210.26558373915
LR_SGD(792/4999): loss=-236951.7014948195
LR_SGD(793/4999): loss=-237523.6902052235
LR_SGD(794/4999): loss=-238060.79482568474
LR_SGD(795/4999): loss=-238592.70221458958
LR_SGD(796/4999): loss=-239254.12598709462
LR_SGD(797/4999): loss=-239932.97814820806
LR_SGD(798/4999): loss=-240447.5851186408
LR_SGD(799/4999): loss=-241004.46672894817
LR_SGD(800/4999): loss=-241551.15277969808
LR_SGD(801/4999): loss=-240606.83579137502
LR_SGD(802/4999): loss=-240009.46694782583
LR_SGD(803/4999): loss=-240771.37532097896
LR_SGD(804/4999): loss=-240235.35616639355
LR_SGD(805/4999): loss=-239567.17660896998
LR_SGD(806/4999

LR_SGD(979/4999): loss=-283012.7253340373
LR_SGD(980/4999): loss=-283595.24934140884
LR_SGD(981/4999): loss=-284153.2554496648
LR_SGD(982/4999): loss=-284818.78886092256
LR_SGD(983/4999): loss=-285720.7422090117
LR_SGD(984/4999): loss=-285027.748272225
LR_SGD(985/4999): loss=-285663.10812166776
LR_SGD(986/4999): loss=-286170.5659844896
LR_SGD(987/4999): loss=-286927.42797988077
LR_SGD(988/4999): loss=-287575.2911187638
LR_SGD(989/4999): loss=-288131.04179672716
LR_SGD(990/4999): loss=-287622.27472773794
LR_SGD(991/4999): loss=-289669.0074474861
LR_SGD(992/4999): loss=-290325.4849389041
LR_SGD(993/4999): loss=-290908.69378184236
LR_SGD(994/4999): loss=-291593.2615119033
LR_SGD(995/4999): loss=-292047.5982695178
LR_SGD(996/4999): loss=-292695.84505729045
LR_SGD(997/4999): loss=-293546.09542721196
LR_SGD(998/4999): loss=-292853.0488401685
LR_SGD(999/4999): loss=-291531.0273977985
LR_SGD(1000/4999): loss=-292056.71137305634
LR_SGD(1001/4999): loss=-291332.12979837175
LR_SGD(1002/4999): los

LR_SGD(1171/4999): loss=-335544.57318314817
LR_SGD(1172/4999): loss=-336084.71862327453
LR_SGD(1173/4999): loss=-336510.5477343211
LR_SGD(1174/4999): loss=-335965.75092537026
LR_SGD(1175/4999): loss=-335543.6471735666
LR_SGD(1176/4999): loss=-335493.97889237176
LR_SGD(1177/4999): loss=-334886.67553099815
LR_SGD(1178/4999): loss=-334313.1003413306
LR_SGD(1179/4999): loss=-335042.923604005
LR_SGD(1180/4999): loss=-335465.42812207196
LR_SGD(1181/4999): loss=-334804.75951789087
LR_SGD(1182/4999): loss=-333927.8593454328
LR_SGD(1183/4999): loss=-334533.77377728326
LR_SGD(1184/4999): loss=-335203.22699970775
LR_SGD(1185/4999): loss=-334633.73006462253
LR_SGD(1186/4999): loss=-334656.11352695076
LR_SGD(1187/4999): loss=-335356.06094042666
LR_SGD(1188/4999): loss=-334990.765803816
LR_SGD(1189/4999): loss=-335517.3314287756
LR_SGD(1190/4999): loss=-334879.3114411203
LR_SGD(1191/4999): loss=-336152.4434591861
LR_SGD(1192/4999): loss=-335769.1139419685
LR_SGD(1193/4999): loss=-336128.517508759
LR

LR_SGD(1363/4999): loss=-361158.7113009783
LR_SGD(1364/4999): loss=-361777.5658676632
LR_SGD(1365/4999): loss=-362245.7973250772
LR_SGD(1366/4999): loss=-362770.7274639982
LR_SGD(1367/4999): loss=-362026.4655197535
LR_SGD(1368/4999): loss=-362537.2281580877
LR_SGD(1369/4999): loss=-364080.4507477252
LR_SGD(1370/4999): loss=-364740.7324592876
LR_SGD(1371/4999): loss=-365105.302042044
LR_SGD(1372/4999): loss=-365572.52461719135
LR_SGD(1373/4999): loss=-365039.1211456737
LR_SGD(1374/4999): loss=-364582.5009073015
LR_SGD(1375/4999): loss=-364187.3311521986
LR_SGD(1376/4999): loss=-363790.43597590004
LR_SGD(1377/4999): loss=-364421.704408221
LR_SGD(1378/4999): loss=-365103.2892737562
LR_SGD(1379/4999): loss=-365835.2831125192
LR_SGD(1380/4999): loss=-366674.4617387032
LR_SGD(1381/4999): loss=-367185.5563770079
LR_SGD(1382/4999): loss=-366326.46087114944
LR_SGD(1383/4999): loss=-366798.7002508946
LR_SGD(1384/4999): loss=-366369.99540184205
LR_SGD(1385/4999): loss=-367347.47349417757
LR_SGD(1

LR_SGD(1556/4999): loss=-414325.6684966084
LR_SGD(1557/4999): loss=-414941.06767560093
LR_SGD(1558/4999): loss=-415364.5375107102
LR_SGD(1559/4999): loss=-415499.595068283
LR_SGD(1560/4999): loss=-416640.8919088239
LR_SGD(1561/4999): loss=-417147.49769988947
LR_SGD(1562/4999): loss=-416613.8296313566
LR_SGD(1563/4999): loss=-415960.40691108076
LR_SGD(1564/4999): loss=-416726.09432749165
LR_SGD(1565/4999): loss=-416294.71051744814
LR_SGD(1566/4999): loss=-416748.80085538013
LR_SGD(1567/4999): loss=-416166.7397718276
LR_SGD(1568/4999): loss=-416803.83027041174
LR_SGD(1569/4999): loss=-417373.564575996
LR_SGD(1570/4999): loss=-418002.4322811641
LR_SGD(1571/4999): loss=-418485.1269561863
LR_SGD(1572/4999): loss=-419057.4192784681
LR_SGD(1573/4999): loss=-419644.34476994496
LR_SGD(1574/4999): loss=-420002.74410071754
LR_SGD(1575/4999): loss=-419430.1254348929
LR_SGD(1576/4999): loss=-418448.03796176467
LR_SGD(1577/4999): loss=-419642.1539034372
LR_SGD(1578/4999): loss=-420087.09510889364
LR

LR_SGD(1745/4999): loss=-455775.1843620554
LR_SGD(1746/4999): loss=-456328.60415053274
LR_SGD(1747/4999): loss=-457099.2676123786
LR_SGD(1748/4999): loss=-457660.39505043835
LR_SGD(1749/4999): loss=-457896.6437939507
LR_SGD(1750/4999): loss=-459012.4613838492
LR_SGD(1751/4999): loss=-459681.03916787595
LR_SGD(1752/4999): loss=-458556.5405418184
LR_SGD(1753/4999): loss=-459016.02159949945
LR_SGD(1754/4999): loss=-459474.43506936857
LR_SGD(1755/4999): loss=-459974.0316431091
LR_SGD(1756/4999): loss=-459675.7458044528
LR_SGD(1757/4999): loss=-458902.3047207451
LR_SGD(1758/4999): loss=-459434.0479338314
LR_SGD(1759/4999): loss=-458803.76060866565
LR_SGD(1760/4999): loss=-459229.8054433638
LR_SGD(1761/4999): loss=-458703.21192751464
LR_SGD(1762/4999): loss=-458230.8875748844
LR_SGD(1763/4999): loss=-458835.6523998683
LR_SGD(1764/4999): loss=-459483.32145768375
LR_SGD(1765/4999): loss=-459657.1030997025
LR_SGD(1766/4999): loss=-460098.3054884088
LR_SGD(1767/4999): loss=-459383.15792969876
LR

LR_SGD(1934/4999): loss=-498021.9565893892
LR_SGD(1935/4999): loss=-498456.5928718086
LR_SGD(1936/4999): loss=-498108.9221324121
LR_SGD(1937/4999): loss=-498641.4064710627
LR_SGD(1938/4999): loss=-498055.77381883183
LR_SGD(1939/4999): loss=-498638.4877129735
LR_SGD(1940/4999): loss=-499431.1646531879
LR_SGD(1941/4999): loss=-499979.1775167702
LR_SGD(1942/4999): loss=-500407.3031649442
LR_SGD(1943/4999): loss=-500994.83611828764
LR_SGD(1944/4999): loss=-500418.6955093383
LR_SGD(1945/4999): loss=-499975.6658354243
LR_SGD(1946/4999): loss=-499314.60090023035
LR_SGD(1947/4999): loss=-500120.3487882187
LR_SGD(1948/4999): loss=-500537.275781531
LR_SGD(1949/4999): loss=-501087.1563014806
LR_SGD(1950/4999): loss=-501486.9674623889
LR_SGD(1951/4999): loss=-502010.65326776996
LR_SGD(1952/4999): loss=-503598.3880941718
LR_SGD(1953/4999): loss=-504080.6542700455
LR_SGD(1954/4999): loss=-503144.2494354209
LR_SGD(1955/4999): loss=-503730.9095596682
LR_SGD(1956/4999): loss=-502861.7038848304
LR_SGD(1

LR_SGD(2126/4999): loss=-528460.5442162909
LR_SGD(2127/4999): loss=-528859.0737404037
LR_SGD(2128/4999): loss=-529265.4039017388
LR_SGD(2129/4999): loss=-528689.9667179955
LR_SGD(2130/4999): loss=-528100.7852492194
LR_SGD(2131/4999): loss=-529304.610428435
LR_SGD(2132/4999): loss=-528611.0595191632
LR_SGD(2133/4999): loss=-529269.739636192
LR_SGD(2134/4999): loss=-529876.1688986159
LR_SGD(2135/4999): loss=-530690.7254532367
LR_SGD(2136/4999): loss=-530119.2413338799
LR_SGD(2137/4999): loss=-529603.6853619086
LR_SGD(2138/4999): loss=-529118.182725926
LR_SGD(2139/4999): loss=-528631.448042456
LR_SGD(2140/4999): loss=-529054.084668665
LR_SGD(2141/4999): loss=-529723.7771837487
LR_SGD(2142/4999): loss=-531232.6344879692
LR_SGD(2143/4999): loss=-530551.0064278294
LR_SGD(2144/4999): loss=-531330.5802853118
LR_SGD(2145/4999): loss=-530809.7537106895
LR_SGD(2146/4999): loss=-531341.5657980435
LR_SGD(2147/4999): loss=-531823.2857284172
LR_SGD(2148/4999): loss=-532387.752141139
LR_SGD(2149/4999)

LR_SGD(2318/4999): loss=-560218.0577684335
LR_SGD(2319/4999): loss=-559839.2511415146
LR_SGD(2320/4999): loss=-560425.3965300897
LR_SGD(2321/4999): loss=-559756.0325773993
LR_SGD(2322/4999): loss=-558692.4023130554
LR_SGD(2323/4999): loss=-559209.8091496382
LR_SGD(2324/4999): loss=-559807.4312513093
LR_SGD(2325/4999): loss=-560320.1050045356
LR_SGD(2326/4999): loss=-561095.1554886912
LR_SGD(2327/4999): loss=-562164.9632896781
LR_SGD(2328/4999): loss=-562556.6379242837
LR_SGD(2329/4999): loss=-563224.9800251141
LR_SGD(2330/4999): loss=-562603.317425763
LR_SGD(2331/4999): loss=-563124.1583495404
LR_SGD(2332/4999): loss=-562176.0611336156
LR_SGD(2333/4999): loss=-562629.0861100911
LR_SGD(2334/4999): loss=-562077.5951368846
LR_SGD(2335/4999): loss=-561329.6010929742
LR_SGD(2336/4999): loss=-561801.7937061827
LR_SGD(2337/4999): loss=-562524.7756677106
LR_SGD(2338/4999): loss=-561853.6454044727
LR_SGD(2339/4999): loss=-562253.2599570454
LR_SGD(2340/4999): loss=-562999.8288232463
LR_SGD(2341/

LR_SGD(2512/4999): loss=-605588.5993701438
LR_SGD(2513/4999): loss=-606424.0531830911
LR_SGD(2514/4999): loss=-606955.0224392968
LR_SGD(2515/4999): loss=-607646.4163712498
LR_SGD(2516/4999): loss=-607000.75855195
LR_SGD(2517/4999): loss=-606321.6883314159
LR_SGD(2518/4999): loss=-607269.3056587767
LR_SGD(2519/4999): loss=-607420.4852490065
LR_SGD(2520/4999): loss=-607367.0304040196
LR_SGD(2521/4999): loss=-606809.7065676863
LR_SGD(2522/4999): loss=-607896.5796007403
LR_SGD(2523/4999): loss=-607362.7469828902
LR_SGD(2524/4999): loss=-607818.2911898486
LR_SGD(2525/4999): loss=-608470.6836904463
LR_SGD(2526/4999): loss=-608917.167203487
LR_SGD(2527/4999): loss=-608421.3097260065
LR_SGD(2528/4999): loss=-607827.7740201547
LR_SGD(2529/4999): loss=-608477.5128252867
LR_SGD(2530/4999): loss=-608921.5553629176
LR_SGD(2531/4999): loss=-609899.4639013278
LR_SGD(2532/4999): loss=-610571.667518823
LR_SGD(2533/4999): loss=-611287.0638945691
LR_SGD(2534/4999): loss=-611877.5686918403
LR_SGD(2535/499

LR_SGD(2704/4999): loss=-652959.5435872711
LR_SGD(2705/4999): loss=-653676.2547246261
LR_SGD(2706/4999): loss=-653175.1793931412
LR_SGD(2707/4999): loss=-652277.9541438483
LR_SGD(2708/4999): loss=-652879.0243333866
LR_SGD(2709/4999): loss=-653542.197175086
LR_SGD(2710/4999): loss=-653031.162485218
LR_SGD(2711/4999): loss=-653585.7559407668
LR_SGD(2712/4999): loss=-654628.2730903566
LR_SGD(2713/4999): loss=-655508.2193160879
LR_SGD(2714/4999): loss=-655026.9184704771
LR_SGD(2715/4999): loss=-654547.9577979103
LR_SGD(2716/4999): loss=-654137.1813266559
LR_SGD(2717/4999): loss=-653430.2602387868
LR_SGD(2718/4999): loss=-652949.1659323962
LR_SGD(2719/4999): loss=-653373.853410643
LR_SGD(2720/4999): loss=-653851.8798722872
LR_SGD(2721/4999): loss=-654448.3372900786
LR_SGD(2722/4999): loss=-654132.5421486922
LR_SGD(2723/4999): loss=-654682.1692952887
LR_SGD(2724/4999): loss=-654081.6928382433
LR_SGD(2725/4999): loss=-654659.6049072826
LR_SGD(2726/4999): loss=-654839.8212223344
LR_SGD(2727/49

LR_SGD(2899/4999): loss=-687773.846944856
LR_SGD(2900/4999): loss=-687774.8203610134
LR_SGD(2901/4999): loss=-688407.6834807054
LR_SGD(2902/4999): loss=-689005.358662797
LR_SGD(2903/4999): loss=-689457.766252249
LR_SGD(2904/4999): loss=-689968.8841714781
LR_SGD(2905/4999): loss=-689250.3824777261
LR_SGD(2906/4999): loss=-689767.0742983803
LR_SGD(2907/4999): loss=-690151.0505539145
LR_SGD(2908/4999): loss=-689400.345010509
LR_SGD(2909/4999): loss=-689848.004151589
LR_SGD(2910/4999): loss=-690316.0630277169
LR_SGD(2911/4999): loss=-689478.5898488947
LR_SGD(2912/4999): loss=-688588.3402973528
LR_SGD(2913/4999): loss=-689094.4306912322
LR_SGD(2914/4999): loss=-689645.239074969
LR_SGD(2915/4999): loss=-690580.1101018749
LR_SGD(2916/4999): loss=-691218.0310058529
LR_SGD(2917/4999): loss=-691643.3232322905
LR_SGD(2918/4999): loss=-691102.6868727216
LR_SGD(2919/4999): loss=-691632.4885067136
LR_SGD(2920/4999): loss=-692258.590810924
LR_SGD(2921/4999): loss=-693050.1747100969
LR_SGD(2922/4999):

LR_SGD(3094/4999): loss=-722951.56494558
LR_SGD(3095/4999): loss=-723571.9304860786
LR_SGD(3096/4999): loss=-724234.8844527606
LR_SGD(3097/4999): loss=-724686.1475898075
LR_SGD(3098/4999): loss=-724200.7950104001
LR_SGD(3099/4999): loss=-725066.6360263934
LR_SGD(3100/4999): loss=-725568.7410333082
LR_SGD(3101/4999): loss=-725112.8607705208
LR_SGD(3102/4999): loss=-725579.2183454664
LR_SGD(3103/4999): loss=-726125.1802436478
LR_SGD(3104/4999): loss=-727401.6730230948
LR_SGD(3105/4999): loss=-726749.5891006226
LR_SGD(3106/4999): loss=-727283.1588525891
LR_SGD(3107/4999): loss=-726733.1618569567
LR_SGD(3108/4999): loss=-727175.0238124821
LR_SGD(3109/4999): loss=-727686.896618324
LR_SGD(3110/4999): loss=-726822.7813992451
LR_SGD(3111/4999): loss=-727454.5268976159
LR_SGD(3112/4999): loss=-727945.7658878479
LR_SGD(3113/4999): loss=-729248.8306832408
LR_SGD(3114/4999): loss=-729907.9465400771
LR_SGD(3115/4999): loss=-729268.6291351838
LR_SGD(3116/4999): loss=-729694.8987287369
LR_SGD(3117/49

LR_SGD(3288/4999): loss=-763901.8777147474
LR_SGD(3289/4999): loss=-764515.8340326119
LR_SGD(3290/4999): loss=-763647.8696788193
LR_SGD(3291/4999): loss=-762938.3946014736
LR_SGD(3292/4999): loss=-762183.5821027753
LR_SGD(3293/4999): loss=-762686.6151494312
LR_SGD(3294/4999): loss=-763231.7770699039
LR_SGD(3295/4999): loss=-763614.0452129967
LR_SGD(3296/4999): loss=-764247.0215715565
LR_SGD(3297/4999): loss=-764674.7262215079
LR_SGD(3298/4999): loss=-763946.8556570635
LR_SGD(3299/4999): loss=-764466.1378783593
LR_SGD(3300/4999): loss=-764825.7553404997
LR_SGD(3301/4999): loss=-765717.2510788641
LR_SGD(3302/4999): loss=-766220.196479493
LR_SGD(3303/4999): loss=-765577.0330607919
LR_SGD(3304/4999): loss=-765029.8048971501
LR_SGD(3305/4999): loss=-765567.5104019374
LR_SGD(3306/4999): loss=-767409.3886135445
LR_SGD(3307/4999): loss=-767994.3803752608
LR_SGD(3308/4999): loss=-766962.854300915
LR_SGD(3309/4999): loss=-767808.0232861644
LR_SGD(3310/4999): loss=-768254.8575862668
LR_SGD(3311/4

LR_SGD(3482/4999): loss=-800785.6735859932
LR_SGD(3483/4999): loss=-801607.2862065946
LR_SGD(3484/4999): loss=-801090.7866191084
LR_SGD(3485/4999): loss=-801498.7441025667
LR_SGD(3486/4999): loss=-800768.6569402908
LR_SGD(3487/4999): loss=-801343.0416961324
LR_SGD(3488/4999): loss=-801969.2134281961
LR_SGD(3489/4999): loss=-801452.4475689804
LR_SGD(3490/4999): loss=-802095.2614432911
LR_SGD(3491/4999): loss=-802933.6804379055
LR_SGD(3492/4999): loss=-802200.2935272716
LR_SGD(3493/4999): loss=-803033.8253189388
LR_SGD(3494/4999): loss=-803366.2182348366
LR_SGD(3495/4999): loss=-802459.6015899811
LR_SGD(3496/4999): loss=-801914.0595227369
LR_SGD(3497/4999): loss=-800976.7293162977
LR_SGD(3498/4999): loss=-801791.9347841546
LR_SGD(3499/4999): loss=-802504.0319215881
LR_SGD(3500/4999): loss=-803363.8448173206
LR_SGD(3501/4999): loss=-803689.9269035804
LR_SGD(3502/4999): loss=-803186.4325628289
LR_SGD(3503/4999): loss=-803860.8071360164
LR_SGD(3504/4999): loss=-804914.6488093898
LR_SGD(3505

LR_SGD(3675/4999): loss=-840153.4135560783
LR_SGD(3676/4999): loss=-840996.8422972666
LR_SGD(3677/4999): loss=-840484.7797490075
LR_SGD(3678/4999): loss=-840928.1097696187
LR_SGD(3679/4999): loss=-840217.1196921201
LR_SGD(3680/4999): loss=-840067.697050789
LR_SGD(3681/4999): loss=-839437.7632583379
LR_SGD(3682/4999): loss=-840220.6123053302
LR_SGD(3683/4999): loss=-839503.5634401729
LR_SGD(3684/4999): loss=-838782.6928044796
LR_SGD(3685/4999): loss=-838328.392781141
LR_SGD(3686/4999): loss=-839009.6708240227
LR_SGD(3687/4999): loss=-839447.1185724037
LR_SGD(3688/4999): loss=-840020.7310450027
LR_SGD(3689/4999): loss=-840410.8829768309
LR_SGD(3690/4999): loss=-841116.1033797925
LR_SGD(3691/4999): loss=-840610.9178556101
LR_SGD(3692/4999): loss=-841051.3720138718
LR_SGD(3693/4999): loss=-841648.4497885554
LR_SGD(3694/4999): loss=-842167.2890197266
LR_SGD(3695/4999): loss=-842595.4483532165
LR_SGD(3696/4999): loss=-843017.0335831651
LR_SGD(3697/4999): loss=-843540.0909680409
LR_SGD(3698/4

LR_SGD(3868/4999): loss=-880087.0127851564
LR_SGD(3869/4999): loss=-880750.2973328133
LR_SGD(3870/4999): loss=-880337.1228806258
LR_SGD(3871/4999): loss=-881009.8386493075
LR_SGD(3872/4999): loss=-880573.3527648947
LR_SGD(3873/4999): loss=-881271.8315999841
LR_SGD(3874/4999): loss=-879910.6799816956
LR_SGD(3875/4999): loss=-880789.5761790358
LR_SGD(3876/4999): loss=-880600.2133354215
LR_SGD(3877/4999): loss=-881065.3946145955
LR_SGD(3878/4999): loss=-881708.4088813234
LR_SGD(3879/4999): loss=-882245.027600699
LR_SGD(3880/4999): loss=-882758.2330883571
LR_SGD(3881/4999): loss=-883312.262946529
LR_SGD(3882/4999): loss=-883866.6586298507
LR_SGD(3883/4999): loss=-883235.614145958
LR_SGD(3884/4999): loss=-884302.919539467
LR_SGD(3885/4999): loss=-885035.1100554359
LR_SGD(3886/4999): loss=-885111.2709092393
LR_SGD(3887/4999): loss=-885575.7895143097
LR_SGD(3888/4999): loss=-885095.1733078321
LR_SGD(3889/4999): loss=-885581.3529458977
LR_SGD(3890/4999): loss=-885961.9789491789
LR_SGD(3891/499

LR_SGD(4062/4999): loss=-917482.2282991776
LR_SGD(4063/4999): loss=-917243.2535493075
LR_SGD(4064/4999): loss=-916840.5310326657
LR_SGD(4065/4999): loss=-917397.2102616575
LR_SGD(4066/4999): loss=-916964.5972559137
LR_SGD(4067/4999): loss=-917575.8522756265
LR_SGD(4068/4999): loss=-918253.7322658004
LR_SGD(4069/4999): loss=-918167.4572411593
LR_SGD(4070/4999): loss=-918666.8646430235
LR_SGD(4071/4999): loss=-919381.0721051791
LR_SGD(4072/4999): loss=-919879.6271382523
LR_SGD(4073/4999): loss=-919932.4067611768
LR_SGD(4074/4999): loss=-920370.4615941002
LR_SGD(4075/4999): loss=-919937.2656415058
LR_SGD(4076/4999): loss=-920381.685878175
LR_SGD(4077/4999): loss=-920630.6661927304
LR_SGD(4078/4999): loss=-920671.4609474355
LR_SGD(4079/4999): loss=-920236.7793877814
LR_SGD(4080/4999): loss=-919777.3325855912
LR_SGD(4081/4999): loss=-919092.9494027357
LR_SGD(4082/4999): loss=-918741.1871861678
LR_SGD(4083/4999): loss=-919559.5525306045
LR_SGD(4084/4999): loss=-919008.4642976873
LR_SGD(4085/

LR_SGD(4254/4999): loss=-952244.4185101133
LR_SGD(4255/4999): loss=-952684.4313363056
LR_SGD(4256/4999): loss=-953267.2712969803
LR_SGD(4257/4999): loss=-953271.3219130577
LR_SGD(4258/4999): loss=-953761.4688953386
LR_SGD(4259/4999): loss=-954611.3493928338
LR_SGD(4260/4999): loss=-955078.2085860481
LR_SGD(4261/4999): loss=-954499.5614328238
LR_SGD(4262/4999): loss=-954055.0384436961
LR_SGD(4263/4999): loss=-954491.0231271513
LR_SGD(4264/4999): loss=-955066.4710813018
LR_SGD(4265/4999): loss=-954431.1713224271
LR_SGD(4266/4999): loss=-954081.1609706423
LR_SGD(4267/4999): loss=-953398.6021227902
LR_SGD(4268/4999): loss=-953886.9037017587
LR_SGD(4269/4999): loss=-953452.6681560245
LR_SGD(4270/4999): loss=-954213.8115076763
LR_SGD(4271/4999): loss=-954616.434079368
LR_SGD(4272/4999): loss=-955205.4714925308
LR_SGD(4273/4999): loss=-954402.9179491487
LR_SGD(4274/4999): loss=-954787.2905277874
LR_SGD(4275/4999): loss=-954223.9620123265
LR_SGD(4276/4999): loss=-954710.2005229527
LR_SGD(4277/

LR_SGD(4448/4999): loss=-979780.5636456341
LR_SGD(4449/4999): loss=-980512.854095316
LR_SGD(4450/4999): loss=-980889.4794135503
LR_SGD(4451/4999): loss=-981422.9341067112
LR_SGD(4452/4999): loss=-981425.3797242916
LR_SGD(4453/4999): loss=-982251.62805579
LR_SGD(4454/4999): loss=-982652.2076235336
LR_SGD(4455/4999): loss=-983077.8029357961
LR_SGD(4456/4999): loss=-983507.870225423
LR_SGD(4457/4999): loss=-984100.1278910256
LR_SGD(4458/4999): loss=-984734.144552678
LR_SGD(4459/4999): loss=-984171.5165964306
LR_SGD(4460/4999): loss=-983696.5801615587
LR_SGD(4461/4999): loss=-984238.8764598272
LR_SGD(4462/4999): loss=-985643.8255495349
LR_SGD(4463/4999): loss=-986141.2472002007
LR_SGD(4464/4999): loss=-986816.1964411839
LR_SGD(4465/4999): loss=-987523.3864403397
LR_SGD(4466/4999): loss=-987751.0366632864
LR_SGD(4467/4999): loss=-987246.2918968423
LR_SGD(4468/4999): loss=-987659.423515182
LR_SGD(4469/4999): loss=-988165.739630304
LR_SGD(4470/4999): loss=-989035.2219296581
LR_SGD(4471/4999):

LR_SGD(4641/4999): loss=-1008951.8335789791
LR_SGD(4642/4999): loss=-1009525.9213244566
LR_SGD(4643/4999): loss=-1010256.4622300351
LR_SGD(4644/4999): loss=-1010768.5555527476
LR_SGD(4645/4999): loss=-1011137.2682151933
LR_SGD(4646/4999): loss=-1010554.7788771207
LR_SGD(4647/4999): loss=-1011793.6476629768
LR_SGD(4648/4999): loss=-1011195.1727533508
LR_SGD(4649/4999): loss=-1010676.7727490446
LR_SGD(4650/4999): loss=-1010205.3797764657
LR_SGD(4651/4999): loss=-1010447.8067614823
LR_SGD(4652/4999): loss=-1010832.35142667
LR_SGD(4653/4999): loss=-1010187.2266158556
LR_SGD(4654/4999): loss=-1009602.7719050428
LR_SGD(4655/4999): loss=-1010171.9998343978
LR_SGD(4656/4999): loss=-1010584.7659645743
LR_SGD(4657/4999): loss=-1009975.1659855373
LR_SGD(4658/4999): loss=-1010565.4836381614
LR_SGD(4659/4999): loss=-1011267.5636214361
LR_SGD(4660/4999): loss=-1010394.7093246614
LR_SGD(4661/4999): loss=-1009939.1354330528
LR_SGD(4662/4999): loss=-1010492.1444170695
LR_SGD(4663/4999): loss=-1011048.1

LR_SGD(4829/4999): loss=-1038704.1268341827
LR_SGD(4830/4999): loss=-1038245.3239708432
LR_SGD(4831/4999): loss=-1037721.2030346827
LR_SGD(4832/4999): loss=-1037997.781304884
LR_SGD(4833/4999): loss=-1037288.2761313325
LR_SGD(4834/4999): loss=-1037758.4817434745
LR_SGD(4835/4999): loss=-1038206.819197384
LR_SGD(4836/4999): loss=-1038602.3452980428
LR_SGD(4837/4999): loss=-1039095.5802339448
LR_SGD(4838/4999): loss=-1039794.8978958213
LR_SGD(4839/4999): loss=-1040258.7956456756
LR_SGD(4840/4999): loss=-1041262.9463206516
LR_SGD(4841/4999): loss=-1041924.3946813942
LR_SGD(4842/4999): loss=-1042746.0850303433
LR_SGD(4843/4999): loss=-1043176.3645820749
LR_SGD(4844/4999): loss=-1043483.5449820863
LR_SGD(4845/4999): loss=-1042733.3399589525
LR_SGD(4846/4999): loss=-1042081.0986994827
LR_SGD(4847/4999): loss=-1042866.3751273626
LR_SGD(4848/4999): loss=-1043357.0254270534
LR_SGD(4849/4999): loss=-1042907.4294505067
LR_SGD(4850/4999): loss=-1043664.0470052388
LR_SGD(4851/4999): loss=-1044309.1

In [20]:
lrsgd_w

array([-0.80445368, -0.99812596, -1.0321358 ,  0.18005652,  0.74885115,
        1.28938028,  0.73451031,  0.69746881,  0.29480696, -1.67685883,
        0.70716195,  0.76729255,  0.74825969,  0.23677741,  0.74669472,
        0.74740555, -0.28458231,  0.74739993,  0.74569285, -0.1336589 ,
        0.74927495, -2.72017334,  0.73347159,  0.22603133,  0.74600548,
        0.74638318,  0.55940817,  0.74785392,  0.74601221, -0.13437263])

In [13]:
lrsgd_loss

-859519.2005315799

In [21]:
lrsgd_loss

-1064692.0518033062

###  Regualarize Logistic Regression

In [35]:

def penalized_logistic_regression(y, tx, w, lambda_):
    
    loss = calculate_negative_log_likelihood(y, tx, w) + lambda_ * np.squeeze(w.T.dot(w))
    gradient = calculate_gradient(y, tx, w) + 2 * lambda_ * w

    return loss, gradient

def reg_logistic_regression(y, tx, lambda_ , initial_w, max_iters, gamma):

    threshold = 1e-8
    losses = []
    w = initial_w
    
    # start the logistic regression
    for n_iter in range(max_iters):
        # get loss and update w.
        loss, gradient = penalized_logistic_regression(y, tx, w, lambda_)
        w = w - gamma * gradient

        # converge criteria
        losses.append(loss)
        print("LR_SGD({bi}/{ti}): loss={l}, norm of {w}".format(
              bi=n_iter, ti=max_iters - 1, l=loss, w= np.inner(w,w)))
        if len(losses) > 1 and np.abs(losses[-1] - losses[-2]) < threshold:
            break
            
    return losses[-1], w

In [37]:
## training by regularize logistci_regression

max_iters = 10000
gamma = 0.001
lambda_ = 0.1
initial_w = np.zeros(tx_train.shape[1])

rlrsgd_loss, rlrsgd_w = reg_logistic_regression(y_train, tx_train, lambda_, initial_w, max_iters, gamma) 

LR_SGD(0/9999): loss=173286.79513500613, norm of 359675.8399892417
LR_SGD(1/9999): loss=-1956519.1641846022, norm of 648240.6502098794
LR_SGD(2/9999): loss=-1931067.2114647499, norm of 1041795.6837580389
LR_SGD(3/9999): loss=-1911980.4348648079, norm of 1535636.9567420378
LR_SGD(4/9999): loss=-1875557.570681922, norm of 2127326.028329721
LR_SGD(5/9999): loss=-1821981.129957403, norm of 2816207.907059348
LR_SGD(6/9999): loss=-1756151.7730197443, norm of 3602050.93624609
LR_SGD(7/9999): loss=-1678535.5376062843, norm of 4484738.868256625
LR_SGD(8/9999): loss=-1591284.8667556585, norm of 5464190.157558347
LR_SGD(9/9999): loss=-1494094.31244508, norm of 6540334.098856934
LR_SGD(10/9999): loss=-1386996.5037327711, norm of 7713106.919103111
LR_SGD(11/9999): loss=-1270104.6293714917, norm of 8982448.11321969
LR_SGD(12/9999): loss=-1143450.2118311636, norm of 10348298.264770927
LR_SGD(13/9999): loss=-1007005.518884633, norm of 11810598.455969466
LR_SGD(14/9999): loss=-860931.7703172709, norm o

LR_SGD(125/9999): loss=73982216.91175236, norm of 772152452.9360347
LR_SGD(126/9999): loss=75172610.67859533, norm of 784146442.9139947
LR_SGD(127/9999): loss=76372010.86290525, norm of 796230443.703159
LR_SGD(128/9999): loss=77580411.84114912, norm of 808404400.3412186
LR_SGD(129/9999): loss=78797808.09982772, norm of 820668257.8916392
LR_SGD(130/9999): loss=80024194.25750393, norm of 833021961.4436514
LR_SGD(131/9999): loss=81259564.92809539, norm of 845465456.112239
LR_SGD(132/9999): loss=82503914.6727725, norm of 857998687.0381278
LR_SGD(133/9999): loss=83757238.02320956, norm of 870621599.387775
LR_SGD(134/9999): loss=85019529.5014877, norm of 883334138.3533579
LR_SGD(135/9999): loss=86290783.6283936, norm of 896136249.1527628
LR_SGD(136/9999): loss=87570994.92627898, norm of 909027877.0295738
LR_SGD(137/9999): loss=88860157.91996713, norm of 922008967.2530621
LR_SGD(138/9999): loss=90158267.1370003, norm of 935079465.118174
LR_SGD(139/9999): loss=91465317.10766146, norm of 948239

LR_SGD(249/9999): loss=288593702.4656512, norm of 2929023850.8848777
LR_SGD(250/9999): loss=290859766.8539617, norm of 2951767929.4241467
LR_SGD(251/9999): loss=293134175.2906803, norm of 2974595396.6173983
LR_SGD(252/9999): loss=295416922.58610415, norm of 2997506200.6154337
LR_SGD(253/9999): loss=297708003.55239207, norm of 3020500289.593486
LR_SGD(254/9999): loss=300007413.00345415, norm of 3043577611.7512193
LR_SGD(255/9999): loss=302315145.75487256, norm of 3066738115.3127117
LR_SGD(256/9999): loss=304631196.6238914, norm of 3089981748.526447
LR_SGD(257/9999): loss=306955560.4295387, norm of 3113308459.6653037
LR_SGD(258/9999): loss=309288231.99294513, norm of 3136718197.0265465
LR_SGD(259/9999): loss=311629206.13790494, norm of 3160210908.931811
LR_SGD(260/9999): loss=313978477.69164604, norm of 3183786543.7270994
LR_SGD(261/9999): loss=316336041.48565656, norm of 3207445049.7827654
LR_SGD(262/9999): loss=318701892.356304, norm of 3231186375.493505
LR_SGD(263/9999): loss=32107602

LR_SGD(372/9999): loss=628377152.8952876, norm of 6336739525.949954
LR_SGD(373/9999): loss=631631344.8709617, norm of 6369358684.989763
LR_SGD(374/9999): loss=634893260.8819728, norm of 6402055035.439564
LR_SGD(375/9999): loss=638162896.0339043, norm of 6434828528.378324
LR_SGD(376/9999): loss=641440245.4346582, norm of 6467679114.908192
LR_SGD(377/9999): loss=644725304.194454, norm of 6500606746.154478
LR_SGD(378/9999): loss=648018067.425827, norm of 6533611373.26565
LR_SGD(379/9999): loss=651318530.2436278, norm of 6566692947.413324
LR_SGD(380/9999): loss=654626687.7650219, norm of 6599851419.79225
LR_SGD(381/9999): loss=657942535.1094873, norm of 6633086741.620305
LR_SGD(382/9999): loss=661266067.3988152, norm of 6666398864.138478
LR_SGD(383/9999): loss=664597279.7571069, norm of 6699787738.610872
LR_SGD(384/9999): loss=667936167.3107753, norm of 6733253316.3246765
LR_SGD(385/9999): loss=671282725.1885418, norm of 6766795548.590177
LR_SGD(386/9999): loss=674636948.5214368, norm of 6

KeyboardInterrupt: 

In [None]:
rlrsgd_w

In [None]:
rlrsgd_loss

## Prediction 


In [25]:
# y_pred
y_pred  = predict_labels(lrsgd_w, tx_test)

In [26]:
# Output to csv
create_csv_submission(idx_test, y_pred, "lrsgd_00000001")