In [1]:
## Import stuff

import numpy as np
from sklearn import datasets, linear_model, metrics

In [2]:
## Load the diabetes dataset
diabetes = datasets.load_diabetes()
diabetes_X = diabetes.data # matrix of dimensions 442x10

# Split the data into training/testing sets
diabetes_X_train = diabetes_X[:-20]
diabetes_X_test = diabetes_X[-20:]

# Split the targets into training/testing sets
diabetes_y_train = diabetes.target[:-20]
diabetes_y_test = diabetes.target[-20:]


In [3]:
## Linear Regression using Gradient Descent
## Our own implementation

# train
X = diabetes_X_train
y = diabetes_y_train

# train: init
# Intialize weights
W = np.random.uniform(low=-0.1, high=0.1, size=diabetes_X.shape[1])
b = 0.0

learning_rate = 0.1
epochs = 100000

# train: gradient descent

for i in range(epochs):

    # calculate predictions
    y_predict = X.dot(W) + b

    # calculate error and cost (mean squared error)
    error = y - y_predict

    mean_squared_error = np.mean(np.power(error, 2))

    # calculate gradients
    W_gradient = -(1.0/len(X)) * error.dot(X)
    b_gradient = -(1.0/len(X)) * np.sum(error)

    # update parameters

    W = W - (learning_rate * W_gradient)
    b = b - (learning_rate * b_gradient)

    # diagnostic output
    if i % 5000 == 0:
        print("Epoch %d: %f" % (i, mean_squared_error))



Epoch 0: 29468.773393
Epoch 5000: 3048.214979
Epoch 10000: 2941.415977
Epoch 15000: 2927.458036
Epoch 20000: 2924.752853
Epoch 25000: 2923.795426
Epoch 30000: 2923.195711
Epoch 35000: 2922.694455
Epoch 40000: 2922.231313
Epoch 45000: 2921.789482
Epoch 50000: 2921.363355
Epoch 55000: 2920.950549
Epoch 60000: 2920.549720
Epoch 65000: 2920.159922
Epoch 70000: 2919.780402
Epoch 75000: 2919.410529
Epoch 80000: 2919.049756
Epoch 85000: 2918.697611
Epoch 90000: 2918.353675
Epoch 95000: 2918.017580


In [5]:
# test

X = diabetes_X_test
y = diabetes_y_test

y_predict = X.dot(W) + b

error = y - y_predict

mean_squared_error = np.mean(np.power(error, 2))

print("Mean squared error: %.2f" % mean_squared_error)

print("="*80)

Mean squared error: 1993.53
