Skip to content

Commit

Permalink
adding a shared_params method.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ragav Venkatesan committed Feb 28, 2017
1 parent 9c852e4 commit 686be09
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
6 changes: 2 additions & 4 deletions yann/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2216,7 +2216,7 @@ def cook(self, verbose = 2, **kwargs):
if verbose >= 3:
print("... datastream not provided, assuming " + self.last_datastream_created)
datastream = self.last_datastream_created
else:
else:
if not datastream in self.datastream.keys():
raise Exception ("Datastream " + datastream + " not found.")
self.cooked_datastream = self.datastream[datastream]
Expand All @@ -2235,7 +2235,6 @@ def cook(self, verbose = 2, **kwargs):
self.cost = []

self._cook_datastream(verbose = verbose)

self._cook_optimizer(params = params,
optimizer = self.cooked_optimizer,
objective = self.dropout_cost,
Expand Down Expand Up @@ -2755,8 +2754,7 @@ def get_params (self, verbose = 2):
print "... Collecting parameters of layer " + lyr
params_list.append(p.get_value(borrow = True))
params[lyr] = params_list

return params
return params

if __name__ == '__main__':
pass
20 changes: 16 additions & 4 deletions yann/utils/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,23 @@ def load(infile, verbose = 2):
if verbose >= 3:
print "... Loading netowrk parameters"
params_np = cPickle.load( open( infile, "rb" ) )
return shared_params (params_np)

for lyr in params_np:
def shared_params (params, verbose = 2):
"""
This will convert a loaded set of parameters to shared variables that could be
passed as ``input_params`` to the ``add_layer`` method.
Args:
params: List from ``get_params`` method.
"""
if verbose >= 3:
print "... Convering parameters to shared parameters"

for lyr in params:
shared_list = list()
for p in params_np[lyr]:
ps = theano.shared(value = p, borrow=True )
for p in params[lyr]:
ps = theano.shared( value = p, borrow=True )
shared_list.append(ps)
params [lyr] = shared_list
params [lyr] = shared_list
return params

0 comments on commit 686be09

Please sign in to comment.