-
Notifications
You must be signed in to change notification settings - Fork 129
/
library_support.py
362 lines (314 loc) · 13.8 KB
/
library_support.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete/blob/main/LICENSE.txt for license information.
"""LibrarySupport.
Library support provides a way to compile an MLIR program into a library that can be later loaded
to execute the compiled code.
"""
import os
from typing import Optional, Union
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
LibrarySupport as _LibrarySupport,
)
from mlir.ir import Module as MlirModule
# pylint: enable=no-name-in-module,import-error
from .compilation_options import CompilationOptions
from .compilation_context import CompilationContext
from .library_compilation_result import LibraryCompilationResult
from .public_arguments import PublicArguments
from .library_lambda import LibraryLambda
from .public_result import PublicResult
from .client_parameters import ClientParameters
from .compilation_feedback import ProgramCompilationFeedback
from .wrapper import WrapperCpp
from .utils import lookup_runtime_lib
from .evaluation_keys import EvaluationKeys
# Default output path for compilation artifacts
DEFAULT_OUTPUT_PATH = os.path.abspath(
os.path.join(os.path.curdir, "concrete-compiler_compilation_artifacts")
)
class LibrarySupport(WrapperCpp):
"""Support class for library compilation and execution."""
def __init__(self, library_support: _LibrarySupport):
"""Wrap the native Cpp object.
Args:
library_support (_LibrarySupport): object to wrap
Raises:
TypeError: if library_support is not of type _LibrarySupport
"""
if not isinstance(library_support, _LibrarySupport):
raise TypeError(
f"library_support must be of type _LibrarySupport, not "
f"{type(library_support)}"
)
super().__init__(library_support)
self.output_dir_path = DEFAULT_OUTPUT_PATH
@property
def output_dir_path(self) -> str:
"""Path where to store compilation artifacts."""
return self._output_dir_path
@output_dir_path.setter
def output_dir_path(self, path: str):
if not isinstance(path, str):
raise TypeError(f"path must be of type str, not {type(path)}")
self._output_dir_path = path
@staticmethod
# pylint: disable=arguments-differ
def new(
output_path: str = DEFAULT_OUTPUT_PATH,
runtime_library_path: Optional[str] = None,
generateSharedLib: bool = True,
generateStaticLib: bool = False,
generateClientParameters: bool = True,
generateCompilationFeedback: bool = True,
generateCppHeader: bool = False,
) -> "LibrarySupport":
"""Build a LibrarySupport.
Args:
output_path (str, optional): path where to store compilation artifacts.
Defaults to DEFAULT_OUTPUT_PATH.
runtime_library_path (Optional[str], optional): path to the runtime library. Defaults to None.
generateSharedLib (bool): whether to emit shared library or not. Default to True.
generateStaticLib (bool): whether to emit static library or not. Default to False.
generateClientParameters (bool): whether to emit client parameters or not. Default to True.
generateCppHeader (bool): whether to emit cpp header or not. Default to False.
Raises:
TypeError: if output_path is not of type str
TypeError: if runtime_library_path is not of type str
TypeError: if one of the generation flags is not of type bool
Returns:
LibrarySupport
"""
if runtime_library_path is None:
runtime_library_path = lookup_runtime_lib()
if not isinstance(output_path, str):
raise TypeError(f"output_path must be of type str, not {type(output_path)}")
if not isinstance(runtime_library_path, str):
raise TypeError(
f"runtime_library_path must be of type str, not {type(runtime_library_path)}"
)
for name, value in [
("generateSharedLib", generateSharedLib),
("generateStaticLib", generateStaticLib),
("generateClientParameters", generateClientParameters),
("generateCompilationFeedback", generateCompilationFeedback),
("generateCppHeader", generateCppHeader),
]:
if not isinstance(value, bool):
raise TypeError(f"{name} must be of type bool, not {type(value)}")
library_support = LibrarySupport.wrap(
_LibrarySupport(
output_path,
runtime_library_path,
generateSharedLib,
generateStaticLib,
generateClientParameters,
generateCompilationFeedback,
generateCppHeader,
)
)
if not os.path.isdir(output_path):
os.makedirs(output_path)
library_support.output_dir_path = output_path
return library_support
def compile(
self,
mlir_program: Union[str, MlirModule],
options: CompilationOptions = CompilationOptions.new(),
compilation_context: Optional[CompilationContext] = None,
) -> LibraryCompilationResult:
"""Compile an MLIR program using Concrete dialects into a library.
Args:
mlir_program (Union[str, MlirModule]): mlir program to compile (textual or in-memory)
options (CompilationOptions): compilation options
Raises:
TypeError: if mlir_program is not of type str or MlirModule
TypeError: if options is not of type CompilationOptions
Returns:
LibraryCompilationResult: the result of the library compilation
"""
if not isinstance(mlir_program, (str, MlirModule)):
raise TypeError(
f"mlir_program must be of type str or MlirModule, not {type(mlir_program)}"
)
if not isinstance(options, CompilationOptions):
raise TypeError(
f"options must be of type CompilationOptions, not {type(options)}"
)
# get the PyCapsule of the module
if isinstance(mlir_program, MlirModule):
if compilation_context is None:
raise ValueError(
"compilation_context must be provided when compiling a module object"
)
if not isinstance(compilation_context, CompilationContext):
raise TypeError(
f"compilation_context must be of type CompilationContext, not "
f"{type(compilation_context)}"
)
# pylint: disable=protected-access
return LibraryCompilationResult.wrap(
self.cpp().compile(
mlir_program._CAPIPtr, options.cpp(), compilation_context.cpp()
)
)
# pylint: enable=protected-access
return LibraryCompilationResult.wrap(
self.cpp().compile(mlir_program, options.cpp())
)
def reload(self) -> LibraryCompilationResult:
"""Reload the library compilation result from the output_dir_path.
Returns:
LibraryCompilationResult: loaded library
"""
return LibraryCompilationResult.new(self.output_dir_path)
def load_client_parameters(
self, library_compilation_result: LibraryCompilationResult
) -> ClientParameters:
"""Load the client parameters from the library compilation result.
Args:
library_compilation_result (LibraryCompilationResult): compilation result of the library
Raises:
TypeError: if library_compilation_result is not of type LibraryCompilationResult
Returns:
ClientParameters: appropriate client parameters for the compiled library
"""
if not isinstance(library_compilation_result, LibraryCompilationResult):
raise TypeError(
f"library_compilation_result must be of type LibraryCompilationResult, not "
f"{type(library_compilation_result)}"
)
return ClientParameters.wrap(
self.cpp().load_client_parameters(library_compilation_result.cpp())
)
def load_compilation_feedback(
self, compilation_result: LibraryCompilationResult
) -> ProgramCompilationFeedback:
"""Load the compilation feedback from the compilation result.
Args:
compilation_result (LibraryCompilationResult): result of the compilation
Raises:
TypeError: if compilation_result is not of type LibraryCompilationResult
Returns:
ProgramCompilationFeedback: the compilation feedback for the compiled program
"""
if not isinstance(compilation_result, LibraryCompilationResult):
raise TypeError(
f"compilation_result must be of type LibraryCompilationResult, not {type(compilation_result)}"
)
return ProgramCompilationFeedback.wrap(
self.cpp().load_compilation_feedback(compilation_result.cpp())
)
def load_server_lambda(
self,
library_compilation_result: LibraryCompilationResult,
simulation: bool,
circuit_name: str = "main",
) -> LibraryLambda:
"""Load the server lambda for a given circuit from the library compilation result.
Args:
library_compilation_result (LibraryCompilationResult): compilation result of the library
simulation (bool): use simulation for execution
circuit_name (str): name of the circuit to be loaded
Raises:
TypeError: if library_compilation_result is not of type LibraryCompilationResult, if
circuit_name is not of type str or
Returns:
LibraryLambda: executable reference to the library
"""
if not isinstance(library_compilation_result, LibraryCompilationResult):
raise TypeError(
f"library_compilation_result must be of type LibraryCompilationResult, not "
f"{type(library_compilation_result)}"
)
if not isinstance(circuit_name, str):
raise TypeError(
f"circuit_name must be of type str, not " f"{type(circuit_name)}"
)
if not isinstance(simulation, bool):
raise TypeError(
f"simulation must be of type bool, not " f"{type(simulation)}"
)
return LibraryLambda.wrap(
self.cpp().load_server_lambda(
library_compilation_result.cpp(), circuit_name, simulation
)
)
def server_call(
self,
library_lambda: LibraryLambda,
public_arguments: PublicArguments,
evaluation_keys: EvaluationKeys,
) -> PublicResult:
"""Call the library with public_arguments.
Args:
library_lambda (LibraryLambda): reference to the compiled library
public_arguments (PublicArguments): arguments to use for execution
evaluation_keys (EvaluationKeys): evaluation keys to use for execution
Raises:
TypeError: if library_lambda is not of type LibraryLambda
TypeError: if public_arguments is not of type PublicArguments
TypeError: if evaluation_keys is not of type EvaluationKeys
Returns:
PublicResult: result of the execution
"""
if not isinstance(library_lambda, LibraryLambda):
raise TypeError(
f"library_lambda must be of type LibraryLambda, not {type(library_lambda)}"
)
if not isinstance(public_arguments, PublicArguments):
raise TypeError(
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
)
if not isinstance(evaluation_keys, EvaluationKeys):
raise TypeError(
f"evaluation_keys must be of type EvaluationKeys, not {type(evaluation_keys)}"
)
return PublicResult.wrap(
self.cpp().server_call(
library_lambda.cpp(),
public_arguments.cpp(),
evaluation_keys.cpp(),
)
)
def simulate(
self,
library_lambda: LibraryLambda,
public_arguments: PublicArguments,
) -> PublicResult:
"""Call the library with public_arguments in simulation mode.
Args:
library_lambda (LibraryLambda): reference to the compiled library
public_arguments (PublicArguments): arguments to use for execution
Raises:
TypeError: if library_lambda is not of type LibraryLambda
TypeError: if public_arguments is not of type PublicArguments
Returns:
PublicResult: result of the execution
"""
if not isinstance(library_lambda, LibraryLambda):
raise TypeError(
f"library_lambda must be of type LibraryLambda, not {type(library_lambda)}"
)
if not isinstance(public_arguments, PublicArguments):
raise TypeError(
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
)
return PublicResult.wrap(
self.cpp().simulate(
library_lambda.cpp(),
public_arguments.cpp(),
)
)
def get_shared_lib_path(self) -> str:
"""Get the path where the shared library is expected to be.
Returns:
str: path to the shared library
"""
return self.cpp().get_shared_lib_path()
def get_program_info_path(self) -> str:
"""Get the path where the program info file is expected to be.
Returns:
str: path to the program info file
"""
return self.cpp().get_program_info_path()