Skip to content

Commit

Permalink
testing accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
Ragav Venkatesan committed Mar 13, 2017
1 parent 8041309 commit 6caadfc
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
24 changes: 18 additions & 6 deletions yann/modules/resultor.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,15 @@ def update_plot (self, verbose = 2):
"""
print "TBD"

def print_confusion (self, epoch=0, train = None, valid = None, verbose = 2):
def print_confusion (self, epoch=0, train = None, valid = None, test = None, verbose = 2):
"""
This method will print the confusion matrix down in files.
Args:
epoch: This is used merely to create a directory for each epoch so that there is a copy.
train: training confusion matrix as gained by the validate method.
valid: validation confusion amtrix as gained by the validate method.
test: testing confusion matrix as gained by the test method.
verbose: As usual.
"""
if verbose >=3:
Expand All @@ -173,14 +174,20 @@ def print_confusion (self, epoch=0, train = None, valid = None, verbose = 2):
if verbose >=3 :
print ("... Saving down the confusion matrix")

self._store_confusion_img (confusion = train,
if not train is None:
self._store_confusion_img (confusion = train,
filename = location + '/train_confusion.eps',
verbose = 2)

self._store_confusion_img (confusion = valid,
if not valid is None:
self._store_confusion_img (confusion = valid,
filename = location + '/valid_confusion.eps',
verbose = 2)

if not test is None:
self._store_confusion_img (confusion = test,
filename = location + '/test_confusion.eps',
verbose = 2)

def _store_confusion_img (self, confusion, filename, verbose = 2):
"""
Convert a normalized confusion matrix into an image and save it down.
Expand All @@ -190,14 +197,19 @@ def _store_confusion_img (self, confusion, filename, verbose = 2):
filename: save the image at the location as a file.
verbose: as usual.
"""
corrects = numpy.trace(confusion)
total_samples = numpy.sum(confusion)
accuracy = 100 * corrects / float(total_samples)
if verbose >= 3:
print ("... Saving the file down")
confusion = confusion / confusion.sum(axis = 1)[:,None]
fig = plt.figure(figsize=(4, 4), dpi=1200)
plt.matshow(confusion)
for (i, j), z in numpy.ndenumerate(confusion):
plt.text(j, i, '{:0.2f}'.format(z), ha='center', va='center', fontsize=10, color = 'm')
plt.title('Confusion matrix')
plt.text(j, i, '{:0.2f}'.format(z), ha='center', va='center', fontsize=10, color = 'm')

plt.title('Accuracy: ' + str(int(corrects)) + '/' + str(int(total_samples)) + \
' = ' + str(round(accuracy,2)) + '%')
plt.set_cmap('GnBu')
plt.colorbar()
plt.ylabel('True labels')
Expand Down
12 changes: 11 additions & 1 deletion yann/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3041,6 +3041,11 @@ def test(self, show_progress = True, verbose = 2):
labels = []
total_mini_batches = self.batches2test * self.mini_batches_per_batch[2]

if self.network_type == 'classifier':
test_confusion_matrix = numpy.zeros((self.num_classes_to_classify,
self.num_classes_to_classify),
dtype = theano.config.floatX)

if show_progress is True:
bar = progressbar.ProgressBar(maxval=total_mini_batches, \
widgets=[progressbar.AnimatedMarker(), \
Expand All @@ -3056,6 +3061,8 @@ def test(self, show_progress = True, verbose = 2):
predictions = predictions + self.mini_batch_predictions(minibatch).tolist()
if self.network_type == 'classifier':
posteriors = posteriors + self.mini_batch_posterior(minibatch).tolist()
test_confusion_matrix = test_confusion_matrix + \
self.mini_batch_confusion (minibatch)
if verbose >= 3:
print("... testing error after mini batch " + str(batch_counter) + \
" is " + str(wrong))
Expand All @@ -3066,10 +3073,13 @@ def test(self, show_progress = True, verbose = 2):
if show_progress is True:
bar.finish()

self.cooked_resultor.print_confusion (epoch = 'fin',
test = test_confusion_matrix,
verbose = verbose)

total_samples = total_mini_batches * self.mini_batch_size
if self.network_type == 'classifier':
testing_accuracy = (total_samples - wrong)*100. / total_samples

if verbose >= 2:
print(".. Testing accuracy : " + str(testing_accuracy))
elif self.network_type == 'generator':
Expand Down

0 comments on commit 6caadfc

Please sign in to comment.