In [1]:
import numpy as np
import mlp
from sklearn import datasets

In [2]:
iris = datasets.load_iris()
Xraw = iris.data[:,:4]
Yraw = iris.target

### Pre-processing the raw data

#### Normalizing Features

In [3]:
X = (Xraw - Xraw.mean(axis=0))/Xraw.max(axis=0)

#### Encoding the Targets

In [4]:
Y = np.zeros((Yraw.shape[0],3))
Y[:,0] = np.where(Yraw == 0, 1, 0)
Y[:,1] = np.where(Yraw == 1, 1, 0)
Y[:,2] = np.where(Yraw == 2, 1, 0)

### Split data into Training, Testing, and Validation Sets

In [5]:
np.random.seed(1)

In [6]:
# split data into train, test, validation sets
ixs = np.arange(len(X))
np.random.shuffle(ixs)
X = X[ixs,:]
Y = Y[ixs,:]

In [7]:
Xtrain, Xtest, Xval = X[::2], X[1::4], X[3::4]

In [8]:
Ytrain, Ytest, Yval = Y[::2], Y[1::4], Y[3::4]

In [9]:
mlp1 = mlp.mlp(Xtrain, Ytrain, np.array([5]))

In [10]:
mlp1.arch

array([4, 5, 3])

In [11]:
mlp1.weights

[array([[ 0.11033664, -0.43306265,  0.38410034,  0.17074339,  0.44481908],
        [-0.29306756, -0.32455565,  0.38692514,  0.17603952, -0.38818125],
        [ 0.2284931 ,  0.22707377,  0.37836465,  0.1891935 , -0.33606227],
        [-0.42943226, -0.42376978, -0.4218955 , -0.22699572,  0.32201879],
        [ 0.03473156,  0.04724541,  0.30592173, -0.33614961, -0.19750412]]),
 array([[ 0.07002215,  0.38342332,  0.04983097],
        [-0.39302284,  0.24546555, -0.21802559],
        [ 0.25075034, -0.0915614 ,  0.29683068],
        [ 0.20177398,  0.04591996, -0.29683307],
        [-0.3593257 , -0.30917177, -0.37187183],
        [-0.3204797 , -0.22395739,  0.17390477]])]

In [12]:
mlp1.train(Xtrain, Ytrain, Xval, Yval,eta=0.15, momentum=0.9)

Stopped after 4 epochs


In [13]:
mlp1.weights

[array([[  0.14964866,  -2.12507499,  -1.4729652 ,  -0.51338629,
          -4.17434064],
        [  0.01732789,  -1.67421913,  -0.48517372,  -0.12114566,
          -1.00385634],
        [  0.43029763,   2.33516367,   1.91649559,   1.81740698,
          -3.92890038],
        [ -0.98207742,  -5.16015396,  -3.95801982,  -3.08441986,
           7.62953531],
        [ -1.27020443,  -4.65622627,  -3.47715703,  -3.63002513,  10.8205909 ]]),
 array([[-4.04023765,  4.35610226, -1.77304624],
        [-0.62914144,  0.9939414 , -1.9811564 ],
        [ 4.0558924 , -5.31668286, -2.98481658],
        [ 2.9049549 , -2.90146114, -3.78708957],
        [ 2.62043286, -1.52701238, -5.18575731],
        [-5.11085577, -8.72052363,  7.04814451]])]

In [14]:
mlp1.cfnmatrix(Xtest,Ytest)

Error: 0.052631578947368474


array([[ 10.,   0.,   0.],
       [  0.,  16.,   0.],
       [  0.,   2.,  10.]])

In [15]:
mlp1.updates

[array([[ -9.13578029e-04,  -7.90476297e-04,  -7.67508489e-04,
          -1.22449000e-03,   6.52123425e-03],
        [  1.07021098e-04,   1.59272216e-04,   1.15365106e-04,
           6.76368385e-05,  -7.82821374e-04],
        [ -3.89018071e-05,  -9.67166764e-05,  -9.48606133e-05,
          -1.23959154e-04,   4.84337436e-04],
        [ -9.11364537e-05,   4.22894332e-04,   2.77919961e-04,
          -1.50750003e-06,   1.13035170e-04],
        [ -1.74140417e-04,   4.85806180e-04,   2.99253094e-04,
          -7.88874958e-05,   6.42826051e-04]]),
 array([[ -8.56542435e-06,  -1.88338321e-04,   1.84020015e-03],
        [ -1.23576247e-04,   1.93612557e-04,   8.10100281e-04],
        [ -4.40839084e-04,   9.80505835e-04,   7.17285272e-05],
        [ -3.72299323e-04,   7.85006192e-04,   1.81727704e-04],
        [ -3.31763877e-04,   6.41069806e-04,   4.12799173e-04],
        [  2.37046156e-05,   3.33021173e-04,   2.95282101e-04]])]

In [16]:
# [array([[  7.02234844e-01,  -2.14514083e+00,  -1.35849254e+00,
#            1.52521419e+00,  -2.76542928e+00],
#         [ -1.26193034e-01,  -1.84743205e+00,  -3.87687652e-01,
#            3.33906378e-01,  -4.77249539e-01],
#         [  2.49553542e-03,   2.56618250e+00,   1.71350413e+00,
#            1.79154186e+00,  -2.57369824e+00],
#         [ -8.77007855e-01,  -5.65301350e+00,  -3.60447512e+00,
#           -3.11439119e+00,   5.74884236e+00],
#         [ -1.26312603e+00,  -5.09211177e+00,  -3.17574015e+00,
#           -4.43642304e+00,   7.95662642e+00]]),
#  array([[-3.51137687,  1.81132696,  0.42201971],
#         [-0.88364363,  1.1497938 , -1.37610687],
#         [ 4.41753029, -6.80977866, -2.5214    ],
#         [ 2.68486245, -2.58314549, -3.27461225],
#         [ 1.40679533,  2.4895717 , -5.05873974],
#         [-5.35150919, -7.14516933,  6.5562209 ]])]