Skip to content

Commit

Permalink
Implemented Accuracy Metric for training and evaluation. Usage:
Browse files Browse the repository at this point in the history
  trainer.set_metrics([AccuracyMetric])
  • Loading branch information
recastrodiaz committed Apr 23, 2017
1 parent 879fae8 commit a35cd1d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 7 deletions.
50 changes: 50 additions & 0 deletions torchsample/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from collections import OrderedDict

"""
MetricsModule that implements batch and epoch metrics such as Accuracy
"""

class MetricsModule():

def __init__(self, metrics_classes):
self._metrics = [ metric_class() for metric_class in metrics_classes]

def update(self, predictions, target):
for metric in self._metrics:
metric.update(predictions, target)

def get_logs(self, prefix = ''):
logs = OrderedDict()
for metric in self._metrics:
logs.update(metric.get_logs(prefix))
return logs

class Metric():

def update(self, predictions, target):
raise NotImplementedError()

def get_logs(self, prefix):
raise NotImplementedError()

class AccuracyMetric(Metric):

def __init__(self):
self.correct_count = 0
self.total_count = 0
self.accuracy = 0

def get_prediction_classes_ids(self, predictions):
# returns the predictions in id format
values, predictions_ids = predictions.max(1)
return predictions_ids

def update(self, predictions, target):
prediction_classes_ids = self.get_prediction_classes_ids(predictions).cpu()
target_classes_ids = target.cpu()
self.correct_count += target_classes_ids.eq(prediction_classes_ids).sum()
self.total_count += predictions.size(0)
self.accuracy = 100.0 * self.correct_count / self.total_count

def get_logs(self, prefix):
return { prefix + 'acc' : self.accuracy }
27 changes: 20 additions & 7 deletions torchsample/modules/module_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..callbacks import CallbackModule, History, TQDM
from ..constraints import ConstraintModule
from ..regularizers import RegularizerModule
from ..metrics import MetricsModule


class ModuleTrainer():
Expand All @@ -38,6 +39,7 @@ def __init__(self, model, use_cuda):
self._callbacks = [self.history]
self._constraints = []
self._regularizers = []
self._metrics = []
self.stop_training = False

def set_loss(self, loss):
Expand All @@ -58,6 +60,9 @@ def set_constraints(self, constraints):

def set_callbacks(self, callbacks):
self._callbacks += callbacks

def set_metrics(self, metrics):
self._metrics = metrics

def fit(self,
x,
Expand Down Expand Up @@ -133,6 +138,7 @@ def fit_loader(self,
'nb_epoch': nb_epoch
}
callbacks.on_epoch_begin(epoch_idx, epoch_logs)
batch_metrics = MetricsModule(self._metrics)

for batch_idx,(x_batch, y_batch) in enumerate(loader):
batch_logs = {
Expand All @@ -157,6 +163,9 @@ def fit_loader(self,
loss += reg_loss
batch_logs['reg_loss'] = reg_loss
batch_logs['loss'] = loss.data[0]

batch_metrics.update(outputs.data, targets.data)
batch_logs.update(batch_metrics.get_logs())

# make backward pass
loss.backward()
Expand All @@ -167,14 +176,16 @@ def fit_loader(self,
constraints.on_batch_end(batch_idx)

if val_loader is not None:
val_loss = self.evaluate_loader(val_loader, self._loss)
epoch_logs['val_loss'] = val_loss
val_metrics = MetricsModule(self._metrics)
val_loss = self.evaluate_loader(val_loader, val_metrics)
epoch_logs.update(val_metrics.get_logs(prefix = 'val_'))

epoch_logs.update(batch_metrics.get_logs())

epoch_logs['loss'] = self.history.loss / self.history.samples_seen
if regularizers is not None:
epoch_logs['reg_loss'] = self.history.reg_loss / self.history.samples_seen

epoch_logs['val_acc'] = 0.0

callbacks.on_epoch_end(epoch_idx, epoch_logs)
constraints.on_epoch_end(epoch_idx)
if self.stop_training:
Expand Down Expand Up @@ -223,10 +234,10 @@ def evaluate(self,
verbose=1):
dataset = TensorDataset(x,y)
loader = DataLoader(dataset, batch_size=batch_size)
loss = self.evaluate_loader(loader, self._loss)
loss = self.evaluate_loader(loader)
return loss

def evaluate_loader(self, loader, loss_f):
def evaluate_loader(self, loader, metrics = MetricsModule([])):
self._model.eval()
total_loss = 0.
total_samples = 0.
Expand All @@ -238,9 +249,11 @@ def evaluate_loader(self, loader, loss_f):
y_batch = Variable(y_batch, volatile=True)

y_pred = self._model(x_batch)
loss = loss_f(y_pred, y_batch)
loss = self._loss(y_pred, y_batch)
total_loss += loss.data[0]*len(x_batch)
total_samples += len(x_batch)

metrics.update(y_pred.data, y_batch.data)
self._model.train()
return total_loss / total_samples

Expand Down

0 comments on commit a35cd1d

Please sign in to comment.