-
Notifications
You must be signed in to change notification settings - Fork 5
/
wrapper.py
187 lines (161 loc) · 7.43 KB
/
wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from typing import Any, AsyncIterator, Callable, Iterator, List, Tuple
import functools
import inspect
import time
from uuid import uuid4
from parea.cache.cache import Cache
from parea.helpers import date_and_time_string_to_timestamp
from parea.schemas.models import TraceLog
from parea.utils.trace_utils import to_date_and_time_string, trace_context, trace_data
class Wrapper:
def __init__(
self,
module: Any,
func_names: List[str],
resolver: Callable,
gen_resolver: Callable,
agen_resolver: Callable,
cache: Cache,
convert_kwargs_to_cache_request: Callable,
convert_cache_to_response: Callable,
aconvert_cache_to_response: Callable,
log: Callable,
) -> None:
self.resolver = resolver
self.gen_resolver = gen_resolver
self.agen_resolver = agen_resolver
self.log = log
self.wrap_functions(module, func_names)
self.cache = cache
self.convert_kwargs_to_cache_request = convert_kwargs_to_cache_request
self.convert_cache_to_response = convert_cache_to_response
self.aconvert_cache_to_response = aconvert_cache_to_response
def wrap_functions(self, module: Any, func_names: List[str]):
for func_name in func_names:
func_name_parts = func_name.split(".")
original = functools.reduce(getattr, func_name_parts, module)
setattr(module if len(func_name_parts) == 1 else functools.reduce(getattr, func_name_parts[:-1], module), func_name_parts[-1], self._wrapped_func(original))
def _wrapped_func(self, original_func: Callable) -> Callable:
unwrapped_func = original_func
while hasattr(original_func, "__wrapped__"):
unwrapped_func = original_func.__wrapped__
return self._get_decorator(unwrapped_func, original_func)
def _get_decorator(self, unwrapped_func: Callable, original_func: Callable):
if inspect.iscoroutinefunction(unwrapped_func):
return self.async_decorator(original_func)
else:
return self.sync_decorator(original_func)
def _init_trace(self) -> Tuple[str, float]:
start_time = time.time()
trace_id = str(uuid4())
trace_context.get().append(trace_id)
trace_data.get()[trace_id] = TraceLog(
trace_id=trace_id,
start_timestamp=to_date_and_time_string(start_time),
trace_name="LLM",
end_user_identifier=None,
metadata=None,
target=None,
tags=None,
inputs={},
)
parent_trace_id = trace_context.get()[-2] if len(trace_context.get()) > 1 else None
if not parent_trace_id:
# we don't have a parent trace id, so we need to create one
parent_trace_id = str(uuid4())
trace_context.get().insert(0, parent_trace_id)
trace_data.get()[parent_trace_id] = TraceLog(
trace_id=parent_trace_id,
start_timestamp=to_date_and_time_string(start_time),
end_user_identifier=None,
metadata=None,
target=None,
tags=None,
inputs={},
)
trace_data.get()[parent_trace_id].children.append(trace_id)
self.log(parent_trace_id)
return trace_id, start_time
def async_decorator(self, orig_func: Callable) -> Callable:
async def wrapper(*args, **kwargs):
trace_id, start_time = self._init_trace()
response = None
error = None
cache_hit = False
cache_key = self.convert_kwargs_to_cache_request(args, kwargs)
try:
if self.cache:
cache_result = await self.cache.aget(cache_key)
if cache_result is not None:
response = self.aconvert_cache_to_response(args, kwargs, cache_result)
cache_hit = True
if response is None:
response = await orig_func(*args, **kwargs)
except Exception as e:
error = str(e)
if self.cache:
await self.cache.ainvalidate(cache_key)
raise
finally:
return await self._acleanup_trace(trace_id, start_time, error, cache_hit, args, kwargs, response)
return wrapper
def sync_decorator(self, orig_func: Callable) -> Callable:
def wrapper(*args, **kwargs):
trace_id, start_time = self._init_trace()
response = None
error = None
cache_hit = False
cache_key = self.convert_kwargs_to_cache_request(args, kwargs)
try:
if self.cache:
cache_result = self.cache.get(cache_key)
if cache_result is not None:
response = self.convert_cache_to_response(args, kwargs, cache_result)
cache_hit = True
if response is None:
response = orig_func(*args, **kwargs)
except Exception as e:
error = str(e)
if self.cache:
self.cache.invalidate(cache_key)
raise e
finally:
return self._cleanup_trace(trace_id, start_time, error, cache_hit, args, kwargs, response)
return wrapper
def _cleanup_trace_core(self, trace_id: str, start_time: float, error: str, cache_hit, args, kwargs):
trace_data.get()[trace_id].cache_hit = cache_hit
if error:
trace_data.get()[trace_id].error = error
trace_data.get()[trace_id].status = "error"
else:
trace_data.get()[trace_id].status = "success"
def final_log():
end_time = time.time()
trace_data.get()[trace_id].end_timestamp = to_date_and_time_string(end_time)
trace_data.get()[trace_id].latency = end_time - start_time
parent_id = trace_context.get()[-2]
trace_data.get()[parent_id].end_timestamp = to_date_and_time_string(end_time)
start_time_parent = date_and_time_string_to_timestamp(trace_data.get()[parent_id].start_timestamp)
trace_data.get()[parent_id].latency = end_time - start_time_parent
if not error and self.cache:
self.cache.set(self.convert_kwargs_to_cache_request(args, kwargs), trace_data.get()[trace_id])
self.log(trace_id)
self.log(parent_id)
trace_context.get().pop()
return final_log
def _cleanup_trace(self, trace_id: str, start_time: float, error: str, cache_hit, args, kwargs, response):
final_log = self._cleanup_trace_core(trace_id, start_time, error, cache_hit, args, kwargs)
if isinstance(response, Iterator):
return self.gen_resolver(trace_id, args, kwargs, response, final_log)
else:
self.resolver(trace_id, args, kwargs, response)
final_log()
return response
async def _acleanup_trace(self, trace_id: str, start_time: float, error: str, cache_hit, args, kwargs, response):
final_log = self._cleanup_trace_core(trace_id, start_time, error, cache_hit, args, kwargs)
if isinstance(response, AsyncIterator):
return self.agen_resolver(trace_id, args, kwargs, response, final_log)
else:
self.resolver(trace_id, args, kwargs, response)
final_log()
return response