Skip to content

Commit

Permalink
Merge pull request #25 from ragavvenkatesan/dev
Browse files Browse the repository at this point in the history
Adding a pickle feature. This will enable saving the network down and…
  • Loading branch information
Ragav Venkatesan committed Feb 23, 2017
2 parents cb8f868 + 089a1d2 commit 1c1bb91
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 22 deletions.
1 change: 1 addition & 0 deletions docs/source/yann/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ their documentations.

dataset
graph
pickle
raster
14 changes: 14 additions & 0 deletions docs/source/yann/utils/pickle.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. _pickle:

:mod:`pickle` - provides a way to save the network' parameters as a pickle file.
================================================================================

The file ``yann.utils.pickle.py`` contains the definition for the pickle methods. Use pickle
method in the file to save the params down as a pickle file. Note that this only saves the
parameters down and not the architecture or optimizers or other modules. The id of the layers
will also be saved along as dictionary keys so you can use them to create a network.

The documentation follows:

.. automodule:: yann.utils.pickle
:members:
16 changes: 7 additions & 9 deletions pantry/tutorials/lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def lenet5 ( dataset= None, verbose = 1 ):
verbose = verbose
)

net.train( epochs = (40, 40 ),
net.train( epochs = (40, 40),
validate_after_epochs = 1,
training_accuracy = True,
learning_rates = learning_rates,
Expand Down Expand Up @@ -207,7 +207,7 @@ def lenet_maxout ( dataset= None, verbose = 1 ):
filter_size = (5,5),
pool_size = (2,2),
activation = ('maxout', 'maxout', 2),
batch_norm = True,
# batch_norm = True,
regularize = True,
verbose = verbose
)
Expand All @@ -219,7 +219,7 @@ def lenet_maxout ( dataset= None, verbose = 1 ):
filter_size = (3,3),
pool_size = (2,2),
activation = ('maxout', 'maxout', 2),
batch_norm = True,
# batch_norm = True,
regularize = True,
verbose = verbose
)
Expand Down Expand Up @@ -274,7 +274,7 @@ def lenet_maxout ( dataset= None, verbose = 1 ):
)
#draw_network(net.graph, filename = 'lenet.png')
net.pretty_print()

net.train( epochs = (40, 40),
validate_after_epochs = 1,
visualize_after_epochs = 1,
Expand All @@ -285,7 +285,7 @@ def lenet_maxout ( dataset= None, verbose = 1 ):
verbose = verbose)

net.test(verbose = verbose)

## Boiler Plate ##
if __name__ == '__main__':
import sys
Expand All @@ -308,7 +308,5 @@ def lenet_maxout ( dataset= None, verbose = 1 ):
data = cook_mnist()
dataset = data.dataset_location()

# lenet5 ( dataset, verbose = 3 )
lenet_maxout (dataset, verbose = 2)


lenet5 ( dataset, verbose = 2 )
# lenet_maxout (dataset, verbose = 3)
8 changes: 4 additions & 4 deletions yann/modules/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def calculate_gradients(self, params, objective, verbose = 1):
self.gradients = []
for param in params:
if verbose >=3 :
print ".. Estimating gradient of parameter ",
print "... Estimating gradient of parameter ",
print param
try:
gradient = T.grad( objective ,param)
Expand Down Expand Up @@ -165,7 +165,7 @@ def create_updates(self, params, verbose = 1):
velocities = []
for param in params:
if verbose >=3 :
print ".. Estimating velocity of parameter ",
print "... Estimating velocity of parameter ",
print param
velocity = theano.shared(numpy.zeros(param.get_value(borrow=True).shape,
dtype=theano.config.floatX))
Expand All @@ -176,7 +176,7 @@ def create_updates(self, params, verbose = 1):
accumulator_2 = []
for param in params:
if verbose >=3 :
print ".. Accumulating gradinent of parameter " ,
print "... Accumulating gradinent of parameter " ,
print param
eps = numpy.zeros_like(param.get_value(borrow=True), dtype=theano.config.floatX)
accumulator_1.append(theano.shared(eps, borrow=True))
Expand Down Expand Up @@ -204,7 +204,7 @@ def create_updates(self, params, verbose = 1):
for velocity, gradient, acc_1 , acc_2, param in zip(velocities, self.gradients,
accumulator_1, accumulator_2, params):
if verbose >=3 :
print ".. Backprop of parameter ",
print "... Backprop of parameter ",
print param

if self.optimizer_type == 'adagrad':
Expand Down
8 changes: 1 addition & 7 deletions yann/modules/resultor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class resultor(module):
"errors" : "<error_file_name>.txt",
"costs" : "<cost_file_name>.txt",
"confusion" : "<confusion_file_name>.txt",
"network" : "<network_save_file_name>.pkl"
"learning_rate" : "<learning_rate_file_name>.txt"
"momentum" : <momentum_file_name>.txt
"visualize" : <bool>
Expand Down Expand Up @@ -57,9 +56,6 @@ def __init__( self, resultor_init_args, verbose = 1):
if not "confusion" in resultor_init_args.keys():
resultor_init_args["confusion"] = "confusion.txt"

if not "network" in resultor_init_args.keys():
resultor_init_args["network"] = "network.pkl"

if not "learning_rate" in resultor_init_args.keys():
resultor_init_args["learning_rate"] = "learning_rate.txt"

Expand All @@ -80,8 +76,6 @@ def __init__( self, resultor_init_args, verbose = 1):
self.cost_file = value
elif item == "confusion":
self.confusion_file = value
elif item == "network":
self.network_file = value
elif item == "learning_rate":
self.learning_rate = value
elif item == "momentum":
Expand Down Expand Up @@ -146,4 +140,4 @@ def update_plot (self, verbose = 2):
"""
This method should update the open plots with costs and other values.
"""
print "TBD"
print "TBD"
28 changes: 26 additions & 2 deletions yann/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,6 @@ def _add_resultor(self, resultor_params = None, verbose = 2):
"errors" : "errors.txt",
"costs" : "costs.txt",
"confusion" : "confusion.txt",
"network" : "network.pkl",
"learning_rate" : "learning_rate.txt",
"momentum" : "momentum.txt",
"visualize" : False,
Expand Down Expand Up @@ -2600,5 +2599,30 @@ def test(self, show_progress = True, verbose = 2):
if verbose >= 2:
print(".. Mean testing error : " + str(testing_accuracy))

def get_params (self, verbose = 2):
"""
This method returns a dictionary of layer weights and bias in numpy format.
Args:
verbose: Blah..
Returns:
OrderedDict: A dictionary of parameters.
"""
if verbose >=3:
print "... Collecting network parameters"
params = OrderedDict()
for lyr in self.dropout_layers.keys():
params_list = list()
if not self.dropout_layers[lyr].params is None:
for p in self.dropout_layers[lyr].params:
if verbose >=3:
print "... Collecting parameters of layer " + lyr
params_list.append(p.get_value(borrow = True))
params[lyr] = params_list

return params

if __name__ == '__main__':
pass
pass
48 changes: 48 additions & 0 deletions yann/utils/pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import cPickle
import theano
from collections import OrderedDict

def pickle(net, name, verbose = 2):
"""
This method saves the weights of all the layers.
Args:
net: A yann network object
name: What is the name of the file to pickle the network as.
verbose: Blah..
"""
if verbose >= 3:
print "... Collecting Parameters"

params = net.get_params(verbose = verbose)
if verbose >= 3:
print "... Dumping netowrk parameters"

f = open(name, 'wb')
cPickle.dump(params, f, protocol = cPickle.HIGHEST_PROTOCOL)
f.close()

def load(infile, verbose = 2):
"""
This method loads a pickled network and returns the parameters.
Args:
infile: Filename of the network pickled by this pickle method.
Returns:
params: A dictionary of parameters.
"""
if verbose >= 2:
print ".. Loading the network."

params = OrderedDict()
if verbose >= 3:
print "... Loading netowrk parameters"
params_np = cPickle.load( open( infile, "rb" ) )

for lyr in params_np:
shared_list = list()
for p in params_np[lyr]:
ps = theano.shared(value = p, borrow=True )
shared_list.append(ps)
params [lyr] = shared_list
return params

0 comments on commit 1c1bb91

Please sign in to comment.