/
backend.py
129 lines (97 loc) · 4.01 KB
/
backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import requests
from flask import Response, abort
import sys
from backend import GenericBackend
from tgi.metrics import Metrics
MODEL_SERVER = '127.0.0.1:5001'
class Backend(GenericBackend):
def __init__(self, container_id, control_server_url, master_token, send_data):
metrics = Metrics(id=container_id, master_token=master_token, control_server_url=control_server_url, send_server_data=send_data)
super().__init__(master_token=master_token, metrics=metrics)
self.model_server_addr = MODEL_SERVER
def generate(self, model_request, metrics=True):
return super().generate(model_request, self.model_server_addr, "generate", lambda r: r.text, metrics=metrics)
def hf_tgi_wrapper(self, model_request):
success = True
self.metrics.start_req(model_request)
try:
response = requests.post(f"http://{self.model_server_addr}/generate_stream", json=model_request, stream=True)
if response.status_code == 200:
for byte_payload in response.iter_lines():
yield byte_payload
yield "\n"
self.metrics.finish_req(model_request)
success = True
except requests.exceptions.RequestException as e:
print(f"[TGI-backend] Request error: {e}")
if not success:
self.metrics.error_req(model_request)
def generate_stream(self, model_request):
return Response(self.hf_tgi_wrapper(model_request))
def health_handler(self):
return super().get(None, self.model_server_addr, "health", lambda r: r.text,)
def info_handler(self):
return super().get(None, self.model_server_addr, "info", lambda r: r.text)
def metrics_handler(self):
return super().get(None, self.model_server_addr, "metrics", lambda r: r.text)
######################################### FLASK HANDLER METHODS ###############################################################
# Can move these functions into the TGIBackend class I think
def generate_handler(backend, request):
auth_dict, model_dict = backend.format_request(request.json)
if auth_dict:
if not backend.check_signature(**auth_dict):
abort(401)
else:
print("WARNING: support for /generate requests without a signed signature will soon be deprecated")
if model_dict is None:
print(f"client request: {request.json} doesn't include model inputs and parameters")
abort(400)
code, content, _ = backend.generate(model_dict)
if code == 200:
return content
else:
print(f"generate failed with code {code}")
abort(code)
def generate_stream_handler(backend, request):
auth_dict, model_dict = backend.format_request(request.json)
if auth_dict:
if not backend.check_signature(**auth_dict):
abort(401)
else:
print("WARNING: support for /generate_stream requests without a signed signature will soon be deprecated")
if model_dict is None:
print(f"client request: {request.json} doesn't include model inputs and parameters")
abort(400)
return backend.generate_stream(model_dict)
def health_handler(backend, request):
code, content = backend.health_handler()
if code == 200:
return content
else:
print(f"health failed with code {code}")
abort(code)
def info_handler(backend, request):
code, content = backend.info_handler()
if code == 200:
return content
else:
print(f"info failed with code {code}")
abort(code)
def metrics_handler(backend, request):
code, content = backend.metrics_handler()
if code == 200:
return content
else:
print(f"metrics failed with code {code}")
abort(code)
flask_dict = {
"POST" : {
"generate" : generate_handler,
"generate_stream" : generate_stream_handler,
},
"GET" : {
"health" : health_handler,
"info" : info_handler,
"metrics" : metrics_handler
}
}