Skip to content

Commit

Permalink
Implementing local OpenAI API-style chat completions on any given inf…
Browse files Browse the repository at this point in the history
…erence server (EleutherAI#1174)

* LocalChatCompletionsLM add

* clean up completions class

* clean up completions class

* update tokens

* README

* fix constructor

* eos token

* folding local-chat-completions into OpenAIChatCompletions

* refactoring to include gen_kwargs as passable option

* add todo on chat completion kwarg validation

* Ruff and README fix

* generalize to **kwargs

* remove unnecessary kwargs

* README and remove kwargs

* README
  • Loading branch information
veekaybee authored and wx-zhang committed Dec 24, 2023
1 parent affa41b commit 4fa7466
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 40 deletions.
25 changes: 15 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,19 +155,24 @@ lm_eval --model openai-completions \
--tasks lambada_openai,hellaswag
```

We also support using your own local inference server with an implemented version of the OpenAI ChatCompletions endpoint and passing trained HuggingFace artifacts and tokenizers.

```bash
lm_eval --model local-chat-completions --tasks gsm8k --model_args model=facebook/opt-125m,base_url=http://{yourip}:8000/v1
```
Note that for externally hosted models, configs such as `--device` and `--batch_size` should not be used and do not function. Just like you can use `--model_args` to pass arbitrary arguments to the model constructor for local models, you can use it to pass arbitrary arguments to the model API for hosted models. See the documentation of the hosting service for information on what arguments they support.


| API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: |
|-----------------------------|---------------------------------|--------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------|----------------------------------------------------------|
| OpenAI Completions | :heavy_check_mark: | `openai-completions` | up to `code-davinci-002` | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| OpenAI ChatCompletions | :x: Not yet - needs testing! | N/A | [All ChatCompletions API models](https://platform.openai.com/docs/guides/gpt) | `generate_until` (no logprobs) |
| Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `generate_until` (no logprobs) |
| Textsynth | :heavy_check_mark: | `textsynth` | [All supported engines](https://textsynth.com/documentation.html#engines) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Cohere | [:hourglass: - blocked on Cohere API bug](https://github.com/EleutherAI/lm-evaluation-harness/pull/395) | N/A | [All `cohere.generate()` engines](https://docs.cohere.com/docs/models) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| [Llama.cpp](https://github.com/ggerganov/llama.cpp) (via [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)) | :heavy_check_mark: | `gguf`, `ggml` | [All models supported by llama.cpp](https://github.com/ggerganov/llama.cpp) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| vLLM | :heavy_check_mark: | `vllm` | [Most HF Causal Language Models](https://docs.vllm.ai/en/latest/models/supported_models.html) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Your inference server here! | ... | ... | ... | ... | | ... |
| API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: |
|---------------------------------------------------------------------------------------------------------------------------|---------------------------------|---------------------------------------------------------------------|-----------------------------------------------------------------------------------------------|------------------------------------------------------------|
| OpenAI Completions | :heavy_check_mark: | `openai-completions` | up to `code-davinci-002` | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| OpenAI ChatCompletions | :heavy_check_mark: | `openai-chat-completions`, `local-chat-completions` | [All ChatCompletions API models](https://platform.openai.com/docs/guides/gpt) | `generate_until` (no logprobs) |
| Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `generate_until` (no logprobs) |
| Textsynth | :heavy_check_mark: | `textsynth` | [All supported engines](https://textsynth.com/documentation.html#engines) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Cohere | [:hourglass: - blocked on Cohere API bug](https://github.com/EleutherAI/lm-evaluation-harness/pull/395) | N/A | [All `cohere.generate()` engines](https://docs.cohere.com/docs/models) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| [Llama.cpp](https://github.com/ggerganov/llama.cpp) (via [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)) | :heavy_check_mark: | `gguf`, `ggml` | [All models supported by llama.cpp](https://github.com/ggerganov/llama.cpp) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| vLLM | :heavy_check_mark: | `vllm` | [Most HF Causal Language Models](https://docs.vllm.ai/en/latest/models/supported_models.html) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Your local inference server! | :heavy_check_mark: | `local-chat-completions` (using `openai-completions` model type) | Any server address that accepts GET requests using HF models and mirror's OpenAI's ChatCompletions interface | `generate_until` | | ... |

It is on our roadmap to create task variants designed to enable models which do not serve logprobs/loglikelihoods to be compared with generation performance of open-source models.

Expand Down
75 changes: 45 additions & 30 deletions lm_eval/models/openai_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from importlib.util import find_spec
from typing import List, Optional, Tuple

import transformers
from tqdm import tqdm

from lm_eval import utils
Expand Down Expand Up @@ -104,7 +105,7 @@ def __init__(
self._max_gen_toks = max_gen_toks
self._max_length = max_length

# Read from environment variable OPENAI_API_SECRET_KEY
# Read from environment variable OPENAI_API_KEY
openai.api_key = os.environ["OPENAI_API_KEY"]

@property
Expand Down Expand Up @@ -353,15 +354,26 @@ async def _get_completions(**kwargs):
backoff_time *= 1.5


@register_model("openai-chat-completions")
@register_model("openai-chat-completions", "local-chat-completions")
class OpenaiChatCompletionsLM(LM):
def __init__(
self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1
self,
model: str = "gpt-3.5-turbo", # GPT model or Local model using HuggingFace model paths
base_url: str = None,
truncate: bool = False,
revision: Optional[str] = "main",
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
**kwargs,
) -> None:
"""
:param model: str
Implements an OpenAI-style chat completion API for
accessing both OpenAI OR locally-hosted models using
HuggingFace Tokenizer
OpenAI API model (e.g. gpt-3.5-turbo)
using the **gen_kwargs passed on init
:param truncate: bool
Truncate input if too long (if False and input is too long, throw error)
"""
Expand All @@ -375,19 +387,34 @@ def __init__(
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
)
self.model = model
self.frequency_penalty = 0
self.logit_bias = None
self.n = 1
self.presence_penalty = 0
self.temperature = 1
self.top_p = 1
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.base_url = base_url
self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token

# if we have a local model, use HF tokenizer over tiktoken
if self.base_url:
self.revision = revision
self.trust_remote_code = trust_remote_code
self.use_fast_tokenizer = use_fast_tokenizer

self.tokenizer = transformers.AutoTokenizer.from_pretrained(
self.model,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
use_fast_tokenizer=self.use_fast_tokenizer,
)
self.vocab_size = self.tokenizer.vocab
self.end_of_text_token_id = self.tokenizer.eos_token
else:
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.end_of_text_token_id = self.tokenizer.eot_token

# Read from environment variable OPENAI_API_KEY
self.client = openai.OpenAI() # openai.AsyncOpenAI()
# Set to EMPTY for local
if self.base_url:
self.client = openai.OpenAI(base_url=self.base_url)
else:
self.client = openai.OpenAI() # openai.AsyncOpenAI()

@property
def eot_token_id(self):
Expand Down Expand Up @@ -474,35 +501,23 @@ def sameuntil_chunks(xs, size):
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}"
)

if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks

response = oa_chat_completion(
client=self.client,
messages=inps,
model=self.model,
frequency_penalty=self.frequency_penalty,
# logit_bias=self.logit_bias,
max_tokens=max_gen_toks,
n=self.n,
presence_penalty=self.presence_penalty,
temperature=self.temperature,
top_p=self.top_p,
client=self.client, messages=inps, model=self.model, **kwargs
)

for resp, (context, args_) in zip(response.choices, chunk):
Expand Down

0 comments on commit 4fa7466

Please sign in to comment.