-
Notifications
You must be signed in to change notification settings - Fork 476
/
Copy pathhub_model_server.py
133 lines (113 loc) · 3.99 KB
/
hub_model_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import requests
from typing import List
from fastdeploy.utils.hub_config import config
class ServerConnectionError(Exception):
def __init__(self, url: str):
self.url = url
def __str__(self):
tips = 'Can\'t connect to FastDeploy Model Server: {}'.format(self.url)
return tips
class ModelServer(object):
'''
FastDeploy server source
Args:
url(str) : Url of the server
timeout(int) : Request timeout
'''
def __init__(self, url: str, timeout: int=10):
self._url = url
self._timeout = timeout
def search_model(self, name: str, format: str=None,
version: str=None) -> List[dict]:
'''
Search model from model server.
Args:
name(str) : FastDeploy model name
format(str): FastDeploy model format
version(str) : FastDeploy model version
Return:
result(list): search results
'''
params = {}
params['name'] = name
if format:
params['format'] = format
if version:
params['version'] = version
result = self.request(path='fastdeploy_search', params=params)
if result['status'] == 0 and len(result['data']) > 0:
return result['data']
return None
def stat_model(self, name: str, format: str, version: str):
'''
Note a record when download a model for statistics.
Args:
name(str) : FastDeploy model name
format(str): FastDeploy model format
version(str) : FastDeploy model version
Return:
is_successful(bool): True if successful, False otherwise
'''
params = {}
params['name'] = name
params['format'] = format
params['version'] = version
params['from'] = 'fastdeploy'
try:
result = self.request(path='stat', params=params)
except Exception:
return False
if result['status'] == 0:
return True
else:
return False
def request(self, path: str, params: dict) -> dict:
'''Request server.'''
api = '{}/{}'.format(self._url, path)
try:
result = requests.get(api, params, timeout=self._timeout)
return result.json()
except requests.exceptions.ConnectionError as e:
raise ServerConnectionError(self._url)
def get_model_list(self):
'''
Get all pre-trained models information in dataset.
Return:
result(dict): key is category name, value is a list which contains models \
information such as name, format and version.
'''
api = '{}/{}'.format(self._url, 'fastdeploy_listmodels')
try:
result = requests.get(api, timeout=self._timeout)
return result.json()
except requests.exceptions.ConnectionError as e:
raise ServerConnectionError(self._url)
def is_connected(self):
return self.check(self._url)
@classmethod
def check(cls, url: str) -> bool:
'''
Check if the specified url is a valid model server
Args:
url(str) : Url to check
'''
try:
r = requests.get(url + '/search')
return r.status_code == 200
except:
return False
model_server = ModelServer(config.server)