In [225]:
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn import preprocessing
class LinearRegression:
    '''
    A class which implements linear regression model with gradient descent.
    '''
    def __init__(self, learning_rate=0.01, n_iterations=2000):
        self.learning_rate = learning_rate
        self.n_iterations = n_iterations
        self.weights, self.bias = None, None
        self.loss = []
        
    @staticmethod
    def _mean_squared_error(y, y_hat):
        '''
        Private method, used to evaluate loss at each iteration.
        
        :param: y - array, true values
        :param: y_hat - array, predicted values
        :return: float
        '''
        error = 0
        for i in range(len(y)):
            error += (y[i] - y_hat[i]) ** 2
        return error / len(y)
    
    def fit(self, X, y):
        '''
        Used to calculate the coefficient of the linear regression model.
        
        :param X: array, features
        :param y: array, true values
        :return: None
        '''
        # 1. Initialize weights and bias to zeros
        self.weights = np.zeros(X.shape[1])
        self.bias = 0
        
        # 2. Perform gradient descent
        for i in range(self.n_iterations):
            # Line equation
            y_hat = np.dot(X, self.weights) + self.bias
            loss = self._mean_squared_error(y, y_hat)
            
            self.loss.append(loss)
            
            # Calculate derivatives
            partial_w = (1 / X.shape[0]) * (2 * np.dot(X.T, (y_hat - y)))
#             partial_w = partial_w/100
            print("partial_w printing",partial_w)
            partial_d = (1 / X.shape[0]) * (2 * np.sum(y_hat - y)).astype(int)
#             partial_d = partial_d/100
            print("partial_d printing", partial_d)
            
            self.weights -= (self.learning_rate * partial_w).astype(int) /100
            self.bias -= (self.learning_rate * partial_d).astype(int) /100
            
            
        print("printing y_hat", y_hat)
        print("printing weights", self.weights)
        print("printing bias", self.bias)
    def predict(self, X):
        '''
        Makes predictions using the line equation.
        
        :param X: array, features
        :return: array, predictions
        '''
        return np.dot(X, self.weights) + self.bias

In [226]:
from sklearn.datasets import load_diabetes

data = load_diabetes()
X = data.data
y = data.target

In [227]:
# print(X)

In [228]:
# print(len(y))

In [229]:
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train.shape
# X_train = np.dot(X_train,1000).astype(int)
# X_test = np.dot(X_test,1000).astype(int)
# y_train = np.dot(y_train,1000).astype(int)
# y_test = np.dot(y_test,1000).astype(int)



(353, 10)

In [230]:
# model = LinearRegression()
# model.fit(X_train, y_train)
# # preds = model.predict(X_test)

In [231]:
# xs = np.arange(len(model.loss))
# ys = model.loss

# plt.plot(xs, ys, lw=3, c='#087E8B')
# plt.title('Loss per iteration (MSE)', size=20)
# plt.xlabel('Iteration', size=14)
# plt.ylabel('Loss', size=14)
# plt.show()

In [232]:
# losses = {}
# for lr in [0.5, 0.1, 0.01, 0.001]:
#     model = LinearRegression(learning_rate=lr)
#     model.fit(X_train, y_train)
#     losses[f'LR={str(lr)}'] = model.loss
    
    
# xs = np.arange(len(model.loss))

# plt.plot(xs, losses['LR=0.5'], lw=3, label=f"LR = 0.5, Final = {losses['LR=0.5'][-1]:.2f}")
# plt.plot(xs, losses['LR=0.1'], lw=3, label=f"LR = 0.1, Final = {losses['LR=0.1'][-1]:.2f}")
# plt.plot(xs, losses['LR=0.01'], lw=3, label=f"LR = 0.01, Final = {losses['LR=0.01'][-1]:.2f}")
# plt.plot(xs, losses['LR=0.001'], lw=3, label=f"LR = 0.001, Final = {losses['LR=0.001'][-1]:.2f}")
# # plt.title('Loss per iteration (MSE) for different learning rates with python decimal operations', size=20)
# plt.xlabel('Iteration', size=14)
# plt.ylabel('Loss', size=14)
# plt.legend()
# plt.savefig("python-mse.pdf")
# plt.show()

In [233]:
# model = LinearRegression(learning_rate=0.5)
# model.fit(X_train, y_train)
# preds = model.predict(X_test)

# model._mean_squared_error(y_test, preds)

In [234]:
# from sklearn.linear_model import LinearRegression
# from sklearn.metrics import mean_squared_error

# lr_model = LinearRegression()
# lr_model.fit(X_train, y_train)
# lr_preds = lr_model.predict(X_test)

# mean_squared_error(y_test, lr_preds)

In [235]:


X_train_num = np.dot(X_train,1000).astype(int)
X_test_num = np.dot(X_test,1000).astype(int)
y_train_num = y_train
y_test_num = y_test

In [236]:
model = LinearRegression()
model.fit(X_train_num, y_train)
preds = model.predict(X_test_num)

partial_w printing [-1826.49858357  -114.76487252 -4959.59773371 -3700.74220963
 -1317.86968839  -904.34560907  3059.09348442 -3329.86402266
 -4461.96033994 -3497.65439093]
partial_d printing -307.4730878186969
partial_w printing [ 1332.71592068  1907.02776204   882.79807365  1158.26560907
  3101.8403966   3293.76492918 -1532.16447592  3045.32770538
  1951.7929745   2133.97648725]
partial_d printing -301.20113314447593
partial_w printing [ -984.66804533  -239.22504249 -2116.76764873 -1453.22096317
 -1406.26458924 -1302.68084986  1376.51643059 -2058.19490085
 -2203.18045326 -1551.79320113]
partial_d printing -303.22946175637395
partial_w printing [ 683.31626062  852.39637394  759.33501416  851.2176204  1449.1384136
 1467.1705949  -918.73484419 1526.06073654 1175.56203966 1308.86028329]
partial_d printing -300.5665722379604
partial_w printing [ -516.67915014  -213.19082153  -969.43671388  -649.86526912
  -782.82815864  -788.63467422   680.0929745  -1125.85274788
 -1077.07320113  -732.601

partial_w printing [-12.64056657  -5.98889518 -36.10424929 -12.08770538  -0.59512748
 -46.38628895 -14.30679887 -82.51382436 -53.11677054  17.45852691]
partial_d printing -293.4220963172805
partial_w printing [-12.58515581  -5.98073654 -36.03274788 -12.03943343  -0.61495751
 -46.4172238  -14.3423796  -82.4898017  -53.06589235  17.53388102]
partial_d printing -293.38243626062325
partial_w printing [-12.52974504  -5.9725779  -35.96124646 -11.99116147  -0.63478754
 -46.44815864 -14.37796034 -82.46577904 -53.01501416  17.60923513]
partial_d printing -293.342776203966
partial_w printing [-12.47433428  -5.96441926 -35.88974504 -11.94288952  -0.65461756
 -46.47909348 -14.41354108 -82.44175637 -52.96413598  17.68458924]
partial_d printing -293.3031161473088
partial_w printing [-12.41892351  -5.95626062 -35.81824363 -11.89461756  -0.67444759
 -46.51002833 -14.44912181 -82.41773371 -52.91325779  17.75994334]
partial_d printing -293.2634560906516
partial_w printing [-12.36351275  -5.94810198 -35.

partial_w printing [ -2.05711048  -4.4305949  -22.44747875  -2.86776204  -4.38266289
 -52.29484419 -21.10271955 -77.92549575 -43.39903683  31.85116147]
partial_d printing -285.78470254957506
partial_w printing [ -2.00169972  -4.42243626 -22.37597734  -2.81949008  -4.40249292
 -52.32577904 -21.13830028 -77.90147309 -43.34815864  31.92651558]
partial_d printing -285.74220963172803
partial_w printing [ -1.94628895  -4.41427762 -22.30447592  -2.77121813  -4.42232295
 -52.35671388 -21.17388102 -77.87745042 -43.29728045  32.00186969]
partial_d printing -285.70254957507086
partial_w printing [ -1.89087819  -4.40611898 -22.2329745   -2.72294618  -4.44215297
 -52.38764873 -21.20946176 -77.85342776 -43.24640227  32.0772238 ]
partial_d printing -285.66288951841364
partial_w printing [ -1.83546742  -4.39796034 -22.16147309  -2.67467422  -4.461983
 -52.41858357 -21.24504249 -77.8294051  -43.19552408  32.1525779 ]
partial_d printing -285.6232294617564
partial_w printing [ -1.78005666  -4.3898017  -2

partial_w printing [  8.6925779   -2.8478187   -8.57620397   6.49699717  -8.22968839
 -58.29620397 -28.00538244 -73.26509915 -33.52866856  46.46985836]
partial_d printing -278.0226628895184
partial_w printing [  8.74798867  -2.83966006  -8.50470255   6.54526912  -8.24951841
 -58.32713881 -28.04096317 -73.24107649 -33.47779037  46.54521246]
partial_d printing -277.9830028328612
partial_w printing [  8.80339943  -2.83150142  -8.43320113   6.59354108  -8.26934844
 -58.35807365 -28.07654391 -73.21705382 -33.42691218  46.62056657]
partial_d printing -277.94334277620396
partial_w printing [  8.8588102   -2.82334278  -8.36169972   6.64181303  -8.28917847
 -58.3890085  -28.11212465 -73.19303116 -33.37603399  46.69592068]
partial_d printing -277.90368271954674
partial_w printing [  8.91422096  -2.81518414  -8.2901983    6.69008499  -8.3090085
 -58.41994334 -28.14770538 -73.1690085  -33.32515581  46.77127479]
partial_d printing -277.8640226628895
partial_w printing [  8.96963173  -2.8070255   -8

partial_w printing [ 18.94356941  -1.33847025   4.65155807  15.42730878 -11.89824363
 -64.01915014 -34.5878187  -68.82090652 -24.11620397  60.41036827]
partial_d printing -270.6232294617564
partial_w printing [ 18.99898017  -1.33031161   4.72305949  15.47558074 -11.91807365
 -64.05008499 -34.62339943 -68.79688385 -24.06532578  60.48572238]
partial_d printing -270.5835694050992
partial_w printing [ 19.05439093  -1.32215297   4.79456091  15.52385269 -11.93790368
 -64.08101983 -34.65898017 -68.77286119 -24.01444759  60.56107649]
partial_d printing -270.54390934844196
partial_w printing [ 19.1098017   -1.31399433   4.86606232  15.57212465 -11.95773371
 -64.11195467 -34.69456091 -68.74883853 -23.96356941  60.63643059]
partial_d printing -270.50424929178473
partial_w printing [ 19.16521246  -1.30583569   4.93756374  15.6203966  -11.97756374
 -64.14288952 -34.73014164 -68.72481586 -23.91269122  60.7117847 ]
partial_d printing -270.4645892351275
partial_w printing [ 19.22062323  -1.29767705   

partial_w printing [ 29.5270255    0.21983003  18.30832861  24.64725212 -15.68577904
 -69.92770538 -41.38373938 -64.2325779  -14.39847025  74.80300283]
partial_d printing -262.9830028328612
partial_w printing [ 29.58243626   0.22798867  18.37983003  24.69552408 -15.70560907
 -69.95864023 -41.41932011 -64.20855524 -14.34759207  74.87835694]
partial_d printing -262.94334277620396
partial_w printing [ 29.63784703   0.23614731  18.45133144  24.74379603 -15.72543909
 -69.98957507 -41.45490085 -64.18453258 -14.29671388  74.95371105]
partial_d printing -262.90368271954674
partial_w printing [ 29.69325779   0.24430595  18.52283286  24.79206799 -15.74526912
 -70.02050992 -41.49048159 -64.16050992 -14.24583569  75.02906516]
partial_d printing -262.8640226628895
partial_w printing [ 29.74866856   0.25246459  18.59433428  24.84033994 -15.76509915
 -70.05144476 -41.52606232 -64.13648725 -14.19495751  75.10441926]
partial_d printing -262.8243626062323
partial_w printing [ 29.80407932   0.26062323  1

 -75.74345609 -48.07291785 -59.71631728  -4.8333711   88.96957507]
partial_d printing -255.46458923512748
partial_w printing [ 39.99966006   1.76181303  31.82209632  33.77065156 -19.43365439
 -75.77439093 -48.10849858 -59.69229462  -4.78249292  89.04492918]
partial_d printing -255.42209631728048
partial_w printing [ 40.05507082   1.76997167  31.89359773  33.81892351 -19.45348442
 -75.80532578 -48.14407932 -59.66827195  -4.73161473  89.12028329]
partial_d printing -255.38243626062325
partial_w printing [ 40.11048159   1.77813031  31.96509915  33.86719547 -19.47331445
 -75.83626062 -48.17966006 -59.64424929  -4.68073654  89.19563739]
partial_d printing -255.34277620396603
partial_w printing [ 40.16589235   1.78628895  32.03660057  33.91546742 -19.49314448
 -75.86719547 -48.21524079 -59.62022663  -4.62985836  89.2709915 ]
partial_d printing -255.3031161473088
partial_w printing [ 40.22130312   1.79444759  32.10810198  33.96373938 -19.5129745
 -75.89813031 -48.25082153 -59.59620397  -4.578

partial_w printing [ 38.09382436  -4.15331445  27.16889518  25.6721813  -38.13694051
 -94.87014164 -41.99433428 -74.71756374 -16.86266289  57.33643059]
partial_d printing -247.86118980169974
partial_w printing [ 38.14923513  -4.14515581  27.2403966   25.72045326 -38.15677054
 -94.90107649 -42.02991501 -74.69354108 -16.8117847   57.4117847 ]
partial_d printing -247.8215297450425
partial_w printing [ 38.20464589  -4.13699717  27.31189802  25.76872521 -38.17660057
 -94.93201133 -42.06549575 -74.66951841 -16.76090652  57.48713881]
partial_d printing -247.78186968838529
partial_w printing [ 38.26005666  -4.12883853  27.38339943  25.81699717 -38.19643059
 -94.96294618 -42.10107649 -74.64549575 -16.71002833  57.56249292]
partial_d printing -247.74220963172806
partial_w printing [ 38.31546742  -4.12067989  27.45490085  25.86526912 -38.21626062
 -94.99388102 -42.13665722 -74.62147309 -16.65915014  57.63784703]
partial_d printing -247.69971671388103
partial_w printing [ 38.37087819  -4.11252125 

partial_w printing [ 57.91852691   2.71830028  52.09495751  41.9929745   -2.06543909
 -56.13898017 -57.00844193 -40.9423796    6.33439093  84.64577904]
partial_d printing -240.39660056657226
partial_w printing [ 57.97393768   2.72645892  52.16645892  42.04124646  -2.08526912
 -56.16991501 -57.04402266 -40.91835694   6.38526912  84.72113314]
partial_d printing -240.35694050991503
partial_w printing [ 58.02934844   2.73461756  52.23796034  42.08951841  -2.10509915
 -56.20084986 -57.0796034  -40.89433428   6.43614731  84.79648725]
partial_d printing -240.3172804532578
partial_w printing [ 58.08475921   2.7427762   52.30946176  42.13779037  -2.12492918
 -56.2317847  -57.11518414 -40.87031161   6.4870255   84.87184136]
partial_d printing -240.27762039660058
partial_w printing [ 58.14016997   2.75093484  52.38096317  42.18606232  -2.14475921
 -56.26271955 -57.15076487 -40.84628895   6.53790368  84.94719547]
partial_d printing -240.23512747875355
partial_w printing [ 58.19558074   2.75909348 

partial_w printing [ 68.72362606   4.30923513  66.03773371  51.40600567  -5.93229462
 -62.17127479 -63.94668555 -36.25796034  16.25563739  99.33983003]
partial_d printing -232.59773371104816
partial_w printing [ 68.77903683   4.31739377  66.10923513  51.45427762  -5.95212465
 -62.20220963 -63.98226629 -36.23393768  16.30651558  99.41518414]
partial_d printing -232.55524079320114
partial_w printing [ 68.83444759   4.32555241  66.18073654  51.50254958  -5.97195467
 -62.23314448 -64.01784703 -36.20991501  16.35739377  99.49053824]
partial_d printing -232.5155807365439
partial_w printing [ 68.88985836   4.33371105  66.25223796  51.55082153  -5.9917847
 -62.26407932 -64.05342776 -36.18589235  16.40827195  99.56589235]
partial_d printing -232.47592067988668
partial_w printing [ 68.94526912   4.34186969  66.32373938  51.59909348  -6.01161473
 -62.29501416 -64.0890085  -36.16186969  16.45915014  99.64124646]
partial_d printing -232.4362606232295
partial_w printing [ 69.00067989   4.35002833  6

partial_w printing [ 67.15025496  -1.55694051  61.81354108  43.59716714 -24.75456091
 -81.45263456 -58.04600567 -51.13909348   4.48073654  68.08345609]
partial_d printing -224.7932011331445
partial_w printing [ 67.20566572  -1.54878187  61.88504249  43.64543909 -24.77439093
 -81.48356941 -58.0815864  -51.11507082   4.53161473  68.1588102 ]
partial_d printing -224.75354107648727
partial_w printing [ 67.26107649  -1.54062323  61.95654391  43.69371105 -24.79422096
 -81.51450425 -58.11716714 -51.09104816   4.58249292  68.23416431]
partial_d printing -224.71388101983004
partial_w printing [ 67.31648725  -1.53246459  62.02804533  43.741983   -24.81405099
 -81.54543909 -58.15274788 -51.0670255    4.6333711   68.30951841]
partial_d printing -224.67422096317281
partial_w printing [ 67.37189802  -1.52430595  62.09954674  43.79025496 -24.83388102
 -81.57637394 -58.18832861 -51.04300283   4.68424929  68.38487252]
partial_d printing -224.6345609065156
partial_w printing [ 67.42730878  -1.51614731  

In [237]:
mean_squared_error(y_test, preds)

16814.644920224557