In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [166]:
import numpy as np

import sys
sys.path.append('../')

from scripts.proj1_helpers import *
from scripts.implementations import *

SEED = 42

# Load train data

In [144]:
# Load data
y, tx, ids = load_csv_data('../data/train.csv')

# Normalise data
tx, mean_tx, std_tx = standardise(tx)

# Check shape of data
print('Shape: y: {}, x:{}\n'.format(y.shape, tx.shape))

num_samples, num_dim = tx.shape

# Check that data is normalised
print(np.mean(tx, axis=0), np.std(tx, axis=0))

# Assign fixed training and evaluation indices
num_train = 200000

Shape: y: (250000,), x:(250000, 30)

[-2.50602916e-15  4.49575133e-15 -3.48448848e-15  7.18646387e-15
 -2.36304576e-14 -3.26035021e-15  1.26038877e-14  2.16223188e-14
  6.40057962e-15  2.86143687e-15 -6.98486646e-15  3.63458152e-15
 -1.27422117e-14 -5.95722149e-15  1.35646161e-16  7.13136217e-17
  2.58023760e-14 -1.06327391e-16 -1.87188487e-16  8.24115935e-15
  1.41040513e-16 -8.99509711e-15 -6.01698247e-16 -4.92204144e-15
  3.11615622e-15 -1.67606551e-15 -9.40773592e-15  1.79148900e-14
 -5.09692022e-15 -1.77122317e-15] [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1.]


# Load test data

In [139]:
y_test, tx_test, ids_test = load_csv_data('../data/test.csv')

# Don't forget to standardise to same mean and std
tx_test = standardise_to_fixed(tx_test, mean_tx, std_tx)

# 0 Random baseline guessing

In [142]:
# Get baseline frequency of the two classes in training data
prior_probs = [sum(y == 1)/len(y), sum(y == -1)/len(y)]

y_test_pred = np.random.choice([1., -1.], size=len(y_test), p=prior_probs)

# Save in submission file
create_csv_submission(ids_test, y_test_pred, '../data/random_basline_submission.csv')

The accuracy on the test after submission to the AICrowd platform for the random guess model is 55%. This is thus our baseline. Anything that goes below that is probably overfitting or a model that diverged. 

# 1 Least-squares gradient descent

## 1.1 Least-quares on the full training set 

In [149]:
# Define hyperparameters for gradient descent
max_iters = 1000
gamma = .1

# Initial weights vector to train a linear model
initial_w = np.zeros(num_dim)

# Run gradient descent under MSE loss to find optimal weights
w_GD, final_loss = least_squares_GD(y[:num_train], tx[:num_train], initial_w, max_iters, gamma)

Gradient Descent(0/999): loss=0.46302369257530496, gradient=0.7908081706112834
Gradient Descent(1/999): loss=0.44782831225546155, gradient=0.4231183169268697
Gradient Descent(2/999): loss=0.4382439498837941, gradient=0.325546515051735
Gradient Descent(3/999): loss=0.43175901491845725, gradient=0.2662201416801101
Gradient Descent(4/999): loss=0.4272113966264951, gradient=0.22201985134552246
Gradient Descent(5/999): loss=0.42390440456240824, gradient=0.18851349590642563
Gradient Descent(6/999): loss=0.42140632371606807, gradient=0.16311584449542835
Gradient Descent(7/999): loss=0.41944644929357205, gradient=0.14385347902931647
Gradient Descent(8/999): loss=0.4178530446904408, gradient=0.12918949612887667
Gradient Descent(9/999): loss=0.4165156849613261, gradient=0.1179388923814379
Gradient Descent(10/999): loss=0.4153622603653287, gradient=0.10920298782622377
Gradient Descent(11/999): loss=0.4143448480295925, gradient=0.10231246993725816
Gradient Descent(12/999): loss=0.4134309718652348,

Gradient Descent(102/999): loss=0.3920327235587327, gradient=0.025062996232642173
Gradient Descent(103/999): loss=0.39197160701641554, gradient=0.024789413109721818
Gradient Descent(104/999): loss=0.3919118146043944, gradient=0.024519287275240127
Gradient Descent(105/999): loss=0.39185331542727564, gradient=0.024252564108145056
Gradient Descent(106/999): loss=0.3917960793846402, gradient=0.023989190175049118
Gradient Descent(107/999): loss=0.39174007714773396, gradient=0.023729113192294732
Gradient Descent(108/999): loss=0.3916852801369714, gradient=0.02347228198974433
Gradient Descent(109/999): loss=0.39163166050020976, gradient=0.023218646476192532
Gradient Descent(110/999): loss=0.39157919109177397, gradient=0.0229681576063026
Gradient Descent(111/999): loss=0.391527845452189, gradient=0.02272076734898251
Gradient Descent(112/999): loss=0.39147759778859836, gradient=0.022476428657115542
Gradient Descent(113/999): loss=0.3914284229558347, gradient=0.02223509543857135
Gradient Descent

Gradient Descent(211/999): loss=0.3893967479696381, gradient=0.008193106308547147
Gradient Descent(212/999): loss=0.38939019661940677, gradient=0.008113711119389163
Gradient Descent(213/999): loss=0.38938377151565123, gradient=0.008035140375648363
Gradient Descent(214/999): loss=0.38937747013992835, gradient=0.007957384649672428
Gradient Descent(215/999): loss=0.3893712900261917, gradient=0.00788043463761372
Gradient Descent(216/999): loss=0.38936522875964197, gradient=0.007804281157530653
Gradient Descent(217/999): loss=0.38935928397560654, gradient=0.007728915147522407
Gradient Descent(218/999): loss=0.3893534533584418, gradient=0.007654327663897324
Gradient Descent(219/999): loss=0.38934773464046435, gradient=0.00758050987937362
Gradient Descent(220/999): loss=0.38934212560090475, gradient=0.0075074530813124245
Gradient Descent(221/999): loss=0.3893366240648864, gradient=0.007435148669981448
Gradient Descent(222/999): loss=0.38933122790242836, gradient=0.007363588156849745
Gradient 

Gradient Descent(317/999): loss=0.3890963947373059, gradient=0.0030136794609826618
Gradient Descent(318/999): loss=0.3890955070584135, gradient=0.0029862049913179337
Gradient Descent(319/999): loss=0.38909463547984346, gradient=0.002958996082556012
Gradient Descent(320/999): loss=0.38909377970082104, gradient=0.0029320500756464516
Gradient Descent(321/999): loss=0.38909293942632184, gradient=0.002905364340325179
Gradient Descent(322/999): loss=0.38909211436695956, gradient=0.00287893627476276
Gradient Descent(323/999): loss=0.3890913042388747, gradient=0.0028527633052176323
Gradient Descent(324/999): loss=0.389090508763626, gradient=0.0028268428856942918
Gradient Descent(325/999): loss=0.38908972766808525, gradient=0.0028011724976059397
Gradient Descent(326/999): loss=0.38908896068433174, gradient=0.002775749649442294
Gradient Descent(327/999): loss=0.3890882075495523, gradient=0.0027505718764418256
Gradient Descent(328/999): loss=0.3890874680059399, gradient=0.002725636740268563
Gradi

Gradient Descent(418/999): loss=0.38905470456701324, gradient=0.0012315294722451344
Gradient Descent(419/999): loss=0.38905455608301936, gradient=0.00122112088489584
Gradient Descent(420/999): loss=0.3890544100949687, gradient=0.0012108113839698616
Gradient Descent(421/999): loss=0.3890542665582454, gradient=0.0012006000271121747
Gradient Descent(422/999): loss=0.38905412542905365, gradient=0.0011904858811354634
Gradient Descent(423/999): loss=0.38905398666440144, gradient=0.0011804680219236476
Gradient Descent(424/999): loss=0.38905385022208655, gradient=0.0011705455343367223
Gradient Descent(425/999): loss=0.38905371606068073, gradient=0.001160717512116654
Gradient Descent(426/999): loss=0.3890535841395162, gradient=0.0011509830577942925
Gradient Descent(427/999): loss=0.3890534544186708, gradient=0.0011413412825974146
Gradient Descent(428/999): loss=0.3890533268589549, gradient=0.0011317913063598427
Gradient Descent(429/999): loss=0.38905320142189603, gradient=0.0011223322574314884


Gradient Descent(525/999): loss=0.3890471253530169, gradient=0.0005314489561316192
Gradient Descent(526/999): loss=0.38904709760197415, gradient=0.0005277178534600986
Gradient Descent(527/999): loss=0.3890470702379888, gradient=0.0005240223970066309
Gradient Descent(528/999): loss=0.38904704325468425, gradient=0.0005203622434241767
Gradient Descent(529/999): loss=0.3890470166457957, gradient=0.0005167370525070577
Gradient Descent(530/999): loss=0.3890469904051685, gradient=0.0005131464871618233
Gradient Descent(531/999): loss=0.3890469645267562, gradient=0.0005095902133783268
Gradient Descent(532/999): loss=0.3890469390046186, gradient=0.0005060679002012802
Gradient Descent(533/999): loss=0.3890469138329194, gradient=0.0005025792197018978
Gradient Descent(534/999): loss=0.3890468890059255, gradient=0.0004991238469500395
Gradient Descent(535/999): loss=0.3890468645180034, gradient=0.0004957014599866286
Gradient Descent(536/999): loss=0.3890468403636187, gradient=0.0004923117397962557
Gr

Gradient Descent(630/999): loss=0.3890455166338639, gradient=0.0002825894996055408
Gradient Descent(631/999): loss=0.389045508745773, gradient=0.0002812018156648169
Gradient Descent(632/999): loss=0.3890455009345652, gradient=0.0002798266383914093
Gradient Descent(633/999): loss=0.38904549319917775, gradient=0.00027846384039694914
Gradient Descent(634/999): loss=0.38904548553856366, gradient=0.0002771132955655907
Gradient Descent(635/999): loss=0.3890454779516939, gradient=0.0002757748790433806
Gradient Descent(636/999): loss=0.3890454704375553, gradient=0.0002744484672278429
Gradient Descent(637/999): loss=0.38904546299515075, gradient=0.0002731339377572026
Gradient Descent(638/999): loss=0.389045455623499, gradient=0.00027183116950019686
Gradient Descent(639/999): loss=0.38904544832163496, gradient=0.00027054004254547923
Gradient Descent(640/999): loss=0.38904544108860856, gradient=0.0002692604381914804
Gradient Descent(641/999): loss=0.3890454339234844, gradient=0.000267992238936098

Gradient Descent(738/999): loss=0.3890449582284646, gradient=0.00018472834122836373
Gradient Descent(739/999): loss=0.3890449548419485, gradient=0.0001841648985585047
Gradient Descent(740/999): loss=0.3890449514759438, gradient=0.00018360570473605802
Gradient Descent(741/999): loss=0.3890449481302342, gradient=0.00018305072015299196
Gradient Descent(742/999): loss=0.3890449448046067, gradient=0.00018249990564572064
Gradient Descent(743/999): loss=0.389044941498851, gradient=0.00018195322249022272
Gradient Descent(744/999): loss=0.38904493821275976, gradient=0.0001814106323969471
Gradient Descent(745/999): loss=0.3890449349461274, gradient=0.000180872097505887
Gradient Descent(746/999): loss=0.3890449316987524, gradient=0.0001803375803819253
Gradient Descent(747/999): loss=0.3890449284704346, gradient=0.00017980704400959643
Gradient Descent(748/999): loss=0.3890449252609772, gradient=0.00017928045178864605
Gradient Descent(749/999): loss=0.38904492207018565, gradient=0.00017875776752911

Gradient Descent(840/999): loss=0.3890446905357401, gradient=0.00014387022514785158
Gradient Descent(841/999): loss=0.3890446884757666, gradient=0.0001435946179188134
Gradient Descent(842/999): loss=0.38904468642363244, gradient=0.00014332082173141946
Gradient Descent(843/999): loss=0.38904468437927064, gradient=0.000143048823304446
Gradient Descent(844/999): loss=0.3890446823426163, gradient=0.00014277860948404945
Gradient Descent(845/999): loss=0.3890446803136039, gradient=0.00014251016724201908
Gradient Descent(846/999): loss=0.38904467829216954, gradient=0.00014224348367432454
Gradient Descent(847/999): loss=0.3890446762782497, gradient=0.00014197854599949227
Gradient Descent(848/999): loss=0.3890446742717805, gradient=0.00014171534155709296
Gradient Descent(849/999): loss=0.38904467227270084, gradient=0.0001414538578063365
Gradient Descent(850/999): loss=0.389044670280948, gradient=0.00014119408232437828
Gradient Descent(851/999): loss=0.3890446682964612, gradient=0.00014093600280

Gradient Descent(946/999): loss=0.38904450566114085, gradient=0.0001226737161124199
Gradient Descent(947/999): loss=0.3890445041605165, gradient=0.00012253458503782305
Gradient Descent(948/999): loss=0.38904450266327184, gradient=0.00012239634170433633
Gradient Descent(949/999): loss=0.3890445011693815, gradient=0.0001222589805118648
Gradient Descent(950/999): loss=0.3890444996788202, gradient=0.00012212249589848447
Gradient Descent(951/999): loss=0.3890444981915634, gradient=0.0001219868823400083
Gradient Descent(952/999): loss=0.3890444967075858, gradient=0.00012185213434972188
Gradient Descent(953/999): loss=0.38904449522686335, gradient=0.00012171824647798729
Gradient Descent(954/999): loss=0.38904449374937167, gradient=0.00012158521331191161
Gradient Descent(955/999): loss=0.38904449227508625, gradient=0.00012145302947500937
Gradient Descent(956/999): loss=0.38904449080398396, gradient=0.00012132168962681915
Gradient Descent(957/999): loss=0.3890444893360403, gradient=0.0001211911

In [164]:
# Test error on evaluation set
y_pred = predict_labels(w_GD, tx[num_train:])

acc = get_accuracy(y_pred, y[num_train:])

f1 = get_f1_score(y_pred, y[num_train:])

print('Accuracy on evaluation set: ', acc)
print('F1 Score on evaluation set:', f1)

Accuracy on evaluation set:  0.71658
F1 Score on evaluation set: 0.6647742057578123


In [165]:
# Get predictions
# Get predictions from current model
y_test_pred = predict_labels(w_GD, tx_test)

# Save in submission file
create_csv_submission(ids_test, y_test_pred, '../data/test_full_gd_submission.csv')

## 1.2 Least-squares gradient descent with shuffle splits

In [122]:
# Define hyperparameters for gradient descent
max_iters = 50
gamma = .1

num_samples, num_dim = tx.shape
# Initial weights vector to train a linear model
initial_w = np.zeros(num_dim)

res = {
    'weights': {},
    'accuracy': {},
    'loss': {}
}
n_iter = 0

for train_data, eval_data in train_eval_split(y, tx, train_size=.7, num_splits=5):
    # Get training data
    y_train, tx_train = train_data
    
    # Run gradient descent under MSE loss to find optimal weights
    final_w, final_loss = least_squares_GD(y_train, tx_train, initial_w, max_iters, gamma)
    
    # Get validation set
    y_eval, tx_eval = eval_data
    
    # Get predictions from current model
    y_pred = predict_labels(final_w, tx_eval)
    
    acc = get_accuracy(y_pred, y_eval)
    f1 = get_f1_score(y_pred, y_eval)
    
    print('Accuracy on evaluation set', acc)
    print('F1 Score on evaluation set:', f1, '\n')
    
    res['weights'][n_iter] = w
    res['loss'][n_iter] = loss
    res['accuracy'][n_iter] = acc
    
    n_iter += 1

# Select model with highest accuracy on validation set
iter_max_acc = max(res['accuracy'], key=res['accuracy'].get)
w_max_acc = res['weights'][iter_max_acc]

Gradient Descent(0/49): loss=0.4636087303159818, gradient=0.7850384279684134
Gradient Descent(1/49): loss=0.44856737946369746, gradient=0.42135566888668025
Gradient Descent(2/49): loss=0.43907337842098365, gradient=0.3239920655829415
Gradient Descent(3/49): loss=0.4326351490635258, gradient=0.2651820191539501
Gradient Descent(4/49): loss=0.428106852004358, gradient=0.2214672862075326
Gradient Descent(5/49): loss=0.4248031527475808, gradient=0.18834801945977317
Gradient Descent(6/49): loss=0.42229957092091547, gradient=0.163237290452147
Gradient Descent(7/49): loss=0.4203298629959851, gradient=0.14417116938505284
Gradient Descent(8/49): loss=0.41872496575343865, gradient=0.12962655796411296
Gradient Descent(9/49): loss=0.41737599713032336, gradient=0.11843430296549144
Gradient Descent(10/49): loss=0.4162116878207794, gradient=0.10971152110227202
Gradient Descent(11/49): loss=0.4151845274903265, gradient=0.10280296649612077
Gradient Descent(12/49): loss=0.41426219842172696, gradient=0.09

Gradient Descent(10/49): loss=0.3921403376329163, gradient=0.0229728546973048
Gradient Descent(11/49): loss=0.39208910207371, gradient=0.02270146646570106
Gradient Descent(12/49): loss=0.39203904358779296, gradient=0.022438162604586195
Gradient Descent(13/49): loss=0.39199011906774084, gradient=0.02218175426612966
Gradient Descent(14/49): loss=0.39194229016810234, gradient=0.02193132993247839
Gradient Descent(15/49): loss=0.39189552227461916, gradient=0.02168618563760748
Gradient Descent(16/49): loss=0.39184978373460055, gradient=0.02144577334208858
Gradient Descent(17/49): loss=0.3918050452788412, gradient=0.021209662635221685
Gradient Descent(18/49): loss=0.39176127958464035, gradient=0.020977512237130347
Gradient Descent(19/49): loss=0.3917184609432256, gradient=0.020749048721963103
Gradient Descent(20/49): loss=0.3916765650047805, gradient=0.02052405057518716
Gradient Descent(21/49): loss=0.39163556858143445, gradient=0.020302336202183465
Gradient Descent(22/49): loss=0.39159544949

Gradient Descent(17/49): loss=0.39010295529124167, gradient=0.007877280576943044
Gradient Descent(18/49): loss=0.39009691382045275, gradient=0.007793295909653595
Gradient Descent(19/49): loss=0.39009099861799945, gradient=0.007711208291203996
Gradient Descent(20/49): loss=0.3900852058441769, gradient=0.007630810567350824
Gradient Descent(21/49): loss=0.3900795319729104, gradient=0.007551938401720225
Gradient Descent(22/49): loss=0.39007397373532643, gradient=0.007474459832794413
Gradient Descent(23/49): loss=0.3900685280765424, gradient=0.007398267636524544
Gradient Descent(24/49): loss=0.39006319212219526, gradient=0.007323273686185222
Gradient Descent(25/49): loss=0.3900579631522196, gradient=0.007249404746061085
Gradient Descent(26/49): loss=0.3900528385800815, gradient=0.007176599303582832
Gradient Descent(27/49): loss=0.3900478159361572, gradient=0.007104805160387658
Gradient Descent(28/49): loss=0.39004289285429455, gradient=0.007033977583143347
Gradient Descent(29/49): loss=0.39

In [121]:
# Get predictions from current model
y_test_pred = predict_labels(w_max_acc, tx_test)

# Save in submission file
create_csv_submission(ids_test, y_test_pred, '../data/test_gd_submission.csv')

# 2 Stochastic Gradients Descent

## 2.1 SGD on full train set

In [168]:
# Define hyperparameters for gradient descent
max_iters = 1000
gamma = .05

# Initial weights vector to train a linear model
initial_w = np.zeros(num_dim)

# Run gradient descent under MSE loss to find optimal weights
w_SGD, final_loss = least_squares_SGD(y[:num_train], tx[:num_train], initial_w, max_iters, gamma)

Stochastic GD(0/999): loss=1.0865822393851747, gradient=6.525919471192617
Stochastic GD(1/999): loss=0.9363195151602712, gradient=1.025062591200858
Stochastic GD(2/999): loss=0.5268541412074781, gradient=9.660429064750197
Stochastic GD(3/999): loss=0.6883064226293868, gradient=4.267578519170905
Stochastic GD(4/999): loss=1.6160450639036497, gradient=13.944059306019678
Stochastic GD(5/999): loss=1.2914355804503512, gradient=8.06334569933903
Stochastic GD(6/999): loss=0.70226911397233, gradient=12.916730927432324
Stochastic GD(7/999): loss=0.7231163660158121, gradient=0.31264957677856164
Stochastic GD(8/999): loss=0.7833571400794119, gradient=8.965745356647332
Stochastic GD(9/999): loss=0.8353742240977803, gradient=8.004047082123268
Stochastic GD(10/999): loss=0.9088542320174826, gradient=3.1523079330177612
Stochastic GD(11/999): loss=1.076731339133384, gradient=2.595904927429233
Stochastic GD(12/999): loss=1.3723312235977123, gradient=6.700153812596393
Stochastic GD(13/999): loss=1.1618

Stochastic GD(114/999): loss=1.164525266099925, gradient=2.8845140894141124
Stochastic GD(115/999): loss=1.1767568913343613, gradient=8.268182041854752
Stochastic GD(116/999): loss=1.1717983210462695, gradient=4.012042186344007
Stochastic GD(117/999): loss=1.1610092981798414, gradient=0.24841412350071787
Stochastic GD(118/999): loss=1.1742100635084023, gradient=1.7139143454244947
Stochastic GD(119/999): loss=1.1351625956711893, gradient=0.7456435115982656
Stochastic GD(120/999): loss=1.1401834459328635, gradient=0.24400604251183225
Stochastic GD(121/999): loss=1.104484356168274, gradient=21.628244187875804
Stochastic GD(122/999): loss=1.071041669599375, gradient=10.025936185620571
Stochastic GD(123/999): loss=1.1076253141837353, gradient=1.521658903580734
Stochastic GD(124/999): loss=1.199487531569984, gradient=1.2851973291531207
Stochastic GD(125/999): loss=1.1346213191463723, gradient=0.7268304234907952
Stochastic GD(126/999): loss=2.049194273676389, gradient=8.674590137210973
Stocha

Stochastic GD(223/999): loss=13.998186816828623, gradient=1.4415834675362624
Stochastic GD(224/999): loss=10.236833560819194, gradient=16.422669922400672
Stochastic GD(225/999): loss=9.03942141599162, gradient=4.3807939231995086
Stochastic GD(226/999): loss=6.372050065216415, gradient=12.31626057966485
Stochastic GD(227/999): loss=6.776691871190891, gradient=2.9288653868350356
Stochastic GD(228/999): loss=7.752933201476841, gradient=7.333391668356544
Stochastic GD(229/999): loss=6.062737927013622, gradient=22.725706432225
Stochastic GD(230/999): loss=7.185964079843374, gradient=5.668679288558566
Stochastic GD(231/999): loss=6.428679074204069, gradient=8.471538455326575
Stochastic GD(232/999): loss=5.848490108446267, gradient=6.768843607403906
Stochastic GD(233/999): loss=5.170227357327626, gradient=5.168488392831792
Stochastic GD(234/999): loss=5.917302983866327, gradient=6.643330197431717
Stochastic GD(235/999): loss=4.392194185929733, gradient=21.343600226672173
Stochastic GD(236/999

Stochastic GD(334/999): loss=2.957047856980724, gradient=2.483472776815927
Stochastic GD(335/999): loss=3.0631447285956335, gradient=1.4618991797608876
Stochastic GD(336/999): loss=1.903414048627217, gradient=21.950077565068433
Stochastic GD(337/999): loss=7.713817357521416, gradient=27.91737870310512
Stochastic GD(338/999): loss=6.876043056517786, gradient=6.3395158449741595
Stochastic GD(339/999): loss=3.047858647015711, gradient=18.578390977806848
Stochastic GD(340/999): loss=3.050184898456895, gradient=6.68098209336275
Stochastic GD(341/999): loss=2.6025278420033584, gradient=15.057419052002848
Stochastic GD(342/999): loss=2.7720485467581732, gradient=11.947988459262733
Stochastic GD(343/999): loss=2.764212816143823, gradient=3.5514879245493978
Stochastic GD(344/999): loss=3.086860259821067, gradient=20.628742501494102
Stochastic GD(345/999): loss=3.0729977696419675, gradient=0.7605579922054277
Stochastic GD(346/999): loss=4.046466963920949, gradient=4.377561328008037
Stochastic GD

Stochastic GD(446/999): loss=115.7931613204399, gradient=64.85107817787723
Stochastic GD(447/999): loss=418.833887238492, gradient=246.7971767577736
Stochastic GD(448/999): loss=123.6214745289024, gradient=112.0494916933814
Stochastic GD(449/999): loss=80.29992499285719, gradient=83.30257943264797
Stochastic GD(450/999): loss=81.3063222665424, gradient=4.859695172613611
Stochastic GD(451/999): loss=88.08725460466239, gradient=23.32899424510628
Stochastic GD(452/999): loss=87.17302686997738, gradient=15.570299170440238
Stochastic GD(453/999): loss=58.85832438117472, gradient=45.235116052033604
Stochastic GD(454/999): loss=51.85330379030135, gradient=132.19862452959887
Stochastic GD(455/999): loss=40.473779967065965, gradient=55.92094962546169
Stochastic GD(456/999): loss=51.68177735921977, gradient=56.196099758129755
Stochastic GD(457/999): loss=50.78252760718044, gradient=21.79920497459572
Stochastic GD(458/999): loss=68.22372570182857, gradient=91.27887137958133
Stochastic GD(459/999)

Stochastic GD(560/999): loss=34.09749158676234, gradient=54.33870499211575
Stochastic GD(561/999): loss=39.842783568692184, gradient=41.7628673578068
Stochastic GD(562/999): loss=20.7322915872465, gradient=36.914110484204116
Stochastic GD(563/999): loss=21.99212872202456, gradient=4.050725537525123
Stochastic GD(564/999): loss=23.203035911288, gradient=3.0892314032672066
Stochastic GD(565/999): loss=23.1827910282679, gradient=5.320175097008159
Stochastic GD(566/999): loss=19.44567271013995, gradient=10.134590667160648
Stochastic GD(567/999): loss=19.256210848599594, gradient=2.8431100265338993
Stochastic GD(568/999): loss=20.08740557731888, gradient=3.4538676990731587
Stochastic GD(569/999): loss=22.246966027321182, gradient=55.824138849947914
Stochastic GD(570/999): loss=16.372719016433727, gradient=48.991401840135985
Stochastic GD(571/999): loss=17.85297437161979, gradient=13.083853491176399
Stochastic GD(572/999): loss=25.315548311145815, gradient=18.70855143313284
Stochastic GD(573

Stochastic GD(674/999): loss=219.91158922262852, gradient=101.10836970970335
Stochastic GD(675/999): loss=92.70509186701103, gradient=145.48705184939905
Stochastic GD(676/999): loss=137.43938500486414, gradient=170.71627248416462
Stochastic GD(677/999): loss=118.2806905737928, gradient=55.83427995417974
Stochastic GD(678/999): loss=113.68576783293564, gradient=29.986360007005946
Stochastic GD(679/999): loss=111.4396388792542, gradient=3.357705139367266
Stochastic GD(680/999): loss=57.68140758970869, gradient=54.89536841700876
Stochastic GD(681/999): loss=37.29995498926726, gradient=24.714977528229987
Stochastic GD(682/999): loss=30.85587680328066, gradient=23.306478972043767
Stochastic GD(683/999): loss=29.585876272244953, gradient=9.24358810125027
Stochastic GD(684/999): loss=29.651313370238118, gradient=5.639252268997131
Stochastic GD(685/999): loss=29.311493901938526, gradient=14.127908390075314
Stochastic GD(686/999): loss=32.57301941080194, gradient=32.87277075289814
Stochastic GD

Stochastic GD(784/999): loss=4.083658556257722, gradient=0.8632362696489172
Stochastic GD(785/999): loss=3.1911584374899107, gradient=16.08241563448345
Stochastic GD(786/999): loss=3.0941675811364355, gradient=9.160151293705226
Stochastic GD(787/999): loss=2.4314544132530105, gradient=15.398088663732716
Stochastic GD(788/999): loss=4.934401268475685, gradient=18.274833432294194
Stochastic GD(789/999): loss=5.277297708647494, gradient=1.1739149462095428
Stochastic GD(790/999): loss=5.482827034246987, gradient=3.070450816263617
Stochastic GD(791/999): loss=7.516448373994164, gradient=14.863412962794504
Stochastic GD(792/999): loss=2.3601107698785864, gradient=21.064285588898713
Stochastic GD(793/999): loss=7.690484399142604, gradient=19.678962347429014
Stochastic GD(794/999): loss=3.0708370818831225, gradient=15.81808259866107
Stochastic GD(795/999): loss=3.052036880073408, gradient=2.4724582072928083
Stochastic GD(796/999): loss=3.9453049162407057, gradient=5.2131776421490805
Stochastic

Stochastic GD(896/999): loss=571.8429637125446, gradient=195.1983908797485
Stochastic GD(897/999): loss=469.2584030807271, gradient=158.38518563649168
Stochastic GD(898/999): loss=669.1833002853998, gradient=195.32757544965958
Stochastic GD(899/999): loss=503.2057586826922, gradient=136.06290865421974
Stochastic GD(900/999): loss=462.81252598509974, gradient=55.27646464472657
Stochastic GD(901/999): loss=443.7914445009039, gradient=50.077684807128826
Stochastic GD(902/999): loss=461.16022842143985, gradient=47.37590977155046
Stochastic GD(903/999): loss=562.6624036905679, gradient=88.55981806348484
Stochastic GD(904/999): loss=612.0837021159268, gradient=34.317776773894785
Stochastic GD(905/999): loss=3502.8379096126646, gradient=621.7416132188725
Stochastic GD(906/999): loss=928.4461144139889, gradient=377.58562818393705
Stochastic GD(907/999): loss=857.3713885403967, gradient=19.59612256930251
Stochastic GD(908/999): loss=892.2551221064604, gradient=88.3991058961939
Stochastic GD(909

In [169]:
# Test error on evaluation set
y_pred = predict_labels(w_SGD, tx[num_train:])

acc = get_accuracy(y_pred, y[num_train:])

f1 = get_f1_score(y_pred, y[num_train:])

print('Accuracy on evaluation set: ', acc)
print('F1 Score on evaluation set:', f1)

Accuracy on evaluation set:  0.44978
F1 Score on evaluation set: 0.42485313486505133


- Did not converge
- Much lower accuracy than using GD

## 2.1 SGD with decreasing step size on full train set

In [171]:
# Define hyperparameters for gradient descent
max_iters = 1000

# Initial weights vector to train a linear model
initial_w = np.zeros(num_dim)

# Run gradient descent under MSE loss to find optimal weights
w_SGD, final_loss = least_squares_SGD_robbinson(y[:num_train], tx[:num_train], initial_w, max_iters, r_gamma=.9)

Stochastic GD(0/999): loss=16.35255658039335, gradient=3.8023062458213737, gamma=1.0
Stochastic GD(1/999): loss=95.934992576537, gradient=12.6390394818709, gamma=0.5358867312681466
Stochastic GD(2/999): loss=99.66671244489346, gradient=24.48823507126843, gamma=0.3720410580113015
Stochastic GD(3/999): loss=1340.3337702415868, gradient=70.32707992489746, gamma=0.2871745887492588
Stochastic GD(4/999): loss=206744.6026092897, gradient=1053.4956040605766, gamma=0.23492378861760377
Stochastic GD(5/999): loss=381525.8913760723, gradient=2554.2290131936975, gamma=0.19937186647521926
Stochastic GD(6/999): loss=523892.8763343596, gradient=4250.018806282206, gamma=0.1735448634341524
Stochastic GD(7/999): loss=26351298.082916945, gradient=17671.722071076252, gamma=0.1538930516681145
Stochastic GD(8/999): loss=25444428.046165343, gradient=1842.6124219351257, gamma=0.1384145488461686
Stochastic GD(9/999): loss=199139601.98674992, gradient=82384.53650109218, gamma=0.1258925411794167
Stochastic GD(10/

Stochastic GD(85/999): loss=7184769.220637817, gradient=0.39160612685572915, gamma=0.018153124821286473
Stochastic GD(86/999): loss=7361034.093857449, gradient=8612.598532500744, gamma=0.017965225491403707
Stochastic GD(87/999): loss=7437377.310543047, gradient=1768.8174925956464, gamma=0.017781385398068894
Stochastic GD(88/999): loss=8099812.981549001, gradient=14107.398094735096, gamma=0.017601472199121834
Stochastic GD(89/999): loss=7926188.030924297, gradient=2177.222470625811, gamma=0.017425359290446665
Stochastic GD(90/999): loss=11849642.816809403, gradient=139142.62174858587, gamma=0.017252925496828494
Stochastic GD(91/999): loss=9012651.472398793, gradient=30464.37033447014, gamma=0.017084054782646442
Stochastic GD(92/999): loss=8890805.178752044, gradient=5244.682822088282, gamma=0.0169186359809306
Stochastic GD(93/999): loss=9125135.612766331, gradient=13803.372651612914, gamma=0.016756562539434097
Stochastic GD(94/999): loss=8874039.41913435, gradient=3703.0066951271556, ga

Stochastic GD(166/999): loss=3138104.987693898, gradient=1908.1943048622386, gamma=0.009989762791730734
Stochastic GD(167/999): loss=3078311.7805720824, gradient=4905.304761779567, gamma=0.009936230242919435
Stochastic GD(168/999): loss=3069833.451440691, gradient=871.771040577718, gamma=0.009883299717957457
Stochastic GD(169/999): loss=3060652.2128203386, gradient=3668.0199943983575, gamma=0.009830960943993114
Stochastic GD(170/999): loss=3033282.948295755, gradient=2036.981663682319, gamma=0.009779203882536683
Stochastic GD(171/999): loss=2928098.1759401397, gradient=16523.44298178031, gamma=0.009728018722781867
Stochastic GD(172/999): loss=2926422.130041741, gradient=575.6100846247995, gamma=0.009677395875155075
Stochastic GD(173/999): loss=2861264.8772900533, gradient=30964.705489127664, gamma=0.009627325965083516
Stochastic GD(174/999): loss=2860841.2113097534, gradient=4619.001489054357, gamma=0.009577799826973434
Stochastic GD(175/999): loss=2860472.839809608, gradient=516.78200

Stochastic GD(246/999): loss=2354121.0382039165, gradient=11833.804851666955, gamma=0.007023809790436891
Stochastic GD(247/999): loss=2335429.318021749, gradient=5558.590905541952, gamma=0.0069983150114799495
Stochastic GD(248/999): loss=2393516.773758779, gradient=5079.385955711548, gamma=0.006973014810144937
Stochastic GD(249/999): loss=2214444.232164853, gradient=20547.15587226046, gamma=0.006947906928878747
Stochastic GD(250/999): loss=2220510.0617780266, gradient=1729.2460672322793, gamma=0.006922989145212803
Stochastic GD(251/999): loss=2229738.547857016, gradient=1396.605962358033, gamma=0.006898259271080725
Stochastic GD(252/999): loss=2215298.0453764345, gradient=4285.444029022799, gamma=0.006873715152151909
Stochastic GD(253/999): loss=2751982.366471568, gradient=43322.74517585272, gamma=0.006849354667180587
Stochastic GD(254/999): loss=2614666.9625814212, gradient=6997.620864189729, gamma=0.006825175727369958
Stochastic GD(255/999): loss=2604621.567549966, gradient=12337.687

Stochastic GD(330/999): loss=1763253.3784556636, gradient=1577.1084303629218, gamma=0.00539702967012292
Stochastic GD(331/999): loss=1767327.537369438, gradient=1706.001482411605, gamma=0.005382396962178927
Stochastic GD(332/999): loss=1670333.7120568121, gradient=13965.796336280353, gamma=0.005367847756716303
Stochastic GD(333/999): loss=1563162.790396309, gradient=9751.321598024464, gamma=0.005353381328604738
Stochastic GD(334/999): loss=1525744.750372198, gradient=10194.891784983953, gamma=0.0053389969611570024
Stochastic GD(335/999): loss=1526677.3959219204, gradient=132.61208457502417, gamma=0.005324693946005798
Stochastic GD(336/999): loss=1488999.282912948, gradient=5613.278882003491, gamma=0.005310471582982767
Stochastic GD(337/999): loss=1427527.232938977, gradient=9356.806729723174, gamma=0.005296329179999617
Stochastic GD(338/999): loss=1354411.8454399558, gradient=12294.373337334922, gamma=0.005282266052931309
Stochastic GD(339/999): loss=1350684.2365938255, gradient=2654.0

Stochastic GD(410/999): loss=1058149.1693317085, gradient=1348.9431143660393, gamma=0.004441629821203413
Stochastic GD(411/999): loss=1062316.4514418063, gradient=2349.585086930607, gamma=0.004431926053240175
Stochastic GD(412/999): loss=1069505.0462566216, gradient=3045.1962306274477, gamma=0.004422266932720329
Stochastic GD(413/999): loss=1069385.242087634, gradient=97.81401629071735, gamma=0.004412652146858248
Stochastic GD(414/999): loss=1077338.4545111544, gradient=4689.893132267426, gamma=0.004403081385808092
Stochastic GD(415/999): loss=1080606.3874898432, gradient=1363.530348524372, gamma=0.004393554342629172
Stochastic GD(416/999): loss=1074239.9699550509, gradient=4889.3301832453335, gamma=0.0043840707132518164
Stochastic GD(417/999): loss=1082140.6855090533, gradient=3120.663874726164, gamma=0.004374630196443704
Stochastic GD(418/999): loss=1078113.0210344992, gradient=2989.568204449233, gamma=0.0043652324937766855
Stochastic GD(419/999): loss=1055661.5024383736, gradient=95

Stochastic GD(494/999): loss=917206.6075724778, gradient=3448.1033765162415, gamma=0.0037571222024642362
Stochastic GD(495/999): loss=921156.5754770181, gradient=7018.845369811097, gamma=0.003750304155886792
Stochastic GD(496/999): loss=918773.3949983905, gradient=1678.0250461162912, gamma=0.0037435121769011925
Stochastic GD(497/999): loss=926629.83689611, gradient=5544.655350482632, gamma=0.003736746113692946
Stochastic GD(498/999): loss=927514.1353705536, gradient=5635.889645239736, gamma=0.003730005815634205
Stochastic GD(499/999): loss=926745.1114733404, gradient=7903.20372937772, gamma=0.0037232911332721382
Stochastic GD(500/999): loss=913220.8240043976, gradient=9441.659259737566, gamma=0.0037166019183174355
Stochastic GD(501/999): loss=905184.2248322933, gradient=7113.141187840996, gamma=0.0037099380236329488
Stochastic GD(502/999): loss=902101.2724116843, gradient=4826.897209608455, gamma=0.003703299303222471
Stochastic GD(503/999): loss=896502.892158165, gradient=11863.0004643

Stochastic GD(577/999): loss=809125.323631913, gradient=1111.485567561473, gamma=0.0032678713092796838
Stochastic GD(578/999): loss=802630.329122183, gradient=3517.3229929128684, gamma=0.003262791277640699
Stochastic GD(579/999): loss=812230.470499801, gradient=5307.7587524784085, gamma=0.003257727888920602
Stochastic GD(580/999): loss=790048.2619972802, gradient=6657.98970845685, gamma=0.0032526810600408654
Stochastic GD(581/999): loss=794198.9640735043, gradient=2381.801003421128, gamma=0.0032476507084797253
Stochastic GD(582/999): loss=793047.973233449, gradient=476.68824800409766, gamma=0.003242636752267496
Stochastic GD(583/999): loss=785843.5588548294, gradient=3574.0623217646944, gamma=0.0032376391099819364
Stochastic GD(584/999): loss=784458.1398558625, gradient=1139.7789268271285, gamma=0.003232657700743671
Stochastic GD(585/999): loss=787772.762672688, gradient=2086.2303454048333, gamma=0.0032276924442116426
Stochastic GD(586/999): loss=781661.4923060148, gradient=3729.015084

Stochastic GD(659/999): loss=705675.7545848422, gradient=5551.6078864420515, gamma=0.0029000832035602935
Stochastic GD(660/999): loss=707080.6713739558, gradient=2012.3111181588222, gamma=0.0028961342286334944
Stochastic GD(661/999): loss=708284.209135618, gradient=1738.7175083187522, gamma=0.002892196588479375
Stochastic GD(662/999): loss=706941.0777380007, gradient=1471.1433513768472, gamma=0.0028882702335152467
Stochastic GD(663/999): loss=706562.3047976329, gradient=4018.84738358006, gamma=0.002884355114449667
Stochastic GD(664/999): loss=704442.4327341992, gradient=3101.594878190796, gamma=0.002880451182280293
Stochastic GD(665/999): loss=703730.4228347833, gradient=979.9734170142525, gamma=0.0028765583882917525
Stochastic GD(666/999): loss=706086.6389532988, gradient=11759.839658752717, gamma=0.0028726766840535397
Stochastic GD(667/999): loss=772265.877949239, gradient=36982.395102160386, gamma=0.002868806021417921
Stochastic GD(668/999): loss=773057.2355572925, gradient=386.0002

Stochastic GD(743/999): loss=645023.1134578531, gradient=11522.716134922226, gamma=0.0026036605211673743
Stochastic GD(744/999): loss=631683.3468784709, gradient=24138.230819166376, gamma=0.00260051494826188
Stochastic GD(745/999): loss=634235.5766629742, gradient=4777.13907302143, gamma=0.002597377387406497
Stochastic GD(746/999): loss=634285.4741444248, gradient=1446.6539202136128, gamma=0.0025942478074947968
Stochastic GD(747/999): loss=625901.6863286791, gradient=25637.450454380625, gamma=0.0025911261775825476
Stochastic GD(748/999): loss=626730.43901585, gradient=5049.842722507226, gamma=0.0025880124668866532
Stochastic GD(749/999): loss=626782.4970137761, gradient=99.86345884014696, gamma=0.002584906644784101
Stochastic GD(750/999): loss=628712.4992142755, gradient=9332.303124158001, gamma=0.0025818086808109156
Stochastic GD(751/999): loss=631025.0967859031, gradient=2441.714602337855, gamma=0.002578718544661124
Stochastic GD(752/999): loss=623703.1401745855, gradient=13818.59286

Stochastic GD(827/999): loss=634620.1775489583, gradient=11650.020444173566, gamma=0.002364681735203236
Stochastic GD(828/999): loss=634831.3056442805, gradient=467.0356552159872, gamma=0.002362114374553321
Stochastic GD(829/999): loss=654668.5683597331, gradient=6208.8271270839805, gamma=0.002359552891348319
Stochastic GD(830/999): loss=657287.2182106065, gradient=777.4434627425305, gamma=0.002356997265076057
Stochastic GD(831/999): loss=633190.0249039126, gradient=9419.256168425129, gamma=0.0023544474753205177
Stochastic GD(832/999): loss=613353.5048756183, gradient=7597.653503740058, gamma=0.0023519035017612734
Stochastic GD(833/999): loss=609635.6908330228, gradient=1530.0394301095173, gamma=0.0023493653241729276
Stochastic GD(834/999): loss=610522.5249234467, gradient=1840.7028349655086, gamma=0.0023468329224245544
Stochastic GD(835/999): loss=613759.5112507752, gradient=4816.807467165511, gamma=0.0023443062764791464
Stochastic GD(836/999): loss=615182.8593194935, gradient=4191.77

Stochastic GD(908/999): loss=503953.69985467574, gradient=1811.0314026634005, gamma=0.0021741650248709166
Stochastic GD(909/999): loss=503800.9995811914, gradient=2818.2007292267954, gamma=0.0021720146335748895
Stochastic GD(910/999): loss=505436.15680515073, gradient=11701.834736502022, gamma=0.0021698687274246792
Stochastic GD(911/999): loss=504686.7612092477, gradient=6561.4702362260905, gamma=0.002167727292157526
Stochastic GD(912/999): loss=499859.9349310616, gradient=10060.145389994828, gamma=0.0021655903135715995
Stochastic GD(913/999): loss=499817.2190371988, gradient=801.5310970856251, gamma=0.0021634577775256703
Stochastic GD(914/999): loss=499639.0282454583, gradient=1405.7979827382906, gamma=0.0021613296699387876
Stochastic GD(915/999): loss=499653.99545086693, gradient=3244.066586716093, gamma=0.0021592059767899545
Stochastic GD(916/999): loss=502466.8713347104, gradient=13421.186465389654, gamma=0.002157086684117809
Stochastic GD(917/999): loss=503062.7872618855, gradient

Stochastic GD(991/999): loss=479001.54178041033, gradient=5689.490253787048, gamma=0.0020097382353595184
Stochastic GD(992/999): loss=477060.01234670717, gradient=4214.6064797528725, gamma=0.0020079166285909165
Stochastic GD(993/999): loss=476794.5326486432, gradient=1428.8358072358396, gamma=0.002006098503942084
Stochastic GD(994/999): loss=476971.5538447424, gradient=2455.601344752711, gamma=0.002004283851263619
Stochastic GD(995/999): loss=473356.27975351305, gradient=16273.49907698161, gamma=0.002002472660445863
Stochastic GD(996/999): loss=472558.09887811093, gradient=3756.572954236715, gamma=0.0020006649214187053
Stochastic GD(997/999): loss=472268.15124078665, gradient=5901.102497372167, gamma=0.001998860624151391
Stochastic GD(998/999): loss=472447.6164560726, gradient=10777.76049069184, gamma=0.0019970597586523244
Stochastic GD(999/999): loss=472576.5982912012, gradient=412.6488524607178, gamma=0.0019952623149688794
