Skip to content

Commit

Permalink
fix swig name. add some df
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 22, 2016
1 parent 3efce3a commit b7766fc
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 17 deletions.
56 changes: 42 additions & 14 deletions tensorpack/dataflow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from ..utils import *

__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData', 'RandomMixData', 'JoinData']
'MapDataComponent', 'RandomChooseData', 'RandomMixData',
'JoinData', 'ConcatData', 'SelectComponent']

class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
Expand Down Expand Up @@ -249,7 +250,7 @@ def get_data(self):
for k in idxs:
yield next(itrs[k])

class JoinData(DataFlow):
class ConcatData(DataFlow):
"""
Concatenate several dataflows.
"""
Expand All @@ -271,21 +272,48 @@ def get_data(self):
for dp in d.get_data():
yield dp

class SelectComponent(ProxyDataFlow):
class JoinData(DataFlow):
"""
Select component from a datapoint.
Join the components from each DataFlow.
e.g.: df1: [dp1, dp2]
df2: [dp3, dp4]
join: [dp1, dp2, dp3, dp4]
"""
def __init__(self, ds, idxs):
def __init__(self, df_lists):
"""
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
:param df_lists: list of :mod:`DataFlow` instances
"""
super(SelectComponent, self).__init__(ds)
self.idxs = idxs
self.df_lists = df_lists
self._size = self.df_lists[0].size()
for d in self.df_lists:
assert d.size() == self._size, \
"All DataFlow must have the same size! {} != {}".format(d.size(), self._size)

def reset_state(self):
for d in self.df_lists:
d.reset_state()

def size(self):
return self._size

def get_data(self):
for dp in self.ds.get_data():
newdp = []
for idx in self.idxs:
newdp.append(dp[idx])
yield newdp
itrs = [k.get_data() for k in self.df_lists]
try:
while True:
dp = []
for itr in itrs:
dp.extend(next(itr))
yield dp
except StopIteration:
pass
finally:
for itr in itrs:
del itr

def SelectComponent(ds, idxs):
"""
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
"""
return MapData(ds, lambda dp: [dp[i] for i in idxs])

3 changes: 2 additions & 1 deletion tensorpack/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def get_predict_func(config):

# check output_var_names against output_vars
if output_var_names is not None:
output_vars = [tf.get_default_graph().get_tensor_by_name(n) for n in output_var_names]
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names]
else:
output_vars = []

Expand Down
2 changes: 1 addition & 1 deletion tensorpack/tfutils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..utils.naming import *
import tensorflow as tf

def get_default_sess_config(mem_fraction=0.5):
def get_default_sess_config(mem_fraction=0.99):
"""
Return a better session config to use as default.
Tensorflow default session config consume too much resources.
Expand Down
2 changes: 1 addition & 1 deletion tensorpack/tfutils/sessinit.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _produce_restore_dict(vars_multimap):
@staticmethod
def _read_checkpoint_vars(model_path):
reader = tf.train.NewCheckpointReader(model_path)
return set(reader.GetVariableToShapeMap().keys())
return set(reader.get_variable_to_shape_map().keys())

@staticmethod
def _get_vars_to_restore_multimap(vars_available):
Expand Down

0 comments on commit b7766fc

Please sign in to comment.