-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathsklearn_ann.py
31 lines (24 loc) · 992 Bytes
/
sklearn_ann.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Train a neural network in just 3 lines of code!
#
# the notes for this class can be found at:
# https://deeplearningcourses.com/c/data-science-deep-learning-in-python
# https://www.udemy.com/data-science-deep-learning-in-python
from __future__ import print_function, division
from builtins import range
# Note: you may need to update your version of future
# sudo pip install -U future
import sys
sys.path.append('../ann_logistic_extra')
from process import get_data
from sklearn.neural_network import MLPClassifier
from sklearn.utils import shuffle
# get the data
Xtrain, Ytrain, Xtest, Ytest = get_data()
# create the neural network
model = MLPClassifier(hidden_layer_sizes=(20, 20), max_iter=2000)
# train the neural network
model.fit(Xtrain, Ytrain)
# print the train and test accuracy
train_accuracy = model.score(Xtrain, Ytrain)
test_accuracy = model.score(Xtest, Ytest)
print("train accuracy:", train_accuracy, "test accuracy:", test_accuracy)