-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathlog_reg.py
73 lines (57 loc) · 2.07 KB
/
log_reg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
Notes:
This code contains one method that explains how to build a
logistic regression classifier for the MNIST dataset using
the yann toolbox.
For a more interactive tutorial refer the notebook at
yann/pantry/tutorials/notebooks/Logistic Regression.ipynb
"""
from yann.network import network
from yann.utils.graph import draw_network
def log_reg ( dataset ):
"""
This function is a demo example of logistic regression.
"""
# Create the yann network class with empty layers.
net = network()
# Setup the datastream module and add it to network.
dataset_params = { "dataset" : dataset,
"svm" : False,
"n_classes" : 10 }
net.add_module ( type = 'datastream', params = dataset_params )
# Create an input layer that feeds from the datastream modele.
net.add_layer ( type = "input", datastream_origin = 'data')
# Create a logistic regression layer.
# Creates a softmax layer.
net.add_layer ( type = "classifier", num_classes = 10 )
# Create an objective layer.
# Default is negative log likelihood.
# What ever the objective is, is always minimized.
net.add_layer ( type = "objective" )
# Cook the network.
net.cook()
# See how the network looks like.
net.pretty_print()
# Train the network.
net.train()
# Test for acccuracy.
net.test()
## Boiler Plate ##
if __name__ == '__main__':
dataset = None
import sys
if len(sys.argv) > 1:
if sys.argv[1] == 'create_dataset':
from yann.special.datasets import cook_mnist
data = cook_mnist (verbose = 3)
dataset = data.dataset_location()
else:
dataset = sys.argv[1]
else:
print "provide dataset"
if dataset is None:
print " creating a new dataset to run through"
from yann.special.datasets import cook_mnist
data = cook_mnist (verbose = 3)
dataset = data.dataset_location()
log_reg ( dataset )