-
Notifications
You must be signed in to change notification settings - Fork 583
/
transforming_client_data.py
136 lines (113 loc) · 5.6 KB
/
transforming_client_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Expands ClientData by performing transformations."""
import bisect
import re
import tensorflow as tf
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.simulation import client_data
CLIENT_ID_REGEX = re.compile(r'^(.*)_(\d+)$')
def split_client_id(client_id):
"""Splits pseudo-client id into raw client id and index components.
Args:
client_id: The pseudo-client id.
Returns:
A tuple (raw_client_id, index) where raw_client_id is the string of the raw
client_id, and index is the integer index of the pseudo-client.
"""
py_typecheck.check_type(client_id, str)
match = CLIENT_ID_REGEX.search(client_id)
if not match:
raise ValueError('client_id must be a valid string from client_ids.')
raw_client_id = match.group(1)
index = int(match.group(2))
return raw_client_id, index
class TransformingClientData(client_data.ClientData):
"""Transforms client data, potentially expanding by adding pseudo-clients.
Each client of the raw_client_data is "expanded" into some number of
pseudo-clients. Each client ID is a string consisting of the original client
ID plus a concatenated integer index. For example, the raw client id
"client_a" might be expanded into pseudo-client ids "client_a_0", "client_a_1"
and "client_a_2". A function fn(x) maps datapoint x to a new datapoint,
where the constructor of fn is parameterized by the (raw) client_id and index
i. For example if x is an image, then make_transform_fn("client_a", 0)(x)
might be the identity, while make_transform_fn("client_a", 1)(x) could be a
random rotation of the image with the angle determined by a hash of "client_a"
and "1". Typically by convention the index 0 corresponds to the identity
function if the identity is supported.
"""
def __init__(self, raw_client_data, make_transform_fn,
num_transformed_clients):
"""Initializes the TransformingClientData.
Args:
raw_client_data: A ClientData to expand.
make_transform_fn: A function that returns a callable that maps datapoint
x to a new datapoint x'. make_transform_fn will be called as
make_transform_fn(raw_client_id, i) where i is an integer index, and
should return a function fn(x)->x. For example if x is an image, then
make_transform_fn("client_a", 0)(x) might be the identity, while
make_transform_fn("client_a", 1)(x) could be a random rotation of the
image with the angle determined by a hash of "client_a" and "1". If
transform_fn_cons returns `None`, no transformation is performed.
Typically by convention the index 0 corresponds to the identity function
if the identity is supported.
num_transformed_clients: The total number of transformed clients to
produce. If it is an integer multiple k of the number of real clients,
there will be exactly k pseudo-clients per real client, with indices
0...k-1. Any remainder g will be generated from the first g real clients
and will be given index k.
"""
py_typecheck.check_type(raw_client_data, client_data.ClientData)
py_typecheck.check_callable(make_transform_fn)
py_typecheck.check_type(num_transformed_clients, int)
if num_transformed_clients <= 0:
raise ValueError('num_transformed_clients must be positive and finite.')
self._raw_client_data = raw_client_data
self._make_transform_fn = make_transform_fn
num_digits = len(str(num_transformed_clients - 1))
format_str = '{}_{:0' + str(num_digits) + '}'
raw_client_ids = raw_client_data.client_ids
k = num_transformed_clients // len(raw_client_ids)
self._client_ids = []
for raw_client_id in raw_client_ids:
for i in range(k):
self._client_ids.append(format_str.format(raw_client_id, i))
num_extra_client_ids = num_transformed_clients - k * len(raw_client_ids)
for c in range(num_extra_client_ids):
self._client_ids.append(format_str.format(raw_client_ids[c], k))
# Already sorted if raw_client_data.client_ids are, but just to be sure...
self._client_ids = sorted(self._client_ids)
@property
def client_ids(self):
return self._client_ids
def create_tf_dataset_for_client(self, client_id):
py_typecheck.check_type(client_id, str)
i = bisect.bisect_left(self._client_ids, client_id)
if i == len(self._client_ids) or self._client_ids[i] != client_id:
raise ValueError('client_id must be a valid string from client_ids.')
raw_client_id, index = split_client_id(client_id)
raw_dataset = self._raw_client_data.create_tf_dataset_for_client(
raw_client_id)
transform_fn = self._make_transform_fn(raw_client_id, index)
if not transform_fn:
return raw_dataset
else:
py_typecheck.check_callable(transform_fn)
return raw_dataset.map(transform_fn, tf.data.experimental.AUTOTUNE)
@property
def element_type_structure(self):
return self._raw_client_data.element_type_structure
@property
def dataset_computation(self):
raise NotImplementedError('b/XXXXXXXXXXXX')