In [1]:
import numpy as np
import pandas as pd

admissions = pd.read_csv('16_binary.csv')

# Make dummy variables for rank
data = pd.concat([admissions, pd.get_dummies(admissions['rank'], prefix='rank')], axis=1)
data = data.drop('rank', axis=1)

# Standarize features
for field in ['gre', 'gpa']:
    mean, std = data[field].mean(), data[field].std()
    data.loc[:,field] = (data[field]-mean)/std
    
# Split off random 10% of the data for testing
np.random.seed(21)
sample = np.random.choice(data.index, size=int(len(data)*0.9), replace=False)
data, test_data = data.ix[sample], data.drop(sample)

# Split into features and targets
features, targets = data.drop('admit', axis=1), data['admit']
features_test, targets_test = test_data.drop('admit', axis=1), test_data['admit']

In [9]:
import numpy as np
#from data_prep import features, targets, features_test, targets_test

np.random.seed(21)

def sigmoid(x):
    """
    Calculate sigmoid
    """
    return 1 / (1 + np.exp(-x))


# Hyperparameters
n_hidden = 2  # number of hidden units
epochs = 900
learnrate = 0.005

n_records, n_features = features.shape
last_loss = None
# Initialize weights
weights_input_hidden = np.random.normal(scale=1 / n_features ** .5, size=(n_features, n_hidden))
weights_hidden_output = np.random.normal(scale=1 / n_features ** .5, size=n_hidden)

for e in range(epochs):
    del_w_input_hidden = np.zeros(weights_input_hidden.shape)
    del_w_hidden_output = np.zeros(weights_hidden_output.shape)
    for x, y in zip(features.values, targets):
        ## Forward pass ##
        # TODO: Calculate the output
        hidden_input = np.dot(x, weights_input_hidden)
        hidden_output = sigmoid(hidden_input)
        
        output_layer_in = np.dot(hidden_output, weights_hidden_output)
        output = sigmoid(output_layer_in)

        ## Backward pass ##
        # TODO: Calculate the network's prediction error
        error = y - output

        # TODO: Calculate error term for the output unit
        output_error_term = error * output * (1 - output)
        
        print(output_error_term.shape)
        print(weights_hidden_output.T.shape)
        ## propagate errors to hidden layer
        # TODO: Calculate the hidden layer's contribution to the error
        hidden_error = np.dot(output_error_term, weights_hidden_output)
        print(output_error_term)
        print()
        print(weights_hidden_output)
        print()
        print(hidden_error)
        print()
        # TODO: Calculate the error term for the hidden layer
        hidden_error_term = hidden_error * hidden_output * (1 - hidden_output)
        
        # TODO: Update the change in weights
        del_w_hidden_output += output_error_term * hidden_output
        del_w_input_hidden += hidden_error_term * x[:, None]
        break
    # TODO: Update weights
    weights_input_hidden += learnrate * del_w_input_hidden / n_records
    weights_hidden_output += learnrate * del_w_hidden_output / n_records

    # Printing out the mean square error on the training set
    if e % (epochs / 10) == 0:
        hidden_output = sigmoid(np.dot(x, weights_input_hidden))
        out = sigmoid(np.dot(hidden_output,
                             weights_hidden_output))
        loss = np.mean((out - targets) ** 2)

        if last_loss and last_loss < loss:
            print("Train loss: ", loss, "  WARNING - Loss Increasing")
        else:
            print("Train loss: ", loss)
        last_loss = loss

# Calculate accuracy on test data
hidden = sigmoid(np.dot(features_test, weights_input_hidden))
out = sigmoid(np.dot(hidden, weights_hidden_output))
predictions = out > 0.5
accuracy = np.mean(predictions == targets_test)
print("Prediction accuracy: {:.3f}".format(accuracy))


()
(2,)
0.101909740805

[ 0.65768472 -0.28137626]

[ 0.06702448 -0.02867498]

Train loss:  0.28686764284154326
()
(2,)
0.101909685678

[ 0.65768559 -0.28137587]

[ 0.06702453 -0.02867493]

()
(2,)
0.101909630552

[ 0.65768646 -0.28137549]

[ 0.06702458 -0.02867487]

()
(2,)
0.101909575426

[ 0.65768734 -0.2813751 ]

[ 0.06702464 -0.02867482]

()
(2,)
0.101909520299

[ 0.65768821 -0.28137471]

[ 0.06702469 -0.02867476]

()
(2,)
0.101909465173

[ 0.65768908 -0.28137432]

[ 0.06702474 -0.02867471]

()
(2,)
0.101909410047

[ 0.65768995 -0.28137393]

[ 0.0670248  -0.02867465]

()
(2,)
0.10190935492

[ 0.65769083 -0.28137354]

[ 0.06702485 -0.0286746 ]

()
(2,)
0.101909299794

[ 0.6576917  -0.28137316]

[ 0.0670249  -0.02867454]

()
(2,)
0.101909244668

[ 0.65769257 -0.28137277]

[ 0.06702495 -0.02867449]

()
(2,)
0.101909189541

[ 0.65769344 -0.28137238]

[ 0.06702501 -0.02867443]

()
(2,)
0.101909134415

[ 0.65769432 -0.28137199]

[ 0.06702506 -0.02867438]

()
(2,)
0.101909079288

[ 0.6576

(2,)
0.101903897392

[ 0.65777718 -0.28133508]

[ 0.06703006 -0.02866914]

()
(2,)
0.101903842265

[ 0.65777805 -0.28133469]

[ 0.06703011 -0.02866909]

()
(2,)
0.101903787139

[ 0.65777893 -0.28133431]

[ 0.06703016 -0.02866903]

()
(2,)
0.101903732012

[ 0.6577798  -0.28133392]

[ 0.06703022 -0.02866898]

()
(2,)
0.101903676885

[ 0.65778067 -0.28133353]

[ 0.06703027 -0.02866892]

()
(2,)
0.101903621758

[ 0.65778154 -0.28133314]

[ 0.06703032 -0.02866887]

()
(2,)
0.101903566632

[ 0.65778241 -0.28133275]

[ 0.06703037 -0.02866881]

()
(2,)
0.101903511505

[ 0.65778329 -0.28133236]

[ 0.06703043 -0.02866876]

()
(2,)
0.101903456378

[ 0.65778416 -0.28133198]

[ 0.06703048 -0.0286687 ]

()
(2,)
0.101903401251

[ 0.65778503 -0.28133159]

[ 0.06703053 -0.02866865]

()
(2,)
0.101903346124

[ 0.6577859 -0.2813312]

[ 0.06703058 -0.02866859]

()
(2,)
0.101903290998

[ 0.65778678 -0.28133081]

[ 0.06703064 -0.02866854]

()
(2,)
0.101903235871

[ 0.65778765 -0.28133042]

[ 0.06703069 -0.02

[ 0.65799262 -0.28123914]

[ 0.06704305 -0.02865553]

()
(2,)
0.10189022584

[ 0.65799349 -0.28123875]

[ 0.06704311 -0.02865548]

()
(2,)
0.101890170712

[ 0.65799436 -0.28123836]

[ 0.06704316 -0.02865542]

()
(2,)
0.101890115585

[ 0.65799523 -0.28123797]

[ 0.06704321 -0.02865537]

()
(2,)
0.101890060457

[ 0.65799611 -0.28123758]

[ 0.06704326 -0.02865531]

()
(2,)
0.101890005329

[ 0.65799698 -0.28123719]

[ 0.06704332 -0.02865526]

()
(2,)
0.101889950202

[ 0.65799785 -0.28123681]

[ 0.06704337 -0.0286552 ]

()
(2,)
0.101889895074

[ 0.65799872 -0.28123642]

[ 0.06704342 -0.02865515]

()
(2,)
0.101889839946

[ 0.6579996  -0.28123603]

[ 0.06704347 -0.02865509]

()
(2,)
0.101889784818

[ 0.65800047 -0.28123564]

[ 0.06704353 -0.02865504]

()
(2,)
0.101889729691

[ 0.65800134 -0.28123525]

[ 0.06704358 -0.02865498]

()
(2,)
0.101889674563

[ 0.65800221 -0.28123486]

[ 0.06704363 -0.02865493]

()
(2,)
0.101889619435

[ 0.65800308 -0.28123447]

[ 0.06704368 -0.02865487]

()
(2,)
0.1

()
(2,)
0.101880027147

[ 0.65815484 -0.28116689]

[ 0.06705283 -0.02864529]

()
(2,)
0.101879972019

[ 0.65815571 -0.28116651]

[ 0.06705289 -0.02864524]

()
(2,)
0.10187991689

[ 0.65815658 -0.28116612]

[ 0.06705294 -0.02864518]

()
(2,)
0.101879861762

[ 0.65815745 -0.28116573]

[ 0.06705299 -0.02864513]

()
(2,)
0.101879806633

[ 0.65815833 -0.28116534]

[ 0.06705304 -0.02864507]

()
(2,)
0.101879751505

[ 0.6581592  -0.28116495]

[ 0.0670531  -0.02864502]

()
(2,)
0.101879696377

[ 0.65816007 -0.28116456]

[ 0.06705315 -0.02864496]

()
(2,)
0.101879641248

[ 0.65816094 -0.28116418]

[ 0.0670532  -0.02864491]

()
(2,)
0.10187958612

[ 0.65816181 -0.28116379]

[ 0.06705325 -0.02864485]

()
(2,)
0.101879530991

[ 0.65816269 -0.2811634 ]

[ 0.06705331 -0.0286448 ]

()
(2,)
0.101879475863

[ 0.65816356 -0.28116301]

[ 0.06705336 -0.02864474]

()
(2,)
0.101879420734

[ 0.65816443 -0.28116262]

[ 0.06705341 -0.02864469]

()
(2,)
0.101879365606

[ 0.6581653  -0.28116223]

[ 0.06705346 -0

()
(2,)
0.101865032071

[ 0.65839204 -0.28106127]

[ 0.06706713 -0.02863032]

()
(2,)
0.101864976941

[ 0.65839291 -0.28106088]

[ 0.06706718 -0.02863026]

()
(2,)
0.101864921812

[ 0.65839379 -0.28106049]

[ 0.06706723 -0.0286302 ]

()
(2,)
0.101864866682

[ 0.65839466 -0.2810601 ]

[ 0.06706728 -0.02863015]

()
(2,)
0.101864811553

[ 0.65839553 -0.28105971]

[ 0.06706734 -0.02863009]

()
(2,)
0.101864756423

[ 0.6583964  -0.28105933]

[ 0.06706739 -0.02863004]

()
(2,)
0.101864701294

[ 0.65839727 -0.28105894]

[ 0.06706744 -0.02862998]

()
(2,)
0.101864646164

[ 0.65839815 -0.28105855]

[ 0.06706749 -0.02862993]

()
(2,)
0.101864591035

[ 0.65839902 -0.28105816]

[ 0.06706755 -0.02862987]

()
(2,)
0.101864535905

[ 0.65839989 -0.28105777]

[ 0.0670676  -0.02862982]

()
(2,)
0.101864480776

[ 0.65840076 -0.28105738]

[ 0.06706765 -0.02862976]

()
(2,)
0.101864425646

[ 0.65840163 -0.281057  ]

[ 0.0670677  -0.02862971]

()
(2,)
0.101864370517

[ 0.65840251 -0.28105661]

[ 0.06706776 