Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: OAuth system supports api-key #1168

Merged
merged 34 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8846be1
Setup for Contributing Doc
Ago327 Mar 1, 2024
0911f23
Merge branch 'xorbitsai:main' into main
Ago327 Mar 4, 2024
2c1a96d
Merge branch 'xorbitsai:main' into main
Ago327 Mar 5, 2024
dead463
Merge branch 'xorbitsai:main' into main
Ago327 Mar 6, 2024
a764092
Merge branch 'xorbitsai:main' into main
Ago327 Mar 6, 2024
d421b37
Merge branch 'xorbitsai:main' into main
Ago327 Mar 6, 2024
eec0c74
Merge branch 'xorbitsai:main' into main
Ago327 Mar 7, 2024
cc9f55c
Merge branch 'xorbitsai:main' into main
Ago327 Mar 11, 2024
147e819
doc_development
Ago327 Mar 11, 2024
0339462
Merge branch 'xorbitsai:main' into main
Ago327 Mar 12, 2024
c4337eb
Merge branch 'xorbitsai:main' into main
Ago327 Mar 13, 2024
ab43e12
Merge branch 'xorbitsai:main' into main
Ago327 Mar 18, 2024
fd15809
Merge branch 'xorbitsai:main' into main
Ago327 Mar 19, 2024
8298504
init api-key check
Ago327 Mar 19, 2024
b299cb5
set api-key non-positional
Ago327 Mar 19, 2024
58ff640
Merge branch 'xorbitsai:main' into api-key-feature
Ago327 Mar 21, 2024
9c35307
client with api-key
Ago327 Mar 21, 2024
09f6f31
fix doc
Ago327 Mar 21, 2024
bc7754a
Merge branch 'xorbitsai:main' into api-key-feature
Ago327 Mar 21, 2024
2d0d980
Merge branch 'xorbitsai:main' into api-key-feature
Ago327 Mar 25, 2024
6dd83ac
compatible with both client and curl
Ago327 Mar 25, 2024
62ea196
Merge branch 'xorbitsai:main' into api-key-feature
Ago327 Mar 25, 2024
d6ee499
compatible with cmdline
Ago327 Mar 25, 2024
b4cb69f
fix debug output
Ago327 Mar 25, 2024
f3ada96
bug fix
Ago327 Mar 25, 2024
750ae2b
fix
Ago327 Mar 25, 2024
2e0d594
fix test
Ago327 Mar 26, 2024
b95b756
Merge branch 'xorbitsai:main' into api-key-feature
Ago327 Mar 28, 2024
aedb26c
test with openaiSDK
Ago327 Mar 28, 2024
edf6ffb
fix SDK
Ago327 Mar 28, 2024
aea8dcc
fix import and test key
Ago327 Mar 28, 2024
f50a5b2
fix embedding
Ago327 Mar 28, 2024
1daa967
embedding
Ago327 Mar 28, 2024
5acf72f
fix terminate
Ago327 Mar 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/development/contributing_environment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Creating a development environment
Before proceeding with any code modifications, it's essential to set up the necessary environment for Xinference development,
which includes familiarizing yourself with Git usage, establishing an isolated environment, installing Xinference, and compiling the frontend.

Getting startted with Git
Getting started with Git
-------------------------

Now that you have identified an issue you wish to resolve, an enhancement to incorporate, or documentation to enhance,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: Xinference \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2024-03-06 16:29+0800\n"
"POT-Creation-Date: 2024-03-21 09:59+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
Expand Down Expand Up @@ -38,7 +38,7 @@ msgstr ""
"Xinference 以及前端部分的编译。"

#: ../../source/development/contributing_environment.rst:12
msgid "Getting startted with Git"
msgid "Getting started with Git"
msgstr "Git 的使用"

#: ../../source/development/contributing_environment.rst:14
Expand Down
65 changes: 47 additions & 18 deletions xinference/api/oauth2/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from datetime import timedelta
from typing import List, Optional
from typing import List, Optional, Tuple

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, SecurityScopes
Expand Down Expand Up @@ -40,13 +41,30 @@ def __init__(self, auth_config_file: Optional[str]):
def config(self):
return self._config

@staticmethod
def is_legal_api_key(key: str) -> bool:
pattern = re.compile("^sk-[a-zA-Z0-9]{13}$")
return re.match(pattern, key) is not None

def init_auth_config(self):
if self._auth_config_file:
config: AuthStartupConfig = parse_file_as(
path=self._auth_config_file, type_=AuthStartupConfig
)
all_api_keys = set()
for user in config.user_config:
user.password = get_password_hash(user.password)
for api_key in user.api_keys:
if not self.is_legal_api_key(api_key):
raise ValueError(
"Api-Key should be a string started with 'sk-' with a total length of 16"
)
if api_key in all_api_keys:
raise ValueError(
"Duplicate api-keys exists, please check your configuration"
)
else:
all_api_keys.add(api_key)
return config

def __call__(
Expand All @@ -67,28 +85,30 @@ def __call__(
headers={"WWW-Authenticate": authenticate_value},
)

try:
assert self._config is not None
payload = jwt.decode(
token,
self._config.auth_config.secret_key,
algorithms=[self._config.auth_config.algorithm],
options={"verify_exp": False}, # TODO: supports token expiration
)
username: str = payload.get("sub")
if username is None:
if self.is_legal_api_key(token):
user, token_scopes = self.get_user_and_scopes_with_api_key(token)
else:
try:
assert self._config is not None
payload = jwt.decode(
token,
self._config.auth_config.secret_key,
algorithms=[self._config.auth_config.algorithm],
options={"verify_exp": False}, # TODO: supports token expiration
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_scopes = payload.get("scopes", [])
user = self.get_user(username)
except (JWTError, ValidationError):
raise credentials_exception
token_scopes = payload.get("scopes", [])
token_data = TokenData(scopes=token_scopes, username=username)
except (JWTError, ValidationError):
raise credentials_exception
user = self.get_user(token_data.username)
if user is None:
raise credentials_exception
if "admin" in token_data.scopes:
if "admin" in token_scopes:
return user
for scope in security_scopes.scopes:
if scope not in token_data.scopes:
if scope not in token_scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions",
Expand All @@ -102,6 +122,15 @@ def get_user(self, username: str) -> Optional[User]:
return user
return None

def get_user_and_scopes_with_api_key(
self, api_key: str
) -> Tuple[Optional[User], List]:
for user in self._config.user_config:
for key in user.api_keys:
if api_key == key:
return user, user.permissions
return None, []

def authenticate_user(self, username: str, password: str):
user = self.get_user(username)
if not user:
Expand Down
1 change: 1 addition & 0 deletions xinference/api/oauth2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class LoginUserForm(BaseModel):

class User(LoginUserForm):
permissions: List[str]
api_keys: List[str]


class AuthConfig(BaseModel):
Expand Down
6 changes: 4 additions & 2 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,13 @@ def translations(


class Client:
def __init__(self, base_url):
def __init__(self, base_url, api_key: Optional[str] = None):
self.base_url = base_url
self._headers = {}
self._headers: Dict[str, str] = {}
self._cluster_authed = False
self._check_cluster_authenticated()
if api_key is not None and self._cluster_authed:
self._headers["Authorization"] = f"Bearer {api_key}"

def _set_token(self, token: Optional[str]):
if not self._cluster_authed or token is None:
Expand Down
54 changes: 54 additions & 0 deletions xinference/client/tests/test_client_with_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,57 @@ def test_client_auth(setup_with_auth):
assert len(client.list_models()) == 1
client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0

# test with api-key
client = RESTfulClient(endpoint, api_key="sk-wrongapikey12")
with pytest.raises(RuntimeError):
client.list_models()

client = RESTfulClient(endpoint, api_key="sk-72tkvudyGLPMi")
assert len(client.list_models()) == 0

with pytest.raises(RuntimeError):
client.launch_model(model_name="bge-small-en-v1.5", model_type="embedding")

client = RESTfulClient(endpoint, api_key="sk-ZOTLIY4gt9w11")
model_uid = client.launch_model(
model_name="bge-small-en-v1.5", model_type="embedding"
)
model = client.get_model(model_uid=model_uid)
assert isinstance(model, RESTfulEmbeddingModelHandle)

completion = model.create_embedding("write a poem.")
assert len(completion["data"][0]["embedding"]) == 384

with pytest.raises(RuntimeError):
client.terminate_model(model_uid=model_uid)

client = RESTfulClient(endpoint, api_key="sk-3sjLbdwqAhhAF")
assert len(client.list_models()) == 1

# test with openai SDK
from openai import AuthenticationError, OpenAI, PermissionDeniedError

client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-wrongapikey12")
with pytest.raises(AuthenticationError):
client_ai.models.list()

client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-72tkvudyGLPMi")
assert len(client_ai.models.list().data) == 1
with pytest.raises(PermissionDeniedError):
chat_completion = client_ai.embeddings.create(
model="bge-small-en-v1.5",
input="write a poem.",
)

client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-ZOTLIY4gt9w11")
chat_completion = client_ai.embeddings.create(
model="bge-small-en-v1.5",
input="write a poem.",
)
assert len(chat_completion.data[0].embedding) == 384

client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-3sjLbdwqAhhAF")
client.terminate_model(model_uid)
assert len(client.list_models()) == 0
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
assert len(client_ai.models.list().data) == 0
15 changes: 13 additions & 2 deletions xinference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,23 @@ def setup_with_auth():
if not cluster_health_check(supervisor_addr, max_attempts=10, sleep_interval=3):
raise RuntimeError("Cluster is not available after multiple attempts")

user1 = User(username="user1", password="pass1", permissions=["admin"])
user2 = User(username="user2", password="pass2", permissions=["models:list"])
user1 = User(
username="user1",
password="pass1",
permissions=["admin"],
api_keys=["sk-3sjLbdwqAhhAF", "sk-0HCRO1rauFQDL"],
)
user2 = User(
username="user2",
password="pass2",
permissions=["models:list"],
api_keys=["sk-72tkvudyGLPMi"],
)
user3 = User(
username="user3",
password="pass3",
permissions=["models:list", "models:read", "models:start"],
api_keys=["sk-m6jEzEwmCc4iQ", "sk-ZOTLIY4gt9w11"],
)
auth_config = AuthConfig(
algorithm="HS256",
Expand Down
Loading
Loading