From 0b7289fca05988b2863cc46a3564e8edbc9b7021 Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Sat, 13 Apr 2024 10:30:27 +0800 Subject: [PATCH 1/2] added vllm chat completion endpoints and fix inspect_history --- dsp/modules/hf_client.py | 88 +++++++++++++++++++++++++++++----------- dsp/modules/lm.py | 8 ++-- 2 files changed, 68 insertions(+), 28 deletions(-) diff --git a/dsp/modules/hf_client.py b/dsp/modules/hf_client.py index caa4b62652..fcda7dd8af 100644 --- a/dsp/modules/hf_client.py +++ b/dsp/modules/hf_client.py @@ -128,40 +128,77 @@ def __init__(self, model, port, url="http://localhost", **kwargs): self.headers = {"Content-Type": "application/json"} self.kwargs |= kwargs + # kwargs needs to have model, port and url for the lm.copy() to work properly + self.kwargs.update({ + 'model': model, + 'port': port, + 'url': url + }) def _generate(self, prompt, **kwargs): kwargs = {**self.kwargs, **kwargs} - - payload = { - "model": self.kwargs["model"], - "prompt": prompt, - **kwargs, - } - + # get model_type + model_type = kwargs.get("model_type",None) + # Round robin the urls. url = self.urls.pop(0) self.urls.append(url) + + if model_type == "chat": + system_prompt = kwargs.get("system_prompt",None) + messages = [{"role": "user", "content": prompt}] + if system_prompt: + messages.insert(0, {"role": "system", "content": system_prompt}) + payload = { + "model": self.kwargs["model"], + "messages": messages, + **kwargs, + } + response = send_hfvllm_chat_request_v00( + f"{url}/v1/chat/completions", + json=payload, + headers=self.headers, + ) + + try: + json_response = response.json() + completions = json_response["choices"] + response = { + "prompt": prompt, + "choices": [{"text": c["message"]['content']} for c in completions], + } + return response - response = send_hfvllm_request_v00( - f"{url}/v1/completions", - json=payload, - headers=self.headers, - ) - - try: - json_response = response.json() - completions = json_response["choices"] - response = { + except Exception: + print("Failed to parse JSON response:", response.text) + raise Exception("Received invalid JSON response from server") + else: + payload = { + "model": self.kwargs["model"], "prompt": prompt, - "choices": [{"text": c["text"]} for c in completions], + **kwargs, } - return response + + response = send_hfvllm_request_v00( + f"{url}/v1/completions", + json=payload, + headers=self.headers, + ) + + try: + json_response = response.json() + completions = json_response["choices"] + response = { + "prompt": prompt, + "choices": [{"text": c["text"]} for c in completions], + } + return response - except Exception: - print("Failed to parse JSON response:", response.text) - raise Exception("Received invalid JSON response from server") + except Exception: + print("Failed to parse JSON response:", response.text) + raise Exception("Received invalid JSON response from server") @CacheMemory.cache @@ -169,6 +206,11 @@ def send_hfvllm_request_v00(arg, **kwargs): return requests.post(arg, **kwargs) +@CacheMemory.cache +def send_hfvllm_chat_request_v00(arg, **kwargs): + return requests.post(arg, **kwargs) + + class HFServerTGI: def __init__(self, user_dir): self.model_weights_dir = os.path.abspath(os.path.join(os.getcwd(), "text-generation-inference", user_dir)) @@ -435,4 +477,4 @@ def _generate(self, prompt, **kwargs): @CacheMemory.cache def send_hfsglang_request_v00(arg, **kwargs): - return requests.post(arg, **kwargs) \ No newline at end of file + return requests.post(arg, **kwargs) diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index b1c26dc8da..ac65343b47 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -63,13 +63,11 @@ def inspect_history(self, n: int = 1, skip: int = 0): if len(printed) >= n: break - for idx, (prompt, choices) in enumerate(reversed(printed)): - printing_value = "" - + for idx, (prompt, choices) in enumerate(printed): # skip the first `skip` prompts - if (n - idx - 1) < skip: + if (n - idx - 1) > skip: continue - + printing_value = "" printing_value += "\n\n\n" printing_value += prompt From ca4efef8aa3b59ad2a25709bd3d719f92eb08ebb Mon Sep 17 00:00:00 2001 From: tjtanaa Date: Mon, 15 Apr 2024 21:58:47 +0800 Subject: [PATCH 2/2] update vllm initializtion and fix inspect history --- dsp/modules/hf_client.py | 12 +++++------- dsp/modules/lm.py | 6 +++--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/dsp/modules/hf_client.py b/dsp/modules/hf_client.py index fcda7dd8af..6c4aaef20b 100644 --- a/dsp/modules/hf_client.py +++ b/dsp/modules/hf_client.py @@ -3,6 +3,7 @@ import re import shutil import subprocess +from typing import Literal # from dsp.modules.adapter import TurboAdapter, DavinciAdapter, LlamaAdapter import backoff @@ -114,7 +115,7 @@ def send_hftgi_request_v00(arg, **kwargs): class HFClientVLLM(HFModel): - def __init__(self, model, port, url="http://localhost", **kwargs): + def __init__(self, model, port, model_type: Literal['chat', 'text'] = 'text', url="http://localhost", **kwargs): super().__init__(model=model, is_client=True) if isinstance(url, list): @@ -126,27 +127,24 @@ def __init__(self, model, port, url="http://localhost", **kwargs): else: raise ValueError(f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type {type(url)}.") + self.model_type = model_type self.headers = {"Content-Type": "application/json"} self.kwargs |= kwargs # kwargs needs to have model, port and url for the lm.copy() to work properly self.kwargs.update({ - 'model': model, 'port': port, - 'url': url + 'url': url, }) def _generate(self, prompt, **kwargs): kwargs = {**self.kwargs, **kwargs} - # get model_type - model_type = kwargs.get("model_type",None) - # Round robin the urls. url = self.urls.pop(0) self.urls.append(url) - if model_type == "chat": + if self.model_type == "chat": system_prompt = kwargs.get("system_prompt",None) messages = [{"role": "user", "content": prompt}] if system_prompt: diff --git a/dsp/modules/lm.py b/dsp/modules/lm.py index ac65343b47..cfe8b88d05 100644 --- a/dsp/modules/lm.py +++ b/dsp/modules/lm.py @@ -63,11 +63,11 @@ def inspect_history(self, n: int = 1, skip: int = 0): if len(printed) >= n: break - for idx, (prompt, choices) in enumerate(printed): + printing_value = "" + for idx, (prompt, choices) in enumerate(reversed(printed)): # skip the first `skip` prompts - if (n - idx - 1) > skip: + if (n - idx - 1) < skip: continue - printing_value = "" printing_value += "\n\n\n" printing_value += prompt