In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [12]:
import numpy as np

import sys
sys.path.append('../')

from scripts.proj1_helpers import *
from scripts.implementations import *

# Load train data

In [13]:
# Load data
y, tx, ids = load_csv_data('../data/train.csv')

# Normalise data
tx, mean_tx, std_tx = standardise(tx)

# Check shape of data
print('Shape: y: {}, x:{}\n'.format(y.shape, tx.shape))

# Check that data is normalised
print(np.mean(tx, axis=0), np.std(tx, axis=0))

# Load test data

In [None]:
y_test, tx_test, ids_test = load_csv_data('../data/test.csv')

# Don't forget to standardise to same mean and std
tx_test = standardise_to_fixed(tx_test, mean_tx, std_tx)

# 1 Least-squares gradient descent with shuffle split 

In [87]:
# Define hyperparameters for gradient descent
max_iters = 100
gamma = .8

num_samples, num_dim = tx.shape
# Initial weights vector to train a linear model
initial_w = np.random.randn(num_dim)

res = {
    'weights': {},
    'accuracy': {},
    'loss': {}
}
n_iter = 0

for train_data, eval_data in train_eval_split(y, tx, train_size=.8, num_splits=5):
    # Get training data
    y_train, tx_train = train_data
    
    # Run gradient descent under MSE loss to find optimal weights
    final_w, final_loss = least_squares_GD(y_train, tx_train, initial_w, max_iters, gamma)
    
    # Get validation set
    y_eval, tx_eval = eval_data
    
    # Get predictions from current model
    y_pred = predict_labels(final_w, tx_eval)
    
    acc = get_accuracy(y_pred, y_eval)
    
    print('Accuracy of predictions using least-squares gradient descent', acc)
    
    res['weights'][n_iter] = w
    res['loss'][n_iter] = loss
    res['accuracy'][n_iter] = acc
    
    n_iter += 1

# Select model with highest accuracy on validation set
iter_max_acc = max(res['accuracy'], key=res['accuracy'].get)
w_max_acc = res['weights'][iter_max_acc]

Gradient Descent(0/99): loss=447.78343902876844
Gradient Descent(1/99): loss=34659.25048998579
Gradient Descent(2/99): loss=2692563.018168814
Gradient Descent(3/99): loss=209181694.48257086
Gradient Descent(4/99): loss=16251055788.886072
Gradient Descent(5/99): loss=1262523547308.5322
Gradient Descent(6/99): loss=98083824722823.39
Gradient Descent(7/99): loss=7620005735948605.0
Gradient Descent(8/99): loss=5.91988409709504e+17
Gradient Descent(9/99): loss=4.599081541068685e+19
Gradient Descent(10/99): loss=3.572967084233661e+21
Gradient Descent(11/99): loss=2.775792007821361e+23
Gradient Descent(12/99): loss=2.1564769809060515e+25
Gradient Descent(13/99): loss=1.6753391306244195e+27
Gradient Descent(14/99): loss=1.3015493452761717e+29
Gradient Descent(15/99): loss=1.011156885924012e+31
Gradient Descent(16/99): loss=7.855547326441059e+32
Gradient Descent(17/99): loss=6.102873318373776e+34
Gradient Descent(18/99): loss=4.741243505052176e+36
Gradient Descent(19/99): loss=3.683410879023360

Gradient Descent(59/99): loss=1.5274096635855817e+114
Gradient Descent(60/99): loss=1.1867762174868195e+116
Gradient Descent(61/99): loss=9.221087334788969e+117
Gradient Descent(62/99): loss=7.164657530453889e+119
Gradient Descent(63/99): loss=5.566839968538691e+121
Gradient Descent(64/99): loss=4.3253577862718605e+123
Gradient Descent(65/99): loss=3.3607432735620363e+125
Gradient Descent(66/99): loss=2.611251116992e+127
Gradient Descent(67/99): loss=2.0289060606420338e+129
Gradient Descent(68/99): loss=1.5764319931251474e+131
Gradient Descent(69/99): loss=1.2248658906179852e+133
Gradient Descent(70/99): loss=9.517038835434712e+134
Gradient Descent(71/99): loss=7.394607759832043e+136
Gradient Descent(72/99): loss=5.7455081215154855e+138
Gradient Descent(73/99): loss=4.464180474009327e+140
Gradient Descent(74/99): loss=3.4686065850115974e+142
Gradient Descent(75/99): loss=2.6950594205660418e+144
Gradient Descent(76/99): loss=2.0940239552585207e+146
Gradient Descent(77/99): loss=1.627027

Gradient Descent(15/99): loss=1.0131358410322848e+31
Gradient Descent(16/99): loss=7.871978394981742e+32
Gradient Descent(17/99): loss=6.11645954484449e+34
Gradient Descent(18/99): loss=4.752436488846097e+36
Gradient Descent(19/99): loss=3.692602299569379e+38
Gradient Descent(20/99): loss=2.869120244907397e+40
Gradient Descent(21/99): loss=2.229281767142244e+42
Gradient Descent(22/99): loss=1.7321327700133527e+44
Gradient Descent(23/99): loss=1.345852272770454e+46
Gradient Descent(24/99): loss=1.0457156469058876e+48
Gradient Descent(25/99): loss=8.125120686037636e+49
Gradient Descent(26/99): loss=6.313148928966743e+51
Gradient Descent(27/99): loss=4.905262449553882e+53
Gradient Descent(28/99): loss=3.811346757337075e+55
Gradient Descent(29/99): loss=2.9613836678575416e+57
Gradient Descent(30/99): loss=2.300969653671894e+59
Gradient Descent(31/99): loss=1.7878336416129872e+61
Gradient Descent(32/99): loss=1.389131371194938e+63
Gradient Descent(33/99): loss=1.0793431343516684e+65
Gradien

Gradient Descent(79/99): loss=7.651133498347021e+151
Gradient Descent(80/99): loss=5.926576926016048e+153
Gradient Descent(81/99): loss=4.590733394937428e+155
Gradient Descent(82/99): loss=3.555987438698519e+157
Gradient Descent(83/99): loss=2.7544720148912254e+159
Gradient Descent(84/99): loss=2.1336172333600118e+161
Gradient Descent(85/99): loss=1.65270239591474e+163
Gradient Descent(86/99): loss=1.2801852022730738e+165
Gradient Descent(87/99): loss=9.916329498704828e+166
Gradient Descent(88/99): loss=7.681200388216165e+168
Gradient Descent(89/99): loss=5.949866773954815e+170
Gradient Descent(90/99): loss=4.608773738297512e+172
Gradient Descent(91/99): loss=3.5699615096931715e+174
Gradient Descent(92/99): loss=2.765296346571895e+176
Gradient Descent(93/99): loss=2.142001773296742e+178
Gradient Descent(94/99): loss=1.659197070322781e+180
Gradient Descent(95/99): loss=1.2852159846397787e+182
Gradient Descent(96/99): loss=9.955297997556518e+183
Gradient Descent(97/99): loss=7.7113854328

In [88]:
# Get predictions from current model
y_test_pred = predict_labels(w_max_acc, tx_test)

# Save in submission file
create_csv_submission(ids_test, y_test_pred, '../data/test_gd_submission.csv')

In [89]:
res

{'weights': {0: array([-6.51882623e+45,  6.61806111e+45,  9.54326342e+44, -2.12129413e+46,
         -2.53441436e+46, -2.45930336e+46, -2.53424448e+46,  1.19929520e+46,
         -1.01660450e+46, -2.40068198e+46, -1.54773217e+45, -1.48437282e+46,
         -2.53444675e+46, -8.70184270e+45, -1.91063115e+44, -1.31385745e+44,
         -7.49727759e+45, -4.09575238e+44,  3.34459667e+43, -1.25714311e+46,
         -2.22472410e+44, -2.25745208e+46, -2.62799025e+46, -2.14062558e+46,
         -2.04440605e+46, -2.04441732e+46, -2.54758582e+46, -2.53443595e+46,
         -2.53442253e+46, -2.46622080e+46]),
  1: array([-6.51882623e+45,  6.61806111e+45,  9.54326342e+44, -2.12129413e+46,
         -2.53441436e+46, -2.45930336e+46, -2.53424448e+46,  1.19929520e+46,
         -1.01660450e+46, -2.40068198e+46, -1.54773217e+45, -1.48437282e+46,
         -2.53444675e+46, -8.70184270e+45, -1.91063115e+44, -1.31385745e+44,
         -7.49727759e+45, -4.09575238e+44,  3.34459667e+43, -1.25714311e+46,
         -2.22