/
client_support.py
278 lines (237 loc) · 10.3 KB
/
client_support.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
"""Client support."""
from typing import List, Optional, Union
import numpy as np
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import ClientSupport as _ClientSupport
# pylint: enable=no-name-in-module,import-error
from .public_result import PublicResult
from .key_set import KeySet
from .key_set_cache import KeySetCache
from .client_parameters import ClientParameters
from .public_arguments import PublicArguments
from .lambda_argument import LambdaArgument
from .wrapper import WrapperCpp
from .utils import ACCEPTED_INTS, ACCEPTED_NUMPY_UINTS, ACCEPTED_TYPES
class ClientSupport(WrapperCpp):
"""Client interface for doing key generation and encryption.
It provides features that are needed on the client side:
- Generation of public and private keys required for the encrypted computation
- Encryption and preparation of public arguments, used later as input to the computation
- Decryption of public result returned after execution
"""
def __init__(self, client_support: _ClientSupport):
"""Wrap the native Cpp object.
Args:
client_support (_ClientSupport): object to wrap
Raises:
TypeError: if client_support is not of type _ClientSupport
"""
if not isinstance(client_support, _ClientSupport):
raise TypeError(
f"client_support must be of type _ClientSupport, not {type(client_support)}"
)
super().__init__(client_support)
@staticmethod
# pylint: disable=arguments-differ
def new() -> "ClientSupport":
"""Build a ClientSupport.
Returns:
ClientSupport
"""
return ClientSupport.wrap(_ClientSupport())
# pylint: enable=arguments-differ
@staticmethod
def key_set(
client_parameters: ClientParameters,
keyset_cache: Optional[KeySetCache] = None,
seed_msb: int = 0,
seed_lsb: int = 0,
) -> KeySet:
"""Generate a key set according to the client parameters.
If the cache is set, and include equivalent keys as specified by the client parameters,
the keyset is loaded, otherwise, a new keyset is generated and saved in the cache.
Args:
client_parameters (ClientParameters): client parameters specification
keyset_cache (Optional[KeySetCache], optional): keyset cache. Defaults to None.
seed_msb (int): msb of seed
seed_lsb (int): lsb of seed
Raises:
TypeError: if client_parameters is not of type ClientParameters
TypeError: if keyset_cache is not of type KeySetCache
AssertionError: if seed components is not uint64
Returns:
KeySet: generated or loaded keyset
"""
assert 0 <= seed_msb < 2**64
assert 0 <= seed_lsb < 2**64
if keyset_cache is not None and not isinstance(keyset_cache, KeySetCache):
raise TypeError(
f"keyset_cache must be None or of type KeySetCache, not {type(keyset_cache)}"
)
cpp_cache = None if keyset_cache is None else keyset_cache.cpp()
return KeySet.wrap(
_ClientSupport.key_set(
client_parameters.cpp(),
cpp_cache,
seed_msb,
seed_lsb,
),
)
@staticmethod
def encrypt_arguments(
client_parameters: ClientParameters,
keyset: KeySet,
args: List[Union[int, np.ndarray]],
) -> PublicArguments:
"""Prepare arguments for encrypted computation.
Pack public arguments by encrypting the ones that requires encryption, and leaving the rest as plain.
It also pack public materials (public keys) that are required during the computation.
Args:
client_parameters (ClientParameters): client parameters specification
keyset (KeySet): keyset used to encrypt arguments that require encryption
args (List[Union[int, np.ndarray]]): list of scalar or tensor arguments
Raises:
TypeError: if client_parameters is not of type ClientParameters
TypeError: if keyset is not of type KeySet
Returns:
PublicArguments: public arguments for execution
"""
if not isinstance(client_parameters, ClientParameters):
raise TypeError(
f"client_parameters must be of type ClientParameters, not {type(client_parameters)}"
)
if not isinstance(keyset, KeySet):
raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}")
signs = client_parameters.input_signs()
if len(signs) != len(args):
raise RuntimeError(
f"function has arity {len(signs)} but is applied to too many arguments"
)
lambda_arguments = [
ClientSupport._create_lambda_argument(arg, signed)
for arg, signed in zip(args, signs)
]
return PublicArguments.wrap(
_ClientSupport.encrypt_arguments(
client_parameters.cpp(),
keyset.cpp(),
[arg.cpp() for arg in lambda_arguments],
)
)
@staticmethod
def decrypt_result(
client_parameters: ClientParameters,
keyset: KeySet,
public_result: PublicResult,
) -> Union[int, np.ndarray]:
"""Decrypt a public result using the keyset.
Args:
client_parameters (ClientParameters): client parameters for decryption
keyset (KeySet): keyset used for decryption
public_result: public result to decrypt
Raises:
TypeError: if keyset is not of type KeySet
TypeError: if public_result is not of type PublicResult
RuntimeError: if the result is of an unknown type
Returns:
Union[int, np.ndarray]: plain result
"""
if not isinstance(keyset, KeySet):
raise TypeError(f"keyset must be of type KeySet, not {type(keyset)}")
if not isinstance(public_result, PublicResult):
raise TypeError(
f"public_result must be of type PublicResult, not {type(public_result)}"
)
lambda_arg = LambdaArgument.wrap(
_ClientSupport.decrypt_result(keyset.cpp(), public_result.cpp())
)
output_signs = client_parameters.output_signs()
assert len(output_signs) == 1
is_signed = lambda_arg.is_signed()
if lambda_arg.is_scalar():
return (
lambda_arg.get_signed_scalar() if is_signed else lambda_arg.get_scalar()
)
if lambda_arg.is_tensor():
return np.array(
lambda_arg.get_signed_tensor_data()
if is_signed
else lambda_arg.get_tensor_data(),
dtype=(np.int64 if is_signed else np.uint64),
).reshape(lambda_arg.get_tensor_shape())
raise RuntimeError("unknown return type")
@staticmethod
def _create_lambda_argument(
value: Union[int, np.ndarray], signed: bool
) -> LambdaArgument:
"""Create a lambda argument holding either an int or tensor value.
Args:
value (Union[int, numpy.array]): value of the argument, either an int, or a numpy array
signed (bool): whether the value is signed
Raises:
TypeError: if the values aren't in the expected range, or using a wrong type
Returns:
LambdaArgument: lambda argument holding the appropriate value
"""
# pylint: disable=too-many-return-statements,too-many-branches
if not isinstance(value, ACCEPTED_TYPES):
raise TypeError(
"value of lambda argument must be either int, numpy.array or numpy.(u)int{8,16,32,64}"
)
if isinstance(value, ACCEPTED_INTS):
if (
isinstance(value, int)
and not np.iinfo(np.int64).min <= value < np.iinfo(np.uint64).max
):
raise TypeError(
"single integer must be in the range [-2**63, 2**64 - 1]"
)
if signed:
return LambdaArgument.from_signed_scalar(value)
return LambdaArgument.from_scalar(value)
assert isinstance(value, np.ndarray)
if value.dtype not in ACCEPTED_NUMPY_UINTS:
raise TypeError("numpy.array must be of dtype (u)int{8,16,32,64}")
if value.shape == ():
if isinstance(value, np.ndarray):
# extract the single element
value = value.max()
# should be a single uint here
if signed:
return LambdaArgument.from_signed_scalar(value)
return LambdaArgument.from_scalar(value)
if value.dtype == np.uint8:
return LambdaArgument.from_tensor_u8(
value.flatten().tolist(), list(value.shape)
)
if value.dtype == np.uint16:
return LambdaArgument.from_tensor_u16(
value.flatten().tolist(), list(value.shape)
)
if value.dtype == np.uint32:
return LambdaArgument.from_tensor_u32(
value.flatten().tolist(), list(value.shape)
)
if value.dtype == np.uint64:
return LambdaArgument.from_tensor_u64(
value.flatten().tolist(), list(value.shape)
)
if value.dtype == np.int8:
return LambdaArgument.from_tensor_i8(
value.flatten().tolist(), list(value.shape)
)
if value.dtype == np.int16:
return LambdaArgument.from_tensor_i16(
value.flatten().tolist(), list(value.shape)
)
if value.dtype == np.int32:
return LambdaArgument.from_tensor_i32(
value.flatten().tolist(), list(value.shape)
)
if value.dtype == np.int64:
return LambdaArgument.from_tensor_i64(
value.flatten().tolist(), list(value.shape)
)
raise TypeError("numpy.array must be of dtype (u)int{8,16,32,64}")