-
Notifications
You must be signed in to change notification settings - Fork 7.4k
/
image_dataloader.py
119 lines (96 loc) · 4.08 KB
/
image_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image dataloader."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.lite.model_maker.core.api import mm_export
from tensorflow_examples.lite.model_maker.core.data_util import dataloader
def load_image(path):
"""Loads image."""
image_raw = tf.io.read_file(path)
image_tensor = tf.cond(
tf.image.is_jpeg(image_raw),
lambda: tf.image.decode_jpeg(image_raw, channels=3),
lambda: tf.image.decode_png(image_raw, channels=3))
return image_tensor
def create_data(name, data, info, label_names):
"""Creates an ImageClassifierDataLoader object from tfds data."""
if name not in data:
return None
data = data[name]
data = data.map(lambda a: (a['image'], a['label']))
size = info.splits[name].num_examples
return ImageClassifierDataLoader(data, size, label_names)
@mm_export('image_classifier.DataLoader')
class ImageClassifierDataLoader(dataloader.ClassificationDataLoader):
"""DataLoader for image classifier."""
@classmethod
def from_folder(cls, filename, shuffle=True):
"""Image analysis for image classification load images with labels.
Assume the image data of the same label are in the same subdirectory.
Args:
filename: Name of the file.
shuffle: boolean, if shuffle, random shuffle data.
Returns:
ImageDataset containing images and labels and other related info.
"""
data_root = os.path.abspath(filename)
# Assumes the image data of the same label are in the same subdirectory,
# gets image path and label names.
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
all_image_size = len(all_image_paths)
if all_image_size == 0:
raise ValueError('Image size is zero')
if shuffle:
# Random shuffle data.
random.shuffle(all_image_paths)
label_names = sorted(
name for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name)))
all_label_size = len(label_names)
label_to_index = dict(
(name, index) for index, name in enumerate(label_names))
all_image_labels = [
label_to_index[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
autotune = tf.data.AUTOTUNE
image_ds = path_ds.map(load_image, num_parallel_calls=autotune)
# Loads label.
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(all_image_labels, tf.int64))
# Creates a dataset if (image, label) pairs.
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
tf.compat.v1.logging.info(
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names))
return ImageClassifierDataLoader(image_label_ds, all_image_size,
label_names)
@classmethod
def from_tfds(cls, name):
"""Loads data from tensorflow_datasets."""
data, info = tfds.load(name, with_info=True)
if 'label' not in info.features:
raise ValueError('info.features need to contain \'label\' key.')
label_names = info.features['label'].names
train_data = create_data('train', data, info, label_names)
validation_data = create_data('validation', data, info, label_names)
test_data = create_data('test', data, info, label_names)
return train_data, validation_data, test_data