Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: OAuth system supports api-key #1168

Merged
merged 34 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8846be1
Setup for Contributing Doc
Ago327 Mar 1, 2024
0911f23
Merge branch 'xorbitsai:main' into main
Ago327 Mar 4, 2024
2c1a96d
Merge branch 'xorbitsai:main' into main
Ago327 Mar 5, 2024
dead463
Merge branch 'xorbitsai:main' into main
Ago327 Mar 6, 2024
a764092
Merge branch 'xorbitsai:main' into main
Ago327 Mar 6, 2024
d421b37
Merge branch 'xorbitsai:main' into main
Ago327 Mar 6, 2024
eec0c74
Merge branch 'xorbitsai:main' into main
Ago327 Mar 7, 2024
cc9f55c
Merge branch 'xorbitsai:main' into main
Ago327 Mar 11, 2024
147e819
doc_development
Ago327 Mar 11, 2024
0339462
Merge branch 'xorbitsai:main' into main
Ago327 Mar 12, 2024
c4337eb
Merge branch 'xorbitsai:main' into main
Ago327 Mar 13, 2024
ab43e12
Merge branch 'xorbitsai:main' into main
Ago327 Mar 18, 2024
fd15809
Merge branch 'xorbitsai:main' into main
Ago327 Mar 19, 2024
8298504
init api-key check
Ago327 Mar 19, 2024
b299cb5
set api-key non-positional
Ago327 Mar 19, 2024
58ff640
Merge branch 'xorbitsai:main' into api-key-feature
Ago327 Mar 21, 2024
9c35307
client with api-key
Ago327 Mar 21, 2024
09f6f31
fix doc
Ago327 Mar 21, 2024
bc7754a
Merge branch 'xorbitsai:main' into api-key-feature
Ago327 Mar 21, 2024
2d0d980
Merge branch 'xorbitsai:main' into api-key-feature
Ago327 Mar 25, 2024
6dd83ac
compatible with both client and curl
Ago327 Mar 25, 2024
62ea196
Merge branch 'xorbitsai:main' into api-key-feature
Ago327 Mar 25, 2024
d6ee499
compatible with cmdline
Ago327 Mar 25, 2024
b4cb69f
fix debug output
Ago327 Mar 25, 2024
f3ada96
bug fix
Ago327 Mar 25, 2024
750ae2b
fix
Ago327 Mar 25, 2024
2e0d594
fix test
Ago327 Mar 26, 2024
b95b756
Merge branch 'xorbitsai:main' into api-key-feature
Ago327 Mar 28, 2024
aedb26c
test with openaiSDK
Ago327 Mar 28, 2024
edf6ffb
fix SDK
Ago327 Mar 28, 2024
aea8dcc
fix import and test key
Ago327 Mar 28, 2024
f50a5b2
fix embedding
Ago327 Mar 28, 2024
1daa967
embedding
Ago327 Mar 28, 2024
5acf72f
fix terminate
Ago327 Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/development/contributing_environment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Creating a development environment
Before proceeding with any code modifications, it's essential to set up the necessary environment for Xinference development,
which includes familiarizing yourself with Git usage, establishing an isolated environment, installing Xinference, and compiling the frontend.

Getting startted with Git
Getting started with Git
-------------------------

Now that you have identified an issue you wish to resolve, an enhancement to incorporate, or documentation to enhance,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: Xinference \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2024-03-06 16:29+0800\n"
"POT-Creation-Date: 2024-03-21 09:59+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
Expand Down Expand Up @@ -38,7 +38,7 @@ msgstr ""
"Xinference 以及前端部分的编译。"

#: ../../source/development/contributing_environment.rst:12
msgid "Getting startted with Git"
msgid "Getting started with Git"
msgstr "Git 的使用"

#: ../../source/development/contributing_environment.rst:14
Expand Down
63 changes: 51 additions & 12 deletions xinference/api/oauth2/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from datetime import timedelta
from typing import List, Optional

Expand Down Expand Up @@ -40,13 +41,33 @@ def __init__(self, auth_config_file: Optional[str]):
def config(self):
return self._config

@staticmethod
def is_legal_api_key(key: str):
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
pattern = re.compile("^[sk]{2}-[a-zA-Z0-9]{48}$")
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
if re.match(pattern, key):
return True
else:
return False
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved

def init_auth_config(self):
if self._auth_config_file:
config: AuthStartupConfig = parse_file_as(
path=self._auth_config_file, type_=AuthStartupConfig
)
total_keys = set()
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
for user in config.user_config:
user.password = get_password_hash(user.password)
if len(set(user.api_keys)) != len(user.api_keys):
raise ValueError("User has duplicate Api-Keys")
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
for api_key in user.api_keys:
if not self.is_legal_api_key(api_key):
raise ValueError(
"Api-Key should be a string started with 'sk-' with a total length of 51"
)
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
if api_key in total_keys:
raise ValueError("Api-Keys of different users have conflict")
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
else:
total_keys.add(api_key)
return config

def __call__(
Expand All @@ -67,22 +88,33 @@ def __call__(
headers={"WWW-Authenticate": authenticate_value},
)

through_api_key = False
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved

try:
assert self._config is not None
payload = jwt.decode(
token,
self._config.auth_config.secret_key,
algorithms=[self._config.auth_config.algorithm],
options={"verify_exp": False}, # TODO: supports token expiration
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_scopes = payload.get("scopes", [])
token_data = TokenData(scopes=token_scopes, username=username)
if self.is_legal_api_key(token):
through_api_key = True
else:
payload = jwt.decode(
token,
self._config.auth_config.secret_key,
algorithms=[self._config.auth_config.algorithm],
options={"verify_exp": False}, # TODO: supports token expiration
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_scopes = payload.get("scopes", [])
token_data = TokenData(scopes=token_scopes, username=username)
except (JWTError, ValidationError):
raise credentials_exception
user = self.get_user(token_data.username)
if not through_api_key:
user = self.get_user(token_data.username)
else:
user = self.get_user_with_api_key(token)
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
if user is None:
raise credentials_exception
token_data = TokenData(scopes=user.permissions, username=user.username)
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
if user is None:
raise credentials_exception
if "admin" in token_data.scopes:
Expand All @@ -102,6 +134,13 @@ def get_user(self, username: str) -> Optional[User]:
return user
return None

def get_user_with_api_key(self, api_key: str) -> Optional[User]:
for user in self._config.user_config:
for key in user.api_keys:
if api_key == key:
return user
return None

def authenticate_user(self, username: str, password: str):
user = self.get_user(username)
if not user:
Expand Down
1 change: 1 addition & 0 deletions xinference/api/oauth2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class LoginUserForm(BaseModel):

class User(LoginUserForm):
permissions: List[str]
api_keys: List[str]


class AuthConfig(BaseModel):
Expand Down
6 changes: 4 additions & 2 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,13 @@ def translations(


class Client:
def __init__(self, base_url):
def __init__(self, base_url, api_key: Optional[str] = None):
self.base_url = base_url
self._headers = {}
self._headers: Dict[str, str] = {}
self._cluster_authed = False
self._check_cluster_authenticated()
if api_key is not None:
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
self._headers["Authorization"] = f"Bearer {api_key}"

def _set_token(self, token: Optional[str]):
if not self._cluster_authed or token is None:
Expand Down
119 changes: 99 additions & 20 deletions xinference/deploy/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,18 +376,27 @@ def worker(
is_flag=True,
help="Persist the model configuration to the filesystem, retains the model registration after server restarts.",
)
@click.option(
"--api-key",
"-ak",
default=None,
type=str,
help="Api-Key for access xinference api with authorization.",
)
def register_model(
endpoint: Optional[str],
model_type: str,
file: str,
persist: bool,
api_key: Optional[str],
):
endpoint = get_endpoint(endpoint)
with open(file) as fd:
model = fd.read()

client = RESTfulClient(base_url=endpoint)
client._set_token(get_stored_token(endpoint, client))
client = RESTfulClient(base_url=endpoint, api_key=api_key)
if client._get_token() is None:
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
client._set_token(get_stored_token(endpoint, client))
client.register_model(
model_type=model_type,
model=model,
Expand All @@ -408,15 +417,24 @@ def register_model(
help="Type of model to unregister (default is 'LLM').",
)
@click.option("--model-name", "-n", type=str, help="Name of the model to unregister.")
@click.option(
"--api-key",
"-ak",
default=None,
type=str,
help="Api-Key for access xinference api with authorization.",
)
def unregister_model(
endpoint: Optional[str],
model_type: str,
model_name: str,
api_key: Optional[str],
):
endpoint = get_endpoint(endpoint)

client = RESTfulClient(base_url=endpoint)
client._set_token(get_stored_token(endpoint, client))
client = RESTfulClient(base_url=endpoint, api_key=api_key)
if client._get_token() is None:
client._set_token(get_stored_token(endpoint, client))
client.unregister_model(
model_type=model_type,
model_name=model_name,
Expand All @@ -437,15 +455,24 @@ def unregister_model(
type=str,
help="Filter by model type (default is 'LLM').",
)
@click.option(
"--api-key",
"-ak",
default=None,
type=str,
help="Api-Key for access xinference api with authorization.",
)
def list_model_registrations(
endpoint: Optional[str],
model_type: str,
api_key: Optional[str],
):
from tabulate import tabulate

endpoint = get_endpoint(endpoint)
client = RESTfulClient(base_url=endpoint)
client._set_token(get_stored_token(endpoint, client))
client = RESTfulClient(base_url=endpoint, api_key=api_key)
if client._get_token() is None:
client._set_token(get_stored_token(endpoint, client))

registrations = client.list_model_registrations(model_type=model_type)

Expand Down Expand Up @@ -638,6 +665,13 @@ def list_model_registrations(
type=bool,
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
)
@click.option(
"--api-key",
"-ak",
default=None,
type=str,
help="Api-Key for access xinference api with authorization.",
)
@click.pass_context
def model_launch(
ctx,
Expand All @@ -654,6 +688,7 @@ def model_launch(
image_lora_load_kwargs: Optional[Tuple],
image_lora_fuse_kwargs: Optional[Tuple],
trust_remote_code: bool,
api_key: Optional[str],
):
kwargs = {}
for i in range(0, len(ctx.args), 2):
Expand Down Expand Up @@ -686,8 +721,9 @@ def model_launch(
if size_in_billions is None or "_" in size_in_billions
else int(size_in_billions)
)
client = RESTfulClient(base_url=endpoint)
client._set_token(get_stored_token(endpoint, client))
client = RESTfulClient(base_url=endpoint, api_key=api_key)
if client._get_token() is None:
client._set_token(get_stored_token(endpoint, client))

model_uid = client.launch_model(
model_name=model_name,
Expand Down Expand Up @@ -718,12 +754,20 @@ def model_launch(
type=str,
help="Xinference endpoint.",
)
def model_list(endpoint: Optional[str]):
@click.option(
"--api-key",
"-ak",
default=None,
type=str,
help="Api-Key for access xinference api with authorization.",
)
def model_list(endpoint: Optional[str], api_key: Optional[str]):
from tabulate import tabulate

endpoint = get_endpoint(endpoint)
client = RESTfulClient(base_url=endpoint)
client._set_token(get_stored_token(endpoint, client))
client = RESTfulClient(base_url=endpoint, api_key=api_key)
if client._get_token() is None:
client._set_token(get_stored_token(endpoint, client))

llm_table = []
embedding_table = []
Expand Down Expand Up @@ -844,13 +888,22 @@ def model_list(endpoint: Optional[str]):
required=True,
help="The unique identifier (UID) of the model.",
)
@click.option(
"--api-key",
"-ak",
default=None,
type=str,
help="Api-Key for access xinference api with authorization.",
)
def model_terminate(
endpoint: Optional[str],
model_uid: str,
api_key: Optional[str],
):
endpoint = get_endpoint(endpoint)
client = RESTfulClient(base_url=endpoint)
client._set_token(get_stored_token(endpoint, client))
client = RESTfulClient(base_url=endpoint, api_key=api_key)
if client._get_token() is None:
client._set_token(get_stored_token(endpoint, client))
client.terminate_model(model_uid=model_uid)


Expand All @@ -873,15 +926,24 @@ def model_terminate(
type=bool,
help="Whether to stream the generated text. Use 'True' for streaming (default is True).",
)
@click.option(
"--api-key",
"-ak",
default=None,
type=str,
help="Api-Key for access xinference api with authorization.",
)
def model_generate(
endpoint: Optional[str],
model_uid: str,
max_tokens: int,
stream: bool,
api_key: Optional[str],
):
endpoint = get_endpoint(endpoint)
client = RESTfulClient(base_url=endpoint)
client._set_token(get_stored_token(endpoint, client))
client = RESTfulClient(base_url=endpoint, api_key=api_key)
if client._get_token() is None:
client._set_token(get_stored_token(endpoint, client))
if stream:
# TODO: when stream=True, RestfulClient cannot generate words one by one.
# So use Client in temporary. The implementation needs to be changed to
Expand Down Expand Up @@ -959,16 +1021,25 @@ async def generate_internal():
type=bool,
help="Whether to stream the chat messages. Use 'True' for streaming (default is True).",
)
@click.option(
"--api-key",
"-ak",
default=None,
type=str,
help="Api-Key for access xinference api with authorization.",
)
def model_chat(
endpoint: Optional[str],
model_uid: str,
max_tokens: int,
stream: bool,
api_key: Optional[str],
):
# TODO: chat model roles may not be user and assistant.
endpoint = get_endpoint(endpoint)
client = RESTfulClient(base_url=endpoint)
client._set_token(get_stored_token(endpoint, client))
client = RESTfulClient(base_url=endpoint, api_key=api_key)
if client._get_token() is None:
client._set_token(get_stored_token(endpoint, client))

chat_history: "List[ChatCompletionMessage]" = []
if stream:
Expand Down Expand Up @@ -1048,10 +1119,18 @@ async def chat_internal():

@cli.command("vllm-models", help="Query and display models compatible with vLLM.")
@click.option("--endpoint", "-e", type=str, help="Xinference endpoint.")
def vllm_models(endpoint: Optional[str]):
@click.option(
"--api-key",
"-ak",
default=None,
type=str,
help="Api-Key for access xinference api with authorization.",
)
def vllm_models(endpoint: Optional[str], api_key: Optional[str]):
endpoint = get_endpoint(endpoint)
client = RESTfulClient(base_url=endpoint)
client._set_token(get_stored_token(endpoint, client))
client = RESTfulClient(base_url=endpoint, api_key=api_key)
if client._get_token() is None:
client._set_token(get_stored_token(endpoint, client))
vllm_models_dict = client.vllm_models()
print("VLLM supported model families:")
chat_models = vllm_models_dict["chat"]
Expand Down
Loading