/
char_recognition_tasks.py
220 lines (192 loc) · 8.82 KB
/
char_recognition_tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright 2021, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library for creating character recognition tasks on EMNIST."""
import enum
from typing import Optional, Union
import tensorflow as tf
from tensorflow_federated.python.learning.models import keras_utils
from tensorflow_federated.python.learning.models import variable
from tensorflow_federated.python.simulation.baselines import baseline_task
from tensorflow_federated.python.simulation.baselines import client_spec
from tensorflow_federated.python.simulation.baselines import task_data
from tensorflow_federated.python.simulation.baselines.emnist import emnist_models
from tensorflow_federated.python.simulation.baselines.emnist import emnist_preprocessing
from tensorflow_federated.python.simulation.datasets import client_data
from tensorflow_federated.python.simulation.datasets import emnist
class CharacterRecognitionModel(enum.Enum):
"""Enum for EMNIST character recognition models."""
CNN_DROPOUT = 'cnn_dropout'
CNN = 'cnn'
TWO_LAYER_DNN = '2nn'
_CHARACTER_RECOGNITION_MODELS = [e.value for e in CharacterRecognitionModel]
def _get_character_recognition_model(
model_id: Union[str, CharacterRecognitionModel],
only_digits: bool,
debug_seed: Optional[int] = None,
) -> tf.keras.Model:
"""Constructs a `tf.keras.Model` for character recognition."""
try:
model_enum = CharacterRecognitionModel(model_id)
except ValueError as e:
raise ValueError(
'The model argument must be one of {}, found {}'.format(
_CHARACTER_RECOGNITION_MODELS, model_id
)
) from e
if model_enum == CharacterRecognitionModel.CNN_DROPOUT:
keras_model = emnist_models.create_conv_dropout_model(
only_digits=only_digits, debug_seed=debug_seed
)
elif model_enum == CharacterRecognitionModel.CNN:
keras_model = emnist_models.create_original_fedavg_cnn_model(
only_digits=only_digits, debug_seed=debug_seed
)
elif model_enum == CharacterRecognitionModel.TWO_LAYER_DNN:
keras_model = emnist_models.create_two_hidden_layer_model(
only_digits=only_digits, debug_seed=debug_seed
)
else:
raise ValueError(
'The model id must be one of {}, found {}'.format(
_CHARACTER_RECOGNITION_MODELS, model_id
)
)
return keras_model
def create_character_recognition_task_from_datasets(
train_client_spec: client_spec.ClientSpec,
eval_client_spec: Optional[client_spec.ClientSpec],
model_id: Union[str, CharacterRecognitionModel],
only_digits: bool,
train_data: client_data.ClientData,
test_data: client_data.ClientData,
debug_seed: Optional[int] = None,
) -> baseline_task.BaselineTask:
"""Creates a baseline task for character recognition on EMNIST.
Args:
train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
preprocess train client data.
eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
specifying how to preprocess evaluation client data. If set to `None`, the
evaluation datasets will use a batch size of 64 with no extra
preprocessing.
model_id: A string identifier for a character recognition model. Must be one
of 'cnn_dropout', 'cnn', or '2nn'. These correspond respectively to a CNN
model with dropout, a CNN model with no dropout, and a densely connected
network with two hidden layers of width 200.
only_digits: A boolean indicating whether to use the full EMNIST-62 dataset
containing 62 alphanumeric classes (`True`) or the smaller EMNIST-10
dataset with only 10 numeric classes (`False`).
train_data: A `tff.simulation.datasets.ClientData` used for training.
test_data: A `tff.simulation.datasets.ClientData` used for testing.
debug_seed: An optional integer seed to force deterministic model
initialization and dataset shuffle buffers. This is intended for
unittesting.
Returns:
A `tff.simulation.baselines.BaselineTask`.
"""
emnist_task = 'character_recognition'
if eval_client_spec is None:
eval_client_spec = client_spec.ClientSpec(
num_epochs=1, batch_size=64, shuffle_buffer_size=1
)
train_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
train_client_spec, emnist_task=emnist_task, debug_seed=debug_seed
)
eval_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
eval_client_spec, emnist_task=emnist_task, debug_seed=debug_seed
)
task_datasets = task_data.BaselineTaskDatasets(
train_data=train_data,
test_data=test_data,
validation_data=None,
train_preprocess_fn=train_preprocess_fn,
eval_preprocess_fn=eval_preprocess_fn,
)
def model_fn() -> variable.VariableModel:
return keras_utils.from_keras_model(
keras_model=_get_character_recognition_model(
model_id, only_digits, debug_seed
),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
input_spec=task_datasets.element_type_structure,
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
return baseline_task.BaselineTask(task_datasets, model_fn)
def create_character_recognition_task(
train_client_spec: client_spec.ClientSpec,
eval_client_spec: Optional[client_spec.ClientSpec] = None,
model_id: Union[str, CharacterRecognitionModel] = 'cnn_dropout',
only_digits: bool = False,
cache_dir: Optional[str] = None,
use_synthetic_data: bool = False,
debug_seed: Optional[int] = None,
) -> baseline_task.BaselineTask:
"""Creates a baseline task for character recognition on EMNIST.
The goal of the task is to minimize the sparse categorical crossentropy
between the output labels of the model and the true label of the image. When
`only_digits = True`, there are 10 possible labels (the digits 0-9), while
when `only_digits = False`, there are 62 possible labels (both numbers and
letters).
This classification can be done using a number of different models, specified
using the `model_id` argument. Below we give a list of the different models
that can be used:
* `model_id = cnn_dropout`: A moderately sized convolutional network. Uses
two convolutional layers, a max pooling layer, and dropout, followed by two
dense layers.
* `model_id = cnn`: A moderately sized convolutional network, without any
dropout layers. Matches the architecture of the convolutional network used
by (McMahan et al., 2017) for the purposes of testing the FedAvg algorithm.
* `model_id = 2nn`: A densely connected network with 2 hidden layers, each
with 200 hidden units and ReLU activations.
Args:
train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
preprocess train client data.
eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
specifying how to preprocess evaluation client data. If set to `None`, the
evaluation datasets will use a batch size of 64 with no extra
preprocessing.
model_id: A string identifier for a character recognition model. Must be one
of 'cnn_dropout', 'cnn', or '2nn'. These correspond respectively to a CNN
model with dropout, a CNN model with no dropout, and a densely connected
network with two hidden layers of width 200.
only_digits: A boolean indicating whether to use the full EMNIST-62 dataset
containing 62 alphanumeric classes (`True`) or the smaller EMNIST-10
dataset with only 10 numeric classes (`False`).
cache_dir: An optional directory to cache the downloadeded datasets. If
`None`, they will be cached to `~/.tff/`.
use_synthetic_data: A boolean indicating whether to use synthetic EMNIST
data. This option should only be used for testing purposes, in order to
avoid downloading the entire EMNIST dataset.
debug_seed: An optional integer seed to force deterministic model
initialization. This is intended for unittesting.
Returns:
A `tff.simulation.baselines.BaselineTask`.
"""
if use_synthetic_data:
synthetic_data = emnist.get_synthetic()
emnist_train = synthetic_data
emnist_test = synthetic_data
else:
emnist_train, emnist_test = emnist.load_data(
only_digits=only_digits, cache_dir=cache_dir
)
return create_character_recognition_task_from_datasets(
train_client_spec,
eval_client_spec,
model_id,
only_digits,
emnist_train,
emnist_test,
debug_seed,
)