Skip to content

Commit

Permalink
Merge pull request #2 from heavengate/add_metric
Browse files Browse the repository at this point in the history
add metric for mnist.
  • Loading branch information
heavengate committed Mar 17, 2020
2 parents e8d52f6 + 38186a0 commit 903d0f7
Show file tree
Hide file tree
Showing 7 changed files with 405 additions and 88 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.pyc
*.json
output*
*checkpoint*
105 changes: 105 additions & 0 deletions metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import six
import abc
import numpy as np
import paddle.fluid as fluid

import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

__all__ = ['Metric', 'Accuracy']


@six.add_metaclass(abc.ABCMeta)
class Metric(object):
"""
Base class for metric, encapsulates metric logic and APIs
Usage:
m = SomeMetric()
for prediction, label in ...:
m.update(prediction, label)
m.accumulate()
"""

@abc.abstractmethod
def reset(self):
"""
Reset states and result
"""
raise NotImplementedError("function 'reset' not implemented in {}.".format(self.__class__.__name__))

@abc.abstractmethod
def update(self, *args, **kwargs):
"""
Update states for metric
"""
raise NotImplementedError("function 'update' not implemented in {}.".format(self.__class__.__name__))

@abc.abstractmethod
def accumulate(self):
"""
Accumulates statistics, computes and returns the metric value
"""
raise NotImplementedError("function 'accumulate' not implemented in {}.".format(self.__class__.__name__))

def add_metric_op(self, pred, label):
"""
Add process op for metric in program
"""
return pred, label


class Accuracy(Metric):
"""
Encapsulates accuracy metric logic
"""

def __init__(self, topk=(1, ), *args, **kwargs):
super(Accuracy, self).__init__(*args, **kwargs)
self.topk = topk
self.maxk = max(topk)
self.reset()

def add_metric_op(self, pred, label, *args, **kwargs):
pred = fluid.layers.argsort(pred[0], descending=True)[1][:, :self.maxk]
correct = pred == label[0]
return correct

def update(self, correct, *args, **kwargs):
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
num_samples = len(correct)
accs.append(float(num_corrects) / num_samples)
self.total[i] += num_corrects
self.count[i] += num_samples
return accs

def reset(self):
self.total = [0.] * len(self.topk)
self.count = [0] * len(self.topk)

def accumulate(self):
res = []
for t, c in zip(self.total, self.count):
res.append(float(t) / c)
return res

29 changes: 16 additions & 13 deletions mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear

from model import Model, CrossEntropy, Input
from metrics import Accuracy


class SimpleImgConvPool(fluid.dygraph.Layer):
Expand Down Expand Up @@ -145,36 +146,38 @@ def null_guard():
parameter_list=model.parameters())
inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
model.prepare(optim, CrossEntropy(), inputs, labels)
model.prepare(optim, CrossEntropy(), Accuracy(topk=(1, 2)), inputs, labels)
if FLAGS.resume is not None:
model.load(FLAGS.resume)

for e in range(FLAGS.epoch):
train_loss = 0.0
train_acc = 0.0
val_loss = 0.0
val_acc = 0.0
print("======== train epoch {} ========".format(e))
for idx, batch in enumerate(train_loader()):
outputs, losses = model.train(batch[0], batch[1])
losses, metrics = model.train(batch[0], batch[1])

acc = accuracy(outputs[0], batch[1])[0]
train_loss += np.sum(losses)
train_acc += acc
if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format(
idx, train_loss / (idx + 1), train_acc / (idx + 1)))
print("{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%".format(
idx, train_loss / (idx + 1), metrics[0][0], metrics[0][1]))
for metric in model._metrics:
res = metric.accumulate()
print("train epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}".format(e, res[0], res[1]))
metric.reset()

print("======== eval epoch {} ========".format(e))
for idx, batch in enumerate(val_loader()):
outputs, losses = model.eval(batch[0], batch[1])
losses, metrics = model.eval(batch[0], batch[1])

acc = accuracy(outputs[0], batch[1])[0]
val_loss += np.sum(losses)
val_acc += acc
if idx % 10 == 0:
print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format(
idx, val_loss / (idx + 1), val_acc / (idx + 1)))
print("{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%".format(
idx, val_loss / (idx + 1), metrics[0][0], metrics[0][1]))
for metric in model._metrics:
res = metric.accumulate()
print("eval epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}".format(e, res[0], res[1]))
metric.reset()
model.save('mnist_checkpoints/{:02d}'.format(e))


Expand Down
88 changes: 72 additions & 16 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable
from metrics import Metric

__all__ = ['Model', 'Loss', 'CrossEntropy', 'Input']

Expand All @@ -46,6 +47,26 @@ def to_numpy(var):
return np.array(t)


def flatten_list(l):
assert isinstance(l, list), "not a list"
outl = []
splits = []
for sl in l:
assert isinstance(sl, list), "sub content not a list"
splits.append(len(sl))
outl += sl
return outl, splits


def restore_flatten_list(l, splits):
outl = []
for split in splits:
assert len(l) >= split, "list length invalid"
sl, l = l[:split], l[split:]
outl.append(sl)
return outl


def extract_args(func):
if hasattr(inspect, 'getfullargspec'):
return inspect.getfullargspec(func)[0]
Expand Down Expand Up @@ -309,15 +330,26 @@ def _run(self, inputs, labels=None):
feed[v.name] = labels[idx]

endpoints = self._endpoints[self.mode]
fetch_list = endpoints['output'] + endpoints['loss']
num_output = len(endpoints['output'])
out = self._executor.run(compiled_prog,
feed=feed,
fetch_list=fetch_list)
if self.mode == 'test':
return out[:num_output]
fetch_list = endpoints['output']
else:
return out[:num_output], out[num_output:]
metric_list, metric_splits = flatten_list(endpoints['metric'])
fetch_list = endpoints['loss'] + metric_list
num_loss = len(endpoints['loss'])
rets = self._executor.run(
compiled_prog, feed=feed,
fetch_list=fetch_list,
return_numpy=False)
# LoDTensor cannot be fetch as numpy directly
rets = [np.array(v) for v in rets]
if self.mode == 'test':
return rets[:]
losses = rets[:num_loss]
metric_states = restore_flatten_list(rets[num_loss:], metric_splits)
metrics = []
for metric, state in zip(self.model._metrics, metric_states):
metrics.append(metric.update(*state))
return (losses, metrics) if len(metrics) > 0 else losses

def prepare(self):
modes = ['train', 'eval', 'test']
Expand Down Expand Up @@ -345,6 +377,7 @@ def _make_program(self, mode):
lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
self.model._optimizer._learning_rate_map[prog] = lr_var
losses = []
metrics = []
with fluid.program_guard(prog, self._startup_prog):
if isinstance(self.model._inputs, dict):
ins = [self.model._inputs[n] \
Expand All @@ -358,6 +391,8 @@ def _make_program(self, mode):
if mode != 'test':
if self.model._loss_function:
losses = self.model._loss_function(outputs, labels)
for metric in self.model._metrics:
metrics.append(to_list(metric.add_metric_op(outputs, labels)))
if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses)
self.model._optimizer.minimize(self._loss_endpoint)
Expand All @@ -367,7 +402,7 @@ def _make_program(self, mode):
self._input_vars[mode] = inputs
self._label_vars[mode] = labels
self._progs[mode] = prog
self._endpoints[mode] = {"output": outputs, "loss": losses}
self._endpoints[mode] = {"output": outputs, "loss": losses, "metric": metrics}

def _compile_and_initialize(self, prog, mode):
compiled_prog = self._compiled_progs.get(mode, None)
Expand Down Expand Up @@ -429,33 +464,44 @@ def train(self, inputs, labels=None):
self.mode = 'train'
inputs = to_list(inputs)
if labels is not None:
labels = to_list(labels)
outputs = self.model.forward(* [to_variable(x) for x in inputs])
labels = [to_variable(l) for l in to_list(labels)]
outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs]))
losses = self.model._loss_function(outputs, labels)
final_loss = fluid.layers.sum(losses)
final_loss.backward()
self.model._optimizer.minimize(final_loss)
self.model.clear_gradients()
return [to_numpy(o) for o in to_list(outputs)], \
[to_numpy(l) for l in losses]
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, to_list(labels))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses]

def eval(self, inputs, labels=None):
super(Model, self.model).eval()
self.mode = 'eval'
inputs = to_list(inputs)
if labels is not None:
labels = to_list(labels)
outputs = self.model.forward(* [to_variable(x) for x in inputs])
labels = [to_variable(l) for l in to_list(labels)]
outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs]))

if self.model._loss_function:
losses = self.model._loss_function(outputs, labels)
else:
losses = []

metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(outputs, labels)
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)

# To be consistent with static graph
# return empty loss if loss_function is None
return [to_numpy(o) for o in to_list(outputs)], \
[to_numpy(l) for l in losses]
return ([to_numpy(l) for l in losses], metrics) \
if len(metrics) > 0 else [to_numpy(l) for l in losses]

def test(self, inputs):
super(Model, self.model).eval()
Expand Down Expand Up @@ -567,6 +613,7 @@ def load(self, *args, **kwargs):
def prepare(self,
optimizer=None,
loss_function=None,
metrics=None,
inputs=None,
labels=None,
device=None,
Expand All @@ -580,6 +627,8 @@ def prepare(self,
loss_function (Loss|None): loss function must be set in training
and should be a Loss instance. It can be None when there is
no loss.
metrics (Metric|list of Metric|None): if metrics is set, all
metric will be calculate and output in train/eval mode.
inputs (Input|list|dict|None): inputs, entry points of network,
could be a Input layer, or lits of Input layers,
or dict (name: Input), or None. For static graph,
Expand Down Expand Up @@ -615,6 +664,13 @@ def prepare(self,
"'inputs' must be list or dict in static graph mode")
if loss_function and not isinstance(labels, (list, Input)):
raise TypeError("'labels' must be list in static graph mode")

metrics = metrics or []
for metric in to_list(metrics):
assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(metric.__class__.__name__)
self._metrics = to_list(metrics)

self._inputs = inputs
self._labels = labels
self._device = device
Expand Down
Loading

0 comments on commit 903d0f7

Please sign in to comment.