diff --git a/dsp/modules/hf_client.py b/dsp/modules/hf_client.py index caa4b62652..6c4aaef20b 100644 --- a/dsp/modules/hf_client.py +++ b/dsp/modules/hf_client.py @@ -3,6 +3,7 @@ import re import shutil import subprocess +from typing import Literal # from dsp.modules.adapter import TurboAdapter, DavinciAdapter, LlamaAdapter import backoff @@ -114,7 +115,7 @@ def send_hftgi_request_v00(arg, **kwargs): class HFClientVLLM(HFModel): - def __init__(self, model, port, url="http://localhost", **kwargs): + def __init__(self, model, port, model_type: Literal['chat', 'text'] = 'text', url="http://localhost", **kwargs): super().__init__(model=model, is_client=True) if isinstance(url, list): @@ -126,42 +127,76 @@ def __init__(self, model, port, url="http://localhost", **kwargs): else: raise ValueError(f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type {type(url)}.") + self.model_type = model_type self.headers = {"Content-Type": "application/json"} self.kwargs |= kwargs + # kwargs needs to have model, port and url for the lm.copy() to work properly + self.kwargs.update({ + 'port': port, + 'url': url, + }) def _generate(self, prompt, **kwargs): kwargs = {**self.kwargs, **kwargs} - - payload = { - "model": self.kwargs["model"], - "prompt": prompt, - **kwargs, - } - # Round robin the urls. url = self.urls.pop(0) self.urls.append(url) + + if self.model_type == "chat": + system_prompt = kwargs.get("system_prompt",None) + messages = [{"role": "user", "content": prompt}] + if system_prompt: + messages.insert(0, {"role": "system", "content": system_prompt}) + payload = { + "model": self.kwargs["model"], + "messages": messages, + **kwargs, + } + response = send_hfvllm_chat_request_v00( + f"{url}/v1/chat/completions", + json=payload, + headers=self.headers, + ) + + try: + json_response = response.json() + completions = json_response["choices"] + response = { + "prompt": prompt, + "choices": [{"text": c["message"]['content']} for c in completions], + } + return response - response = send_hfvllm_request_v00( - f"{url}/v1/completions", - json=payload, - headers=self.headers, - ) - - try: - json_response = response.json() - completions = json_response["choices"] - response = { + except Exception: + print("Failed to parse JSON response:", response.text) + raise Exception("Received invalid JSON response from server") + else: + payload = { + "model": self.kwargs["model"], "prompt": prompt, - "choices": [{"text": c["text"]} for c in completions], + **kwargs, } - return response + + response = send_hfvllm_request_v00( + f"{url}/v1/completions", + json=payload, + headers=self.headers, + ) + + try: + json_response = response.json() + completions = json_response["choices"] + response = { + "prompt": prompt, + "choices": [{"text": c["text"]} for c in completions], + } + return response - except Exception: - print("Failed to parse JSON response:", response.text) - raise Exception("Received invalid JSON response from server") + except Exception: + print("Failed to parse JSON response:", response.text) + raise Exception("Received invalid JSON response from server") @CacheMemory.cache @@ -169,6 +204,11 @@ def send_hfvllm_request_v00(arg, **kwargs): return requests.post(arg, **kwargs) +@CacheMemory.cache +def send_hfvllm_chat_request_v00(arg, **kwargs): + return requests.post(arg, **kwargs) + + class HFServerTGI: def __init__(self, user_dir): self.model_weights_dir = os.path.abspath(os.path.join(os.getcwd(), "text-generation-inference", user_dir)) @@ -435,4 +475,4 @@ def _generate(self, prompt, **kwargs): @CacheMemory.cache def send_hfsglang_request_v00(arg, **kwargs): - return requests.post(arg, **kwargs) \ No newline at end of file + return requests.post(arg, **kwargs) diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index b1c26dc8da..cfe8b88d05 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -63,13 +63,11 @@ def inspect_history(self, n: int = 1, skip: int = 0): if len(printed) >= n: break + printing_value = "" for idx, (prompt, choices) in enumerate(reversed(printed)): - printing_value = "" - # skip the first `skip` prompts if (n - idx - 1) < skip: continue - printing_value += "\n\n\n" printing_value += prompt