Skip to content

Commit

Permalink
map/filter dataflow
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 19, 2016
1 parent 174c3fc commit 5476b48
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion examples/ResNet/svhn_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def get_config():
InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError() ]),
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (33, 0.001), (60, 0.0001)])
[(1, 0.1), (20, 0.01), (28, 0.001), (50, 0.0001)])
]),
session_config=sess_config,
model=Model(n=18),
Expand Down
22 changes: 14 additions & 8 deletions tensorpack/dataflow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,35 +144,41 @@ def get_data(self):
yield [self.rng.random_sample(k) for k in self.shapes]

class MapData(ProxyDataFlow):
""" Map a function on the datapoint"""
""" Apply map/filter a function on the datapoint"""
def __init__(self, ds, func):
"""
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a original datapoint, returns a new datapoint
:param func: a function that takes a original datapoint, returns a new
datapoint. return None to skip this data point.
"""
super(MapData, self).__init__(ds)
self.func = func

def get_data(self):
for dp in self.ds.get_data():
yield self.func(dp)
ret = self.func(dp)
if ret is not None:
yield ret

class MapDataComponent(ProxyDataFlow):
""" Apply a function to the given index in the datapoint"""
""" Apply map/filter on the given index in the datapoint"""
def __init__(self, ds, func, index=0):
"""
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a datapoint dp[index], returns a new value of dp[index]
:param func: a function that takes a datapoint dp[index], returns a
new value of dp[index]. return None to skip this datapoint.
"""
super(MapDataComponent, self).__init__(ds)
self.func = func
self.index = index

def get_data(self):
for dp in self.ds.get_data():
dp = copy.deepcopy(dp) # avoid modifying the original dp
dp[self.index] = self.func(dp[self.index])
yield dp
repl = self.func(dp[self.index])
if repl is not None:
dp = copy.deepcopy(dp) # avoid modifying the original dp
dp[self.index] = repl
yield dp

class RandomChooseData(DataFlow):
"""
Expand Down

0 comments on commit 5476b48

Please sign in to comment.