Skip to content

Commit

Permalink
Optional feature spec
Browse files Browse the repository at this point in the history
Bug fixed at eval initializer
  • Loading branch information
shygiants committed Jul 2, 2019
1 parent 86765f5 commit d0cbbb2
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 56 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
'tensorflow==1.13.1'
]
},
version='0.11.2',
version='0.11.3',
description='Libraries for easy bootstrapping TensorFlow project',
author='Sanghoon Yoon',
author_email='shygiants@gmail.com',
Expand Down
12 changes: 10 additions & 2 deletions tflibs/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def process_wrapper(coll, length, thread_idx):

for processed_e in processed:
# Build feature proto
nested_feature = map_dict(lambda k, v: (k, v.feature_proto(processed_e[k])), feature_specs)
nested_feature = {k: feature_spec.feature_proto(processed_e[k]) for k, feature_spec in
feature_specs.items() if k in processed_e}
# Flatten nested dict
feature = flatten_nested_dict(nested_feature)

Expand Down Expand Up @@ -207,13 +208,20 @@ def get_ranges(length):
else:
raise ValueError('`length` should be provided')

threads = []
for i, coll in enumerate(colls):
kwargs = coll
kwargs.update(thread_idx=i)

thread = threading.Thread(target=process_wrapper, kwargs=kwargs)
threads.append(thread)
thread.start()

for i, thread in enumerate(threads):
tf.logging.info('Waiting for joining thread {}'.format(i))
thread.join()
tf.logging.info('Thread {} is joined'.format(i))

def read(self, split=None, num_parallel_reads=16, num_parallel_calls=16, cache=True):
"""
Reads tfrecord and makes it tf.data.Dataset
Expand All @@ -229,7 +237,7 @@ def dataset_fn():
feature_specs = self.feature_specs(split=split)

def parse(record):
return map_dict(lambda k, v: (k, v.parse(k, record)), feature_specs)
return {k: feature_spec.parse(k, record) for k, feature_spec in feature_specs.items()}

fname, ext = os.path.splitext(self.tfrecord_filename)
fname_pattern = '{fname}{split}*{ext}'
Expand Down
126 changes: 81 additions & 45 deletions tflibs/datasets/feature_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from __future__ import division
from __future__ import print_function

from functools import reduce
from operator import mul

import tensorflow as tf
import numpy as np
from enum import Enum
Expand Down Expand Up @@ -57,15 +60,8 @@ def _float_feature(values):
values = [values]
return tf.train.Feature(float_list=tf.train.FloatList(value=values))

def __init__(self, shape):
self._shape = shape

@property
def shape(self):
return self._shape

@property
def feature_proto_spec(self):
def feature_proto_specs(self):
"""
A property for specifying inner encoding spec of the feature
Expand All @@ -83,21 +79,21 @@ def feature_proto(self, value_dict):
:rtype: dict
"""

def map_fn(k, v):
dtype = v['dtype']
value = value_dict[k]
def make_feature(feature_proto_spec, value):
dtype = feature_proto_spec['dtype']

if isinstance(value, np.ndarray):
value = value.tolist()

if dtype == tf.string:
return k, FeatureSpec._bytes_feature(value)
return FeatureSpec._bytes_feature(value)
elif dtype == tf.float16 or dtype == tf.float32 or dtype == tf.float64:
return k, FeatureSpec._float_feature(value)
return FeatureSpec._float_feature(value)
elif dtype == tf.int8 or dtype == tf.int16 or dtype == tf.int32 or dtype == tf.int64:
return k, FeatureSpec._int64_feature(value)
return FeatureSpec._int64_feature(value)

return map_dict(map_fn, self.feature_proto_spec)
return {k: make_feature(feature_proto_spec, value_dict[k])
for k, feature_proto_spec in self.feature_proto_specs.items() if k in value_dict}

def parse(self, parent_key, record):
"""
Expand All @@ -109,11 +105,40 @@ def parse(self, parent_key, record):
:rtype: dict
"""

def parse(k, v):
return '{}/{}'.format(parent_key, k), tf.FixedLenFeature(v['shape'], v['dtype'])
def make_feature(feature_proto):
if 'default' in feature_proto:
default_value = feature_proto['default']
elif 'required' in feature_proto and not feature_proto['required']:
size = reduce(mul, feature_proto['shape'], 1)
dtype = feature_proto['dtype'] # type: tf.DType

if dtype.is_floating:
value = -1.
elif dtype.is_integer:
value = -1
elif dtype == tf.string:
value = ''
else:
raise ValueError('')

if size == 0:
default_value = value
else:
default_value = [value] * size
else:
default_value = None

return tf.FixedLenFeature(feature_proto['shape'], feature_proto['dtype'],
default_value=default_value)

features = {
'{}/{}'.format(parent_key, k): make_feature(feature_spec)
for k, feature_spec in self.feature_proto_specs.items()
}

features = map_dict(parse, self.feature_proto_spec)
return tf.parse_single_example(record, features)
parsed = tf.parse_single_example(record, features)

return {k.split('/')[-1]: v for k, v in parsed.items()}


class ScalarSpec(FeatureSpec):
Expand All @@ -139,7 +164,7 @@ def pydtype(self):
raise ValueError('Invalid dtype')

@property
def feature_proto_spec(self):
def feature_proto_specs(self):
return {
'value': {
'shape': (),
Expand All @@ -160,11 +185,8 @@ class IDSpec(FeatureSpec):
"""A class for specifying unique ID.
"""

def __init__(self):
super(IDSpec, self).__init__(())

@property
def feature_proto_spec(self):
def feature_proto_specs(self):
"""
A property for specifying inner encoding spec of the feature
Expand All @@ -189,7 +211,7 @@ def parse(self, parent_key, record):
"""
parsed = FeatureSpec.parse(self, parent_key, record)

return parsed['{}/_id'.format(parent_key)]
return parsed['_id']

def create_with_string(self, string):
return {
Expand All @@ -201,15 +223,22 @@ class ImageSpec(FeatureSpec):
"""
A class for specifying image spec
:param list|tuple image_size: The sizes of images
:param list|tuple image_shape: The sizes of images
"""

def __init__(self, image_size):
image_size = list(image_size)
super(ImageSpec, self).__init__(image_size)
def __init__(self, image_shape):
self._image_shape = tuple(image_shape)

@property
def feature_proto_spec(self):
def image_shape(self):
return self._image_shape

@property
def channels(self):
return self.image_shape[-1]

@property
def feature_proto_specs(self):
"""
A property for specifying inner encoding spec of the feature
Expand All @@ -233,8 +262,8 @@ def parse(self, parent_key, record):
:rtype: tf.Tensor
"""
parsed = FeatureSpec.parse(self, parent_key, record)
decoded = tf.image.decode_image(parsed['{}/{}'.format(parent_key, 'encoded')], channels=self.shape[-1])
decoded = tf.reshape(decoded, self.shape)
decoded = tf.image.decode_image(parsed['encoded'], channels=self.channels)
decoded = tf.reshape(decoded, self.image_shape)

return decoded

Expand Down Expand Up @@ -267,11 +296,10 @@ class VarImageSpec(FeatureSpec):
"""

def __init__(self, channels):
super(VarImageSpec, self).__init__(())
self._channels = channels

@property
def feature_proto_spec(self):
def feature_proto_specs(self):
"""
A property for specifying inner encoding spec of the feature
Expand Down Expand Up @@ -299,7 +327,7 @@ def parse(self, parent_key, record):
:rtype: tf.Tensor
"""
parsed = FeatureSpec.parse(self, parent_key, record)
decoded = tf.image.decode_image(parsed['{}/{}'.format(parent_key, 'encoded')], channels=self.channels)
decoded = tf.image.decode_image(parsed['encoded'], channels=self.channels)
decoded.set_shape((None, None, self.channels))

return decoded
Expand Down Expand Up @@ -334,11 +362,15 @@ class LabelSpec(FeatureSpec):
"""

def __init__(self, depth, class_names=None):
super(LabelSpec, self).__init__([depth])
self._depth = depth
self._class_names = class_names

@property
def feature_proto_spec(self):
def depth(self):
return self._depth

@property
def feature_proto_specs(self):
"""
A property for specifying inner encoding spec of the feature
Expand All @@ -363,10 +395,10 @@ def parse(self, parent_key, record):
"""
parsed = FeatureSpec.parse(self, parent_key, record)

return tf.one_hot(parsed['{}/{}'.format(parent_key, 'index')], self.shape[0])
return tf.one_hot(parsed['index'], self.depth)

def create_with_index(self, index):
assert index < self.shape[0]
assert index < self.depth

return {
'index': index
Expand Down Expand Up @@ -401,11 +433,15 @@ class MultiLabelSpec(FeatureSpec):
"""

def __init__(self, depth, class_names=None):
super(MultiLabelSpec, self).__init__([depth])
self._depth = depth
self._class_names = class_names

@property
def feature_proto_spec(self):
def depth(self):
return self._depth

@property
def feature_proto_specs(self):
"""
A property for specifying inner encoding spec of the feature
Expand All @@ -414,7 +450,7 @@ def feature_proto_spec(self):
"""
return {
'tensor': {
'shape': self.shape,
'shape': (self.depth,),
'dtype': tf.int64
}
}
Expand Down Expand Up @@ -443,7 +479,7 @@ def parse(self, parent_key, record):
"""
parsed = FeatureSpec.parse(self, parent_key, record)

return parsed['{}/{}'.format(parent_key, 'tensor')]
return parsed['tensor']

def create_with_tensor(self, tensor):
# TODO: Assert shape
Expand All @@ -468,7 +504,7 @@ def __init__(self, enum_cls):
self.enum = enum_cls

@property
def feature_proto_spec(self):
def feature_proto_specs(self):
"""
A property for specifying inner encoding spec of the feature
Expand All @@ -493,7 +529,7 @@ def parse(self, parent_key, record):
"""
parsed = FeatureSpec.parse(self, parent_key, record)

return parsed['{}/enum'.format(parent_key)]
return parsed['enum']

def create_with_string(self, string: str):
if string not in [lambda e: e.value, list(self.enum)]:
Expand Down
11 changes: 3 additions & 8 deletions tflibs/runner/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,7 @@ def handle(self, parse_args, unknown):
handled_args, unknown = ModelInitializer.handle(self, parse_args, unknown)

model_cls = handled_args['model_cls']

# Parse model-specific arguments
parser = argparse.ArgumentParser()
model_cls.add_model_args(parser, parse_args)
model_args, unknown = parser.parse_known_args(unknown)
log_parse_args(model_args, 'Model arguments')
model_args = handled_args['model_args']

# Parse model-specific eval arguments
parser = argparse.ArgumentParser()
Expand All @@ -342,7 +337,7 @@ def handle(self, parse_args, unknown):
eval_args.update({'eval_batch_size': parse_args.eval_batch_size})

model_params = {
'model_args': vars(model_args),
'model_args': model_args,
'eval_args': eval_args,
}

Expand All @@ -364,4 +359,4 @@ def handle(self, parse_args, unknown):
return {'estimator': estimator,
'eval_batch_size': parse_args.eval_batch_size,
'model_cls': model_cls,
'eval_map_fn': model_cls.make_map_fn('eval', **vars(model_args))}, unknown
'eval_map_fn': model_cls.make_map_fn('eval', **model_args)}, unknown

0 comments on commit d0cbbb2

Please sign in to comment.