Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#15 from guoshengCS/fix-data-train
Browse files Browse the repository at this point in the history
Reorganize data from data_loader into inputs and labels.
  • Loading branch information
guoshengCS committed Apr 1, 2020
2 parents 4d22fee + 863897c commit 7d1ea67
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
42 changes: 26 additions & 16 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.layers.utils import flatten
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.base import role_maker
from paddle.fluid.io import DataLoader, Dataset
Expand Down Expand Up @@ -414,13 +415,7 @@ def _make_program(self, mode):
losses = []
metrics = []
with fluid.program_guard(prog, self._startup_prog):
if isinstance(self.model._inputs, dict):
ins = [
self.model._inputs[n]
for n in extract_args(self.model.forward) if n != 'self'
]
else:
ins = self.model._inputs
ins = self.model._inputs
lbls = self.model._labels if self.model._labels else []
inputs = [k.forward() for k in to_list(ins)]
labels = [k.forward() for k in to_list(lbls)]
Expand Down Expand Up @@ -867,8 +862,10 @@ def prepare(self,
metric.__class__.__name__)
self._metrics = to_list(metrics)

self._inputs = inputs
self._labels = labels
self._inputs = to_list(inputs) if not isinstance(inputs, dict) else [
inputs[n] for n in extract_args(self.forward) if n != 'self'
]
self._labels = to_list(labels)

if not in_dygraph_mode():
self._adapter.prepare()
Expand Down Expand Up @@ -1174,17 +1171,30 @@ def _run_one_epoch(self,
callbacks.on_epoch_begin(epoch)

for step, data in enumerate(data_loader):
if not fluid.in_dygraph_mode():
data = data[0]
batch_size = data[0].shape()[0]
else:
batch_size = data[0].shape[0]
# data might come from different types of data_loader and have
# different format, as following:
# 1. DataLoader in static graph:
# [[input1, input2, ..., label1, lable2, ...]]
# 2. DataLoader in dygraph
# [input1, input2, ..., label1, lable2, ...]
# 3. custumed iterator yield concated inputs and labels:
# [input1, input2, ..., label1, lable2, ...]
# 4. custumed iterator yield seperated inputs and labels:
# ([input1, input2, ...], [label1, lable2, ...])
# To handle all of these, flatten (nested) list to list.
data = flatten(data)
# LoDTensor.shape is callable, where LoDTensor comes from
# DataLoader in static graph
batch_size = data[0].shape()[0] if callable(data[
0].shape) else data[0].shape[0]

callbacks.on_batch_begin(mode, step, logs)
if mode == 'train':
outs = self.train(*data)
outs = self.train(data[:len(self._inputs)],
data[len(self._inputs):])
else:
outs = self.eval(*data)
outs = self.eval(data[:len(self._inputs)],
data[len(self._inputs):])

# losses
loss = outs[0] if self._metrics else outs
Expand Down
4 changes: 2 additions & 2 deletions progressbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def update(self, current_num, values=None):
eta = time_per_unit * (self._num - current_num)
if eta > 3600:
eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) //
60, eta % 60)
60, eta % 60)
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else:
Expand Down Expand Up @@ -148,7 +148,7 @@ def update(self, current_num, values=None):
else:
info += ' %.4e' % v
elif isinstance(v, np.ndarray) and \
isinstance(v.size, 1) and \
v.size == 1 and \
isinstance(v.dtype, (np.float32, np.float64)):
if abs(v[0]) > 1e-3:
info += ' %.4f' % v[0]
Expand Down

0 comments on commit 7d1ea67

Please sign in to comment.