Skip to content

Commit

Permalink
Save parameters after epochs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ragav Venkatesan committed Mar 13, 2017
1 parent 6caadfc commit df81a41
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 6 deletions.
3 changes: 2 additions & 1 deletion yann/modules/resultor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def __init__( self, resultor_init_args, verbose = 1):
def process_results( self,
cost,
lr,
mom,
mom,
params = None,
verbose = 2 ):
"""
This method will print results and also write them down in the appropriate files.
Expand Down
24 changes: 23 additions & 1 deletion yann/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
if nx_installed is True:
import networkx as nx

import time
import time, os
from collections import OrderedDict

import numpy
Expand Down Expand Up @@ -2831,6 +2831,7 @@ def train(self, verbose = 2, **kwargs):
Args:
epochs: ``(num_epochs for each learning rate... )`` to train Default is ``(20, 20)``
validate_after_epochs: 1, after how many epochs do you want to validate ?
save_after_epochs: 1, Save network after that many epochs of training.
show_progress: default is ``True``, will display a clean progressbar.
If ``verbose`` is ``3`` or more - False
early_terminate: ``True`` will allow early termination.
Expand Down Expand Up @@ -2861,6 +2862,11 @@ def train(self, verbose = 2, **kwargs):
else:
self.visualize_after_epochs = kwargs['visualize_after_epochs']

if not 'save_after_epochs' in kwargs.keys():
self.save_after_epochs = self.validate_after_epochs
else:
self.save_after_epochs = kwargs['save_after_epochs']

if not 'show_progress' in kwargs.keys():
show_progress = True
else:
Expand Down Expand Up @@ -2995,6 +3001,7 @@ def train(self, verbose = 2, **kwargs):
verbose = verbose )
self.visualize ( epoch = epoch_counter , verbose = verbose )
self.print_status ( epoch = epoch_counter, verbose=verbose )
self.save_params ( epoch = epoch_counter, verbose = verbose )

if best is True:
copy_params(source = self.active_params, destination= nan_insurance ,
Expand Down Expand Up @@ -3127,5 +3134,20 @@ def get_params (self, verbose = 2):
params[lyr] = params_list
return params

def save_params (self, epoch = 0, verbose = 2):
"""
This method will save down a list of network parameters
Args:
verbose: As usual
epoch: epoch.
"""

from yann.utils.pickle import pickle
if not os.path.exists (self.cooked_resultor.root + '/params'):
os.makedirs (self.cooked_resultor.root + '/params')

filename = self.cooked_resultor.root + '/params/epoch_' + str(epoch) + '.pkl'
pickle(net = self, filename = filename, verbose=verbose)
if __name__ == '__main__':
pass
9 changes: 8 additions & 1 deletion yann/special/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def train ( self, verbose, **kwargs):
k : how many discriminator updates for every generator update.
learning_rates: (annealing_rate, learning_rates ... ) length must be one more than
``epochs`` Default is ``(0.05, 0.01, 0.001)``
save_after_epochs: 1, Save network after that many epochs of training.
pre_train_discriminator: If you want to pre-train the discriminator to make it stay
ahead of the generator for making predictions. This will only
train the softmax layer loss and not the fake or real loss.
Expand Down Expand Up @@ -564,6 +565,11 @@ def train ( self, verbose, **kwargs):
else:
self.visualize_after_epochs = kwargs['visualize_after_epochs']

if not 'save_after_epochs' in kwargs.keys():
self.save_after_epochs = self.validate_after_epochs
else:
self.save_after_epochs = kwargs['save_after_epochs']

if not 'show_progress' in kwargs.keys():
show_progress = True
else:
Expand Down Expand Up @@ -606,7 +612,7 @@ def train ( self, verbose, **kwargs):

if self.softmax_head is True:
self.softmax_learning_rate.set_value(learning_rates[1])

patience_increase = 2
improvement_threshold = 0.995
best_iteration = 0
Expand Down Expand Up @@ -849,6 +855,7 @@ def train ( self, verbose, **kwargs):
training_accuracy = training_accuracy,
show_progress = show_progress,
verbose = verbose )
self.save_params ( epoch = epoch_counter, verbose = verbose )
self.visualize ( epoch = epoch_counter , verbose = verbose)

if best is True:
Expand Down
6 changes: 3 additions & 3 deletions yann/utils/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import theano
from collections import OrderedDict

def pickle(net, name, verbose = 2):
def pickle(net, filename, verbose = 2):
"""
This method saves the weights of all the layers.
Args:
net: A yann network object
name: What is the name of the file to pickle the network as.
filename: What is the name of the file to pickle the network as.
verbose: Blah..
"""
if verbose >= 3:
Expand All @@ -18,7 +18,7 @@ def pickle(net, name, verbose = 2):
if verbose >= 3:
print "... Dumping netowrk parameters"

f = open(name, 'wb')
f = open(filename, 'wb')
cPickle.dump(params, f, protocol = cPickle.HIGHEST_PROTOCOL)
f.close()

Expand Down

0 comments on commit df81a41

Please sign in to comment.