In [50]:
# import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [71]:
# normalize data
def normalize(data):
    for row in data.T:
        f_mean = np.mean(row)
        f_range = np.amax(row) - np.amin(row)
        
        row -= f_mean
        row /= f_range
    return data

# sigmoid activation function
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# predict by applying sigmoid to linearity
def predict(x, weights, bias):
    z = np.dot(x, weights) + bias
    return sigmoid(z)

# use cross entropy instead of mse (mse no appropriate to use with non-linear functions because there are many local maxima)
def cross_entropy(y_hat, y):
    n = y.shape[0]
    
    cost = np.multiply(y, np.log(y_hat))
    cost += np.multiply((np.ones(y.shape) - y), np.log(np.ones(y.shape) - y_hat))
    cost *= -1/n
    
    return cost.sum()

# update weights and bias
def update_parameters(x, y, y_hat, weights, bias, learning_rate=0.05):
    
    n = y.shape[0]
    
    error = y_hat - y
    
    dw = np.dot(x.T, error) / n
    db = error.sum() / n
    
    weights -= dw * learning_rate
    bias -= db * learning_rate

# fit
def train(x, y, iterations=200):
    w = np.random.rand(x.shape[1],1)
    b = np.ones((1,1))

    for n in range(iterations):
        y_hat = predict(x, w, b)
        update_parameters(x, y, y_hat, w, b)
        
        if (n + 1) % 10 == 0:
            print("%ith iteration error: %s" % (n + 1, cross_entropy(y_hat, y)))
    
    return w, b

In [81]:
w = np.random.rand(2,1)
b = np.ones((1,1))
x = np.array([[0, 0], [1, 0], [0, 1], [1, 1]])
y = np.array([[0], [0], [0], [1]])

w, b = train(x, y, 400)

y_hat = predict(x, w, b)
print(y_hat)

10th iteration error: 1.005350216658127
20th iteration error: 0.8928319155534907
30th iteration error: 0.8096351078194921
40th iteration error: 0.7486534603178534
50th iteration error: 0.7036140870350333
60th iteration error: 0.6696642738282635
70th iteration error: 0.6433242044944267
80th iteration error: 0.6221997433419623
90th iteration error: 0.6046767816200218
100th iteration error: 0.5896753480190486
110th iteration error: 0.5764729403312535
120th iteration error: 0.5645841057836171
130th iteration error: 0.5536803312591946
140th iteration error: 0.5435372575590268
150th iteration error: 0.5339999577330564
160th iteration error: 0.5249600431578945
170th iteration error: 0.5163405023018077
180th iteration error: 0.5080856088107867
190th iteration error: 0.5001541689196971
200th iteration error: 0.49251498094241797
210th iteration error: 0.4851437684111809
220th iteration error: 0.47802110011638166
230th iteration error: 0.4711309741599912
240th iteration error: 0.46445985055276473