Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 64 additions & 24 deletions dsp/modules/hf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import shutil
import subprocess
from typing import Literal

# from dsp.modules.adapter import TurboAdapter, DavinciAdapter, LlamaAdapter
import backoff
Expand Down Expand Up @@ -114,7 +115,7 @@ def send_hftgi_request_v00(arg, **kwargs):


class HFClientVLLM(HFModel):
def __init__(self, model, port, url="http://localhost", **kwargs):
def __init__(self, model, port, model_type: Literal['chat', 'text'] = 'text', url="http://localhost", **kwargs):
super().__init__(model=model, is_client=True)

if isinstance(url, list):
Expand All @@ -126,49 +127,88 @@ def __init__(self, model, port, url="http://localhost", **kwargs):
else:
raise ValueError(f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type {type(url)}.")

self.model_type = model_type
self.headers = {"Content-Type": "application/json"}
self.kwargs |= kwargs
# kwargs needs to have model, port and url for the lm.copy() to work properly
self.kwargs.update({
'port': port,
'url': url,
})


def _generate(self, prompt, **kwargs):
kwargs = {**self.kwargs, **kwargs}

payload = {
"model": self.kwargs["model"],
"prompt": prompt,
**kwargs,
}


# Round robin the urls.
url = self.urls.pop(0)
self.urls.append(url)

if self.model_type == "chat":
system_prompt = kwargs.get("system_prompt",None)
messages = [{"role": "user", "content": prompt}]
if system_prompt:
messages.insert(0, {"role": "system", "content": system_prompt})
payload = {
"model": self.kwargs["model"],
"messages": messages,
**kwargs,
}
response = send_hfvllm_chat_request_v00(
f"{url}/v1/chat/completions",
json=payload,
headers=self.headers,
)

try:
json_response = response.json()
completions = json_response["choices"]
response = {
"prompt": prompt,
"choices": [{"text": c["message"]['content']} for c in completions],
}
return response

response = send_hfvllm_request_v00(
f"{url}/v1/completions",
json=payload,
headers=self.headers,
)

try:
json_response = response.json()
completions = json_response["choices"]
response = {
except Exception:
print("Failed to parse JSON response:", response.text)
raise Exception("Received invalid JSON response from server")
else:
payload = {
"model": self.kwargs["model"],
"prompt": prompt,
"choices": [{"text": c["text"]} for c in completions],
**kwargs,
}
return response

response = send_hfvllm_request_v00(
f"{url}/v1/completions",
json=payload,
headers=self.headers,
)

try:
json_response = response.json()
completions = json_response["choices"]
response = {
"prompt": prompt,
"choices": [{"text": c["text"]} for c in completions],
}
return response

except Exception:
print("Failed to parse JSON response:", response.text)
raise Exception("Received invalid JSON response from server")
except Exception:
print("Failed to parse JSON response:", response.text)
raise Exception("Received invalid JSON response from server")


@CacheMemory.cache
def send_hfvllm_request_v00(arg, **kwargs):
return requests.post(arg, **kwargs)


@CacheMemory.cache
def send_hfvllm_chat_request_v00(arg, **kwargs):
return requests.post(arg, **kwargs)


class HFServerTGI:
def __init__(self, user_dir):
self.model_weights_dir = os.path.abspath(os.path.join(os.getcwd(), "text-generation-inference", user_dir))
Expand Down Expand Up @@ -435,4 +475,4 @@ def _generate(self, prompt, **kwargs):

@CacheMemory.cache
def send_hfsglang_request_v00(arg, **kwargs):
return requests.post(arg, **kwargs)
return requests.post(arg, **kwargs)
4 changes: 1 addition & 3 deletions dsp/modules/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,11 @@ def inspect_history(self, n: int = 1, skip: int = 0):
if len(printed) >= n:
break

printing_value = ""
for idx, (prompt, choices) in enumerate(reversed(printed)):
printing_value = ""

# skip the first `skip` prompts
if (n - idx - 1) < skip:
continue

printing_value += "\n\n\n"
printing_value += prompt

Expand Down