In [1]:
from simple_nn.nn import NeuralNetwork
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn import datasets

# load the MNIST dataset
print("[INFO] loading MNIST (sample) dataset...")
digits = datasets.load_digits()
data = digits.data.astype("float")
data = (data - data.min()) / (data.max() - data.min())
print("[INFO] samples: {}, dim: {}".format(data.shape[0], data.shape[1]))

# split the data
(trainX, testX, trainY, testY) = train_test_split(data, digits.target, test_size=0.25)
trainY = LabelBinarizer().fit_transform(trainY)
testY = LabelBinarizer().fit_transform(testY)

# train the network
print("[INFO] training network...")
nn = NeuralNetwork([trainX.shape[1], 32, 16, 10])
print("[INFO] {}".format(nn))
nn.fit(trainX, trainY, epochs=1000)

# evaluate the network
print("[INFO] evaluating network...")
predictions = nn.predict(testX)
predictions = predictions.argmax(axis=1)
print(classification_report(testY.argmax(axis=1), predictions))


[INFO] loading MNIST (sample) dataset...
[INFO] samples: 1797, dim: 64
[INFO] training network...
[INFO] NeuralNetwork: 64-32-16-10
[INFO] epoch=1, loss=605.0888471
[INFO] epoch=100, loss=6.7824465
[INFO] epoch=200, loss=2.5541157
[INFO] epoch=300, loss=1.9931933
[INFO] epoch=400, loss=1.3813137
[INFO] epoch=500, loss=0.9087222
[INFO] epoch=600, loss=0.7226877
[INFO] epoch=700, loss=0.2215744
[INFO] epoch=800, loss=0.1676359
[INFO] epoch=900, loss=0.1377962
[INFO] epoch=1000, loss=0.1175171
[INFO] evaluating network...
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        39
           1       0.97      1.00      0.98        60
           2       0.96      0.98      0.97        47
           3       1.00      0.94      0.97        52
           4       0.97      1.00      0.98        31
           5       0.93      1.00      0.96        50
           6       1.00      0.98      0.99        47
           7       1.00      0.97      0.9