Skip to content

Commit

Permalink
Bug fixes for Resultor and Confusion Matrix.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ragav Venkatesan committed Mar 6, 2017
1 parent dd5fdf0 commit ee2ca4c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
4 changes: 2 additions & 2 deletions yann/modules/resultor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__( self, resultor_init_args, verbose = 1):
print "... Creating a root directory for save files"
os.makedirs(self.root)

for file in [self.results_file, self.error_file, self.cost_file, self.confusion_file,
for file in [self.results_file, self.error_file, self.cost_file,
self.learning_rate, self.momentum]:
f = open(self.root + "/" + file, 'w')
f.close()
Expand Down Expand Up @@ -163,7 +163,7 @@ def print_confusion (self, epoch=0, train = None, valid = None, verbose = 2):
print "... Creating a root directory for saving confusions"
os.makedirs(self.root + '/confusion')

location = self.root + '/confusion/' + '/epoch_' + str(epoch)
location = self.root + '/confusion' + '/epoch_' + str(epoch)
if not os.path.exists( location ):
if verbose >=3 :
print "... Making the epoch directory"
Expand Down
33 changes: 27 additions & 6 deletions yann/special/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def cook_generator ( self, optimizer_params, verbose = 2):
self.generator_decay_learning_rate = self.decay_learning_rate
self.generator_current_momentum = self.current_momentum

def cook( self,
def cook( self,
objective_layers,
discriminator_layers,
generator_layers,
Expand Down Expand Up @@ -230,6 +230,11 @@ def cook( self,
else:
datastream = kwargs['datastream']

if not 'resultor' in kwargs.keys():
resultor = None
else:
resultor = kwargs['resultor']

if not 'params' in kwargs.keys():
params = None
else:
Expand All @@ -249,6 +254,20 @@ def cook( self,
raise Exception ("Datastream " + datastream + " not found.")
self.cooked_datastream = self.datastream[datastream]

if resultor is None:
if self.last_resultor_created is None:
if verbose >= 3:
print('... No resultor setup, creating a defualt one.')
self.add_module( type = 'resultor', verbose =verbose )
else:
if verbose >= 3:
print("... resultor not provided, assuming " + self.last_resultor_created)
resultor = self.last_resultor_created
else:
if not resultor in self.resultor.keys():
raise Exception ("Resultor " + resultor + " not found.")
self.cooked_resultor = self.resultor[resultor]

self.generator_active_params = []
self.discriminator_active_params = []

Expand Down Expand Up @@ -326,13 +345,16 @@ def cook( self,
verbose = verbose)
self.cook_generator( optimizer_params = optimizer_params,
verbose = verbose)

if self.softmax_head is True:
self._initialize_test (classifier = softmax_layer,
verbose = verbose)
self._initialize_predict ( classifier = softmax_layer,
verbose = verbose)
self._initialize_posterior (classifier = softmax_layer,
verbose = verbose)
verbose = verbose)
self._initialize_confusion (classifier = softmax_layer,
verbose = verbose)

self.initialize_train ( verbose = verbose )
self.validation_accuracy = []
Expand All @@ -356,6 +378,7 @@ def cook( self,
self.disc_cost = []
self.softmax_cost = []
self.cooked_visualizer = self.visualizer[visualizer]
self._cook_resultor(resultor = self.cooked_resultor, verbose = verbose)
self._cook_visualizer(verbose = verbose) # always cook visualizer last.
self.visualize (epoch = 0, verbose = verbose)
# Cook Resultor.
Expand All @@ -375,11 +398,9 @@ def _new_era ( self, new_learning_rate = 0.01, verbose = 2):
if self.softmax_head is True:
self.softmax_learning_rate.set_value(numpy.asarray(new_learning_rate,
dtype = theano.config.floatX))
self.real_learning_rate.set_value(numpy.asarray(new_learning_rate,
dtype = theano.config.floatX))
self.fake_learning_rate.set_value(numpy.asarray(new_learning_rate,
self.generator_learning_rate.set_value(numpy.asarray(new_learning_rate,
dtype = theano.config.floatX))
self.gen_learning_rate.set_value(numpy.asarray(new_learning_rate,
self.discriminator_learning_rate.set_value(numpy.asarray(new_learning_rate,
dtype = theano.config.floatX))
# copying and removing only active_params. Is that a porblem ?
copy_params ( source = self.best_params, destination = self.active_params ,
Expand Down

0 comments on commit ee2ca4c

Please sign in to comment.