In [16]:
# -*- endoding: utf-8 -*-
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

def mse(pred_y, y):
    return np.linalg.norm(pred_y-y)

In [22]:
class LogisticRegressor:
    def __init__(self, learning_rate=1e-3, n_epoch=2000):
        self.learning_rate = learning_rate
        self.n_epoch = n_epoch

    def _sigma(self, t):
        return 1/(1+np.exp(-t))

    def _cost(self, X, y):
        hat_p = self._sigma(X.dot(self.theta))
        sum = 0
        for pred_y, real_y in zip(hat_p, y):
            sum += real_y*np.log(pred_y) + (1-real_y)*np.log(1-pred_y)
        return -sum / X.shape[0]

    def fit(self, X, y):
        # theta: randomly initialized: [-1/sqrt(n), 1/sqrt(n)]
        limit = np.sqrt(X.shape[1])
        self.theta = np.random.uniform(-1/limit, 1/limit, (X.shape[1], 1))

        for epoch in range(self.n_epoch):
            print('epoch', epoch, ':', X.T.dot(self._sigma(X.dot(self.theta))-y))
            print('cost:', self._cost(X, y))
            self.theta -= self.learning_rate * 1/X.shape[0] * X.T.dot(self._sigma(X.dot(self.theta))-y)

    def predict(self, X):
        return np.around(self._sigma(X.dot(self.theta))).astype(int)

    def predict_proba(self, X):
        return [[x[0], y[0]] for x, y in \
                zip(self._sigma(X.dot(self.theta)), 1-self._sigma(X.dot(self.theta)))]

In [23]:
def test_logistic_regression():
    from sklearn import datasets
    iris = datasets.load_iris()
    X = iris['data'][:, 3:]
    y = (iris['target']==2).astype(np.int).reshape(-1, 1)

    print(X.shape, y.shape)

    lr = LogisticRegressor(learning_rate=0.001, n_epoch=1000)
    lr.fit(X, y)
    print(lr.predict_proba(X[:10]))
    print(mse(lr.predict(X), y))

    # from sklearn.linear_model import LogisticRegression

    # slr = LogisticRegression()
    # slr.fit(X, y)
    # print(slr.predict_proba(X[:10]))
    # print(mse(slr.predict(X), y))
test_logistic_regression()

(150, 1) (150, 1)
epoch 0 : [[-4.00646318]]
cost: [0.68816281]
epoch 1 : [[-3.80637712]]
cost: [0.68809325]
epoch 2 : [[-3.61637384]]
cost: [0.68803047]
epoch 3 : [[-3.43593852]]
cost: [0.68797379]
epoch 4 : [[-3.26458296]]
cost: [0.68792263]
epoch 5 : [[-3.10184426]]
cost: [0.68787645]
epoch 6 : [[-2.94728349]]
cost: [0.68783475]
epoch 7 : [[-2.80048444]]
cost: [0.68779711]
epoch 8 : [[-2.66105243]]
cost: [0.68776312]
epoch 9 : [[-2.5286132]]
cost: [0.68773243]
epoch 10 : [[-2.40281183]]
cost: [0.68770472]
epoch 11 : [[-2.28331171]]
cost: [0.6876797]
epoch 12 : [[-2.16979362]]
cost: [0.6876571]
epoch 13 : [[-2.06195477]]
cost: [0.6876367]
epoch 14 : [[-1.95950797]]
cost: [0.68761827]
epoch 15 : [[-1.86218078]]
cost: [0.68760163]
epoch 16 : [[-1.76971477]]
cost: [0.6875866]
epoch 17 : [[-1.68186475]]
cost: [0.68757302]
epoch 18 : [[-1.59839812]]
cost: [0.68756076]
epoch 19 : [[-1.51909416]]
cost: [0.68754969]
epoch 20 : [[-1.44374344]]
cost: [0.68753969]
epoch 21 : [[-1.37214722]]
cost

epoch 193 : [[-0.00022637]]
cost: [0.68744615]
epoch 194 : [[-0.0002152]]
cost: [0.68744615]
epoch 195 : [[-0.00020458]]
cost: [0.68744615]
epoch 196 : [[-0.00019448]]
cost: [0.68744615]
epoch 197 : [[-0.00018488]]
cost: [0.68744615]
epoch 198 : [[-0.00017576]]
cost: [0.68744615]
epoch 199 : [[-0.00016709]]
cost: [0.68744615]
epoch 200 : [[-0.00015884]]
cost: [0.68744615]
epoch 201 : [[-0.000151]]
cost: [0.68744615]
epoch 202 : [[-0.00014355]]
cost: [0.68744615]
epoch 203 : [[-0.00013647]]
cost: [0.68744615]
epoch 204 : [[-0.00012973]]
cost: [0.68744615]
epoch 205 : [[-0.00012333]]
cost: [0.68744615]
epoch 206 : [[-0.00011724]]
cost: [0.68744615]
epoch 207 : [[-0.00011146]]
cost: [0.68744615]
epoch 208 : [[-0.00010596]]
cost: [0.68744615]
epoch 209 : [[-0.00010073]]
cost: [0.68744615]
epoch 210 : [[-9.57570089e-05]]
cost: [0.68744615]
epoch 211 : [[-9.1031461e-05]]
cost: [0.68744615]
epoch 212 : [[-8.65391159e-05]]
cost: [0.68744615]
epoch 213 : [[-8.22684653e-05]]
cost: [0.68744615]
e

cost: [0.68744615]
epoch 363 : [[-4.15313011e-08]]
cost: [0.68744615]
epoch 364 : [[-3.94817534e-08]]
cost: [0.68744615]
epoch 365 : [[-3.75333529e-08]]
cost: [0.68744615]
epoch 366 : [[-3.56811194e-08]]
cost: [0.68744615]
epoch 367 : [[-3.39202657e-08]]
cost: [0.68744615]
epoch 368 : [[-3.22463263e-08]]
cost: [0.68744615]
epoch 369 : [[-3.0655001e-08]]
cost: [0.68744615]
epoch 370 : [[-2.91421887e-08]]
cost: [0.68744615]
epoch 371 : [[-2.77040291e-08]]
cost: [0.68744615]
epoch 372 : [[-2.63368636e-08]]
cost: [0.68744615]
epoch 373 : [[-2.50371617e-08]]
cost: [0.68744615]
epoch 374 : [[-2.38015887e-08]]
cost: [0.68744615]
epoch 375 : [[-2.26269994e-08]]
cost: [0.68744615]
epoch 376 : [[-2.1510365e-08]]
cost: [0.68744615]
epoch 377 : [[-2.04488506e-08]]
cost: [0.68744615]
epoch 378 : [[-1.94397072e-08]]
cost: [0.68744615]
epoch 379 : [[-1.84803788e-08]]
cost: [0.68744615]
epoch 380 : [[-1.75683725e-08]]
cost: [0.68744615]
epoch 381 : [[-1.67013798e-08]]
cost: [0.68744615]
epoch 382 : [[

cost: [0.68744615]
epoch 530 : [[-8.87023788e-12]]
cost: [0.68744615]
epoch 531 : [[-8.42637071e-12]]
cost: [0.68744615]
epoch 532 : [[-8.01603228e-12]]
cost: [0.68744615]
epoch 533 : [[-7.61835039e-12]]
cost: [0.68744615]
epoch 534 : [[-7.23976434e-12]]
cost: [0.68744615]
epoch 535 : [[-6.88138435e-12]]
cost: [0.68744615]
epoch 536 : [[-6.54321042e-12]]
cost: [0.68744615]
epoch 537 : [[-6.21991347e-12]]
cost: [0.68744615]
epoch 538 : [[-5.91149352e-12]]
cost: [0.68744615]
epoch 539 : [[-5.61639624e-12]]
cost: [0.68744615]
epoch 540 : [[-5.34061684e-12]]
cost: [0.68744615]
epoch 541 : [[-5.07438536e-12]]
cost: [0.68744615]
epoch 542 : [[-4.83790785e-12]]
cost: [0.68744615]
epoch 543 : [[-4.59632332e-12]]
cost: [0.68744615]
epoch 544 : [[-4.36894965e-12]]
cost: [0.68744615]
epoch 545 : [[-4.15800727e-12]]
cost: [0.68744615]
epoch 546 : [[-3.95772304e-12]]
cost: [0.68744615]
epoch 547 : [[-3.75299791e-12]]
cost: [0.68744615]
epoch 548 : [[-3.57491814e-12]]
cost: [0.68744615]
epoch 549 : 

cost: [0.68744615]
epoch 703 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 704 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 705 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 706 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 707 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 708 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 709 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 710 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 711 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 712 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 713 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 714 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 715 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 716 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 717 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 718 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 719 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 720 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 721 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 722 : 

cost: [0.68744615]
epoch 881 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 882 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 883 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 884 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 885 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 886 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 887 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 888 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 889 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 890 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 891 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 892 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 893 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 894 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 895 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 896 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 897 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 898 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 899 : [[-1.99840144e-14]]
cost: [0.68744615]
epoch 900 : 

In [3]:
from sklearn import datasets
iris = datasets.load_iris()
X = iris['data'][:, 3:]
y = (iris['target']==2).astype(np.int).reshape(-1, 1)

In [10]:
test_logistic_regression()

(150, 1) (150, 1)
epoch 0 : [[-70.48676233]]
epoch 1 : [[-48.13825723]]
epoch 2 : [[-26.7019363]]
epoch 3 : [[-13.44426538]]
epoch 4 : [[-6.67588723]]
epoch 5 : [[-3.33424401]]
epoch 6 : [[-1.67529773]]
epoch 7 : [[-0.84491397]]
epoch 8 : [[-0.42700134]]
epoch 9 : [[-0.21603186]]
epoch 10 : [[-0.10935782]]
epoch 11 : [[-0.05537406]]
epoch 12 : [[-0.02804311]]
epoch 13 : [[-0.01420294]]
epoch 14 : [[-0.0071936]]
epoch 15 : [[-0.00364353]]
epoch 16 : [[-0.00184546]]
epoch 17 : [[-0.00093473]]
epoch 18 : [[-0.00047345]]
epoch 19 : [[-0.0002398]]
epoch 20 : [[-0.00012146]]
epoch 21 : [[-6.15211415e-05]]
epoch 22 : [[-3.11608519e-05]]
epoch 23 : [[-1.57831722e-05]]
epoch 24 : [[-7.9942787e-06]]
epoch 25 : [[-4.0491539e-06]]
epoch 26 : [[-2.05092268e-06]]
epoch 27 : [[-1.03880563e-06]]
epoch 28 : [[-5.26161791e-07]]
epoch 29 : [[-2.66504352e-07]]
epoch 30 : [[-1.34986193e-07]]
epoch 31 : [[-6.83713755e-08]]
epoch 32 : [[-3.46305429e-08]]
epoch 33 : [[-1.75405994e-08]]
epoch 34 : [[-8.8844216

epoch 433 : [[-1.77635684e-15]]
epoch 434 : [[-1.77635684e-15]]
epoch 435 : [[-1.77635684e-15]]
epoch 436 : [[-1.77635684e-15]]
epoch 437 : [[-1.77635684e-15]]
epoch 438 : [[-1.77635684e-15]]
epoch 439 : [[-1.77635684e-15]]
epoch 440 : [[-1.77635684e-15]]
epoch 441 : [[-1.77635684e-15]]
epoch 442 : [[-1.77635684e-15]]
epoch 443 : [[-1.77635684e-15]]
epoch 444 : [[-1.77635684e-15]]
epoch 445 : [[-1.77635684e-15]]
epoch 446 : [[-1.77635684e-15]]
epoch 447 : [[-1.77635684e-15]]
epoch 448 : [[-1.77635684e-15]]
epoch 449 : [[-1.77635684e-15]]
epoch 450 : [[-1.77635684e-15]]
epoch 451 : [[-1.77635684e-15]]
epoch 452 : [[-1.77635684e-15]]
epoch 453 : [[-1.77635684e-15]]
epoch 454 : [[-1.77635684e-15]]
epoch 455 : [[-1.77635684e-15]]
epoch 456 : [[-1.77635684e-15]]
epoch 457 : [[-1.77635684e-15]]
epoch 458 : [[-1.77635684e-15]]
epoch 459 : [[-1.77635684e-15]]
epoch 460 : [[-1.77635684e-15]]
epoch 461 : [[-1.77635684e-15]]
epoch 462 : [[-1.77635684e-15]]
epoch 463 : [[-1.77635684e-15]]
epoch 46

epoch 728 : [[-1.77635684e-15]]
epoch 729 : [[-1.77635684e-15]]
epoch 730 : [[-1.77635684e-15]]
epoch 731 : [[-1.77635684e-15]]
epoch 732 : [[-1.77635684e-15]]
epoch 733 : [[-1.77635684e-15]]
epoch 734 : [[-1.77635684e-15]]
epoch 735 : [[-1.77635684e-15]]
epoch 736 : [[-1.77635684e-15]]
epoch 737 : [[-1.77635684e-15]]
epoch 738 : [[-1.77635684e-15]]
epoch 739 : [[-1.77635684e-15]]
epoch 740 : [[-1.77635684e-15]]
epoch 741 : [[-1.77635684e-15]]
epoch 742 : [[-1.77635684e-15]]
epoch 743 : [[-1.77635684e-15]]
epoch 744 : [[-1.77635684e-15]]
epoch 745 : [[-1.77635684e-15]]
epoch 746 : [[-1.77635684e-15]]
epoch 747 : [[-1.77635684e-15]]
epoch 748 : [[-1.77635684e-15]]
epoch 749 : [[-1.77635684e-15]]
epoch 750 : [[-1.77635684e-15]]
epoch 751 : [[-1.77635684e-15]]
epoch 752 : [[-1.77635684e-15]]
epoch 753 : [[-1.77635684e-15]]
epoch 754 : [[-1.77635684e-15]]
epoch 755 : [[-1.77635684e-15]]
epoch 756 : [[-1.77635684e-15]]
epoch 757 : [[-1.77635684e-15]]
epoch 758 : [[-1.77635684e-15]]
epoch 75