Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
yann/pantry/tutorials/log_reg.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
73 lines (57 sloc)
2.07 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Notes: | |
This code contains one method that explains how to build a | |
logistic regression classifier for the MNIST dataset using | |
the yann toolbox. | |
For a more interactive tutorial refer the notebook at | |
yann/pantry/tutorials/notebooks/Logistic Regression.ipynb | |
""" | |
from yann.network import network | |
from yann.utils.graph import draw_network | |
def log_reg ( dataset ): | |
""" | |
This function is a demo example of logistic regression. | |
""" | |
# Create the yann network class with empty layers. | |
net = network() | |
# Setup the datastream module and add it to network. | |
dataset_params = { "dataset" : dataset, | |
"svm" : False, | |
"n_classes" : 10 } | |
net.add_module ( type = 'datastream', params = dataset_params ) | |
# Create an input layer that feeds from the datastream modele. | |
net.add_layer ( type = "input", datastream_origin = 'data') | |
# Create a logistic regression layer. | |
# Creates a softmax layer. | |
net.add_layer ( type = "classifier", num_classes = 10 ) | |
# Create an objective layer. | |
# Default is negative log likelihood. | |
# What ever the objective is, is always minimized. | |
net.add_layer ( type = "objective" ) | |
# Cook the network. | |
net.cook() | |
# See how the network looks like. | |
net.pretty_print() | |
# Train the network. | |
net.train() | |
# Test for acccuracy. | |
net.test() | |
## Boiler Plate ## | |
if __name__ == '__main__': | |
dataset = None | |
import sys | |
if len(sys.argv) > 1: | |
if sys.argv[1] == 'create_dataset': | |
from yann.special.datasets import cook_mnist | |
data = cook_mnist (verbose = 3) | |
dataset = data.dataset_location() | |
else: | |
dataset = sys.argv[1] | |
else: | |
print "provide dataset" | |
if dataset is None: | |
print " creating a new dataset to run through" | |
from yann.special.datasets import cook_mnist | |
data = cook_mnist (verbose = 3) | |
dataset = data.dataset_location() | |
log_reg ( dataset ) |