diff --git a/README.md b/README.md index 3f7b0385a..7064721e8 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,10 @@ streamlit run torchchat.py -- browser llama3.1
This mode gives a REST API that matches the OpenAI API spec for interacting with a model +The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions. +Since this feature is under active development, it's possible not every parameter is consumed. See api/api.py for details on +which request parameters are implemented. If you encounter any issues, please comment on the [tracking Github issue](https://github.com/pytorch/torchchat/issues/973). + To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request. In one terminal, start the server @@ -213,8 +217,7 @@ python3 torchchat.py server llama3.1 In another terminal, query the server using `curl`. Depending on the model configuration, this query might take a few minutes to respond. -Setting `stream` to "true" in the request emits a response in chunks. Currently, this response -is plaintext and will not be formatted to the OpenAI API specification. If `stream` is unset or not "true", then the client will await the full response from the server. +Setting `stream` to "true" in the request emits a response in chunks. If `stream` is unset or not "true", then the client will await the full response from the server. **Example Input + Output** @@ -227,6 +230,7 @@ curl http://127.0.0.1:5000/v1/chat \ -d '{ "model": "llama3.1", "stream": "true", + "max_tokens": 200, "messages": [ { "role": "system", diff --git a/api/api.py b/api/api.py index bef0eb914..63135133b 100644 --- a/api/api.py +++ b/api/api.py @@ -19,6 +19,8 @@ See https://platform.openai.com/docs/api-reference/chat for the full specification and details. """ +OPENAI_API_DEFAULT_MAX_TOKENS = 16 + # Message classes and associated objects - see the types of Messages under "Create Chat Completion >>> Request body >>> messages" @@ -105,20 +107,20 @@ class CompletionRequest: logit_bias: Optional[Dict[str, float]] = None # unimplemented logprobs: Optional[bool] = None # unimplemented top_logprobs: Optional[int] = None # unimplemented - max_tokens: Optional[int] = None # unimplemented + max_tokens: Optional[int] = None n: int = 1 presence_penalty: float = 0 # unimplemented response_format: Optional[ResponseFormat] = None # unimplemented - seed: Optional[int] = None # unimplemented + seed: Optional[int] = None service_tier: Optional[str] = None # unimplemented stop: Optional[List[str]] = None # unimplemented stream: bool = False stream_options: Optional[StreamOptions] = None # unimplemented - temperature: Optional[float] = 1.0 # unimplemented + temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 # unimplemented - tools: Optional[List[Any]] = None # unimplemented - tool_choice: Optional[Union[str, Any]] = None # unimplemented - parallel_tool_calls: Optional[bool] = None # unimplemented + tools: Optional[List[Any]] = None # unimplemented - Assistant features + tool_choice: Optional[Union[str, Any]] = None # unimplemented - Assistant features + parallel_tool_calls: Optional[bool] = None # unimplemented - Assistant features user: Optional[str] = None # unimplemented @@ -229,9 +231,8 @@ def __init__(self, *args, **kwargs): else self.model.config.max_seq_length ) # The System fingerprint is a unique identifier for the model and its configuration. - # Currently, this is not implemented in a self.system_fingerprint = ( - self.builder_args.device + type(self.builder_args.precision).__name__ + f"{self.builder_args.device}_{self.builder_args.precision}" ) def chunked_completion(self, completion_request: CompletionRequest): @@ -270,7 +271,13 @@ def chunked_completion(self, completion_request: CompletionRequest): ) generator_args = GeneratorArgs( completion_request.messages[-1].get("content"), + max_new_tokens=( + int(completion_request.max_tokens) + if completion_request.max_tokens + else OPENAI_API_DEFAULT_MAX_TOKENS + ), encoded_prompt=encoded, + temperature=float(completion_request.temperature), chat_mode=False, ) @@ -295,6 +302,7 @@ def callback(x, *, done_generating=False): sequential_prefill=generator_args.sequential_prefill, start_pos=self.start_pos, max_seq_length=self.max_seq_length, + seed=int(completion_request.seed), ): if y is None: continue diff --git a/generate.py b/generate.py index 5920bd656..48a77ba29 100644 --- a/generate.py +++ b/generate.py @@ -71,7 +71,7 @@ class GeneratorArgs: num_samples: int = 1 max_new_tokens: int = 200 top_k: int = 200 - temperature: int = 0 # deterministic argmax + temperature: float = 0.0 # deterministic argmax if 0.0 compile: bool = False compile_prefill: bool = False speculate_k: int = 5 @@ -105,9 +105,7 @@ def validate_build( def from_args(cls, args): dso_path = getattr(args, "dso_path", None) pte_path = getattr(args, "pte_path", None) - sequential_prefill = ( - args.sequential_prefill or bool(dso_path) or bool(pte_path) - ) + sequential_prefill = args.sequential_prefill or bool(dso_path) or bool(pte_path) return cls( prompt=getattr(args, "prompt", ""), diff --git a/server.py b/server.py index 2efff3304..106c30134 100644 --- a/server.py +++ b/server.py @@ -9,6 +9,8 @@ from dataclasses import asdict from typing import Dict, List, Union +import torch + from api.api import CompletionRequest, OpenAiApiGenerator from api.models import get_model_info_list, retrieve_model_info @@ -50,6 +52,8 @@ def chat_endpoint(): """ print(" === Completion Request ===") + if seed := request.args.get("seed"): + torch.manual_seed(int(seed)) # Parse the request in to a CompletionRequest object data = request.get_json()