-
Notifications
You must be signed in to change notification settings - Fork 120
/
server.py
91 lines (70 loc) · 2.53 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Deployment server.
Routes:
- Get client.zip
- Add a key
- Compute
"""
import io
import os
import uuid
from pathlib import Path
from typing import Dict
import uvicorn
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse, StreamingResponse
# No relative import here because when not used in the package itself
from concrete.ml.deployment import FHEModelServer
if __name__ == "__main__":
app = FastAPI(debug=False)
FILE_FOLDER = Path(__file__).parent
KEY_PATH = Path(os.environ.get("KEY_PATH", FILE_FOLDER / Path("server_keys")))
CLIENT_SERVER_PATH = Path(os.environ.get("PATH_TO_MODEL", FILE_FOLDER / Path("dev")))
PORT = os.environ.get("PORT", "5000")
fhe = FHEModelServer(str(CLIENT_SERVER_PATH.resolve()))
KEYS: Dict[str, bytes] = {}
PATH_TO_CLIENT = (CLIENT_SERVER_PATH / "client.zip").resolve()
PATH_TO_SERVER = (CLIENT_SERVER_PATH / "server.zip").resolve()
assert PATH_TO_CLIENT.exists()
assert PATH_TO_SERVER.exists()
@app.get("/get_client")
def get_client():
"""Get client.
Returns:
FileResponse: client.zip
Raises:
HTTPException: if the file can't be find locally
"""
path_to_client = (CLIENT_SERVER_PATH / "client.zip").resolve()
if not path_to_client.exists():
raise HTTPException(status_code=500, detail="Could not find client.")
return FileResponse(path_to_client, media_type="application/zip")
@app.post("/add_key")
async def add_key(key: UploadFile):
"""Add public key.
Arguments:
key (UploadFile): public key
Returns:
Dict[str, str]
- uid: uid a personal uid
"""
uid = str(uuid.uuid4())
KEYS[uid] = await key.read()
return {"uid": uid}
@app.post("/compute")
async def compute(model_input: UploadFile, uid: str = Form()): # noqa: B008
"""Compute the circuit over encrypted input.
Arguments:
model_input (UploadFile): input of the circuit
uid (str): uid of the public key to use
Returns:
StreamingResponse: the result of the circuit
"""
key = KEYS[uid]
encrypted_results = fhe.run(
serialized_encrypted_quantized_data=await model_input.read(),
serialized_evaluation_keys=key,
)
return StreamingResponse(
io.BytesIO(encrypted_results),
)
uvicorn.run(app, host="0.0.0.0", port=int(PORT))