diff --git a/dsp/modules/hf_client.py b/dsp/modules/hf_client.py index cc4ded6237..810c9d3a00 100644 --- a/dsp/modules/hf_client.py +++ b/dsp/modules/hf_client.py @@ -130,13 +130,15 @@ def __init__(self, model, port, model_type: Literal['chat', 'text'] = 'text', ur else: raise ValueError(f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type {type(url)}.") + self.urls_const = tuple(self.urls) + self.port = port self.model_type = model_type self.headers = {"Content-Type": "application/json"} self.kwargs |= kwargs # kwargs needs to have model, port and url for the lm.copy() to work properly self.kwargs.update({ 'port': port, - 'url': url, + 'url': self.urls_const, }) @@ -157,8 +159,10 @@ def _generate(self, prompt, **kwargs): "messages": messages, **kwargs, } - response = send_hfvllm_chat_request_v00( + response = send_hfvllm_request_v01_wrapped( f"{url}/v1/chat/completions", + url=self.urls_const, + port=self.port, json=payload, headers=self.headers, ) @@ -181,9 +185,11 @@ def _generate(self, prompt, **kwargs): "prompt": prompt, **kwargs, } - - response = send_hfvllm_request_v00( + + response = send_hfvllm_request_v01_wrapped( f"{url}/v1/completions", + url=self.urls_const, + port=self.port, json=payload, headers=self.headers, ) @@ -201,13 +207,20 @@ def _generate(self, prompt, **kwargs): print("Failed to parse JSON response:", response.text) raise Exception("Received invalid JSON response from server") - @CacheMemory.cache(ignore=['arg']) -def send_hfvllm_request_v00(arg, **kwargs): +def send_hfvllm_request_v01(arg, url, port, **kwargs): return requests.post(arg, **kwargs) +# @functools.lru_cache(maxsize=None if cache_turn_on else 0) +@NotebookCacheMemory.cache(ignore=['arg']) +def send_hfvllm_request_v01_wrapped(arg, url, port, **kwargs): + return send_hftgi_request_v01(arg, url, port, **kwargs) -@CacheMemory.cache(ignore=['arg']) +@CacheMemory.cache +def send_hfvllm_request_v00(arg, **kwargs): + return requests.post(arg, **kwargs) + +@CacheMemory.cache def send_hfvllm_chat_request_v00(arg, **kwargs): return requests.post(arg, **kwargs)