Skip to content

Commit

Permalink
fix prefetch bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 17, 2016
1 parent b81c226 commit d04661e
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 26 deletions.
7 changes: 1 addition & 6 deletions examples/cifar10_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,12 @@
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug

"""
A small cifar10 convnet model.
90% validation accuracy after 40k step.
"""

BATCH_SIZE = 128
MIN_AFTER_DEQUEUE = int(50000 * 0.4)
CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE

class Model(ModelDesc):
def _get_input_vars(self):
return [InputVar(tf.float32, [None, 30, 30, 3], 'input'),
Expand Down Expand Up @@ -134,7 +129,7 @@ def get_config():
session_config=sess_config,
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=200,
max_epoch=3,
)

if __name__ == '__main__':
Expand Down
11 changes: 4 additions & 7 deletions examples/cifar10_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug

"""
CIFAR10-resnet example.
Expand Down Expand Up @@ -45,7 +44,7 @@ def _get_input_vars(self):

def _get_cost(self, input_vars, is_training):
image, label = input_vars
image = image / 255.0
image = image / 128.0 - 1

def conv(name, l, channel, stride):
return Conv2D(name, l, channel, 3, stride=stride,
Expand Down Expand Up @@ -117,10 +116,10 @@ def residual(name, l, increase_dim=False, first=False):
# weight decay on all W of fc layers
wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(),
480000, 0.2, True)
wd_cost = wd_w * regularize_cost('.*/W', tf.nn.l2_loss)
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)

add_param_summary([('.*/W', ['histogram', 'sparsity'])]) # monitor W
add_param_summary([('.*/W', ['histogram'])]) # monitor W
return tf.add_n([cost, wd_cost], name='cost')

def get_data(train_or_test):
Expand All @@ -146,8 +145,6 @@ def get_data(train_or_test):
ds = PrefetchData(ds, 3, 2)
return ds



def get_config():
# prepare dataset
dataset_train = get_data('train')
Expand All @@ -170,7 +167,7 @@ def get_config():
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]),
session_config=sess_config,
model=Model(n=18),
model=Model(n=30),
step_per_epoch=step_per_epoch,
max_epoch=500,
)
Expand Down
15 changes: 2 additions & 13 deletions tensorpack/dataflow/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@

__all__ = ['PrefetchData']

class Sentinel:
pass

class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue):
"""
Expand All @@ -24,11 +21,9 @@ def __init__(self, ds, queue):

def run(self):
self.ds.reset_state()
try:
while True:
for dp in self.ds.get_data():
self.queue.put(dp)
finally:
self.queue.put(Sentinel())

class PrefetchData(ProxyDataFlow):
"""
Expand All @@ -52,17 +47,11 @@ def __init__(self, ds, nr_prefetch, nr_proc=1):
x.start()

def get_data(self):
end_cnt = 0
tot_cnt = 0
while True:
dp = self.queue.get()
if isinstance(dp, Sentinel):
end_cnt += 1
if end_cnt == self.nr_proc:
break
continue
tot_cnt += 1
yield dp
tot_cnt += 1
if tot_cnt == self._size:
break

Expand Down

0 comments on commit d04661e

Please sign in to comment.