Skip to content

Commit

Permalink
FEAT: OAuth system supports api-key (#1168)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ago327 committed Mar 28, 2024
1 parent 5a83dd5 commit b44c2ce
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 45 deletions.
2 changes: 1 addition & 1 deletion doc/source/development/contributing_environment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Creating a development environment
Before proceeding with any code modifications, it's essential to set up the necessary environment for Xinference development,
which includes familiarizing yourself with Git usage, establishing an isolated environment, installing Xinference, and compiling the frontend.

Getting startted with Git
Getting started with Git
-------------------------

Now that you have identified an issue you wish to resolve, an enhancement to incorporate, or documentation to enhance,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ msgid ""
msgstr ""
"Project-Id-Version: Xinference \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2024-03-06 16:29+0800\n"
"POT-Creation-Date: 2024-03-21 09:59+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
Expand Down Expand Up @@ -38,7 +38,7 @@ msgstr ""
"Xinference 以及前端部分的编译。"

#: ../../source/development/contributing_environment.rst:12
msgid "Getting startted with Git"
msgid "Getting started with Git"
msgstr "Git 的使用"

#: ../../source/development/contributing_environment.rst:14
Expand Down
65 changes: 47 additions & 18 deletions xinference/api/oauth2/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from datetime import timedelta
from typing import List, Optional
from typing import List, Optional, Tuple

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, SecurityScopes
Expand Down Expand Up @@ -40,13 +41,30 @@ def __init__(self, auth_config_file: Optional[str]):
def config(self):
return self._config

@staticmethod
def is_legal_api_key(key: str) -> bool:
pattern = re.compile("^sk-[a-zA-Z0-9]{13}$")
return re.match(pattern, key) is not None

def init_auth_config(self):
if self._auth_config_file:
config: AuthStartupConfig = parse_file_as(
path=self._auth_config_file, type_=AuthStartupConfig
)
all_api_keys = set()
for user in config.user_config:
user.password = get_password_hash(user.password)
for api_key in user.api_keys:
if not self.is_legal_api_key(api_key):
raise ValueError(
"Api-Key should be a string started with 'sk-' with a total length of 16"
)
if api_key in all_api_keys:
raise ValueError(
"Duplicate api-keys exists, please check your configuration"
)
else:
all_api_keys.add(api_key)
return config

def __call__(
Expand All @@ -67,28 +85,30 @@ def __call__(
headers={"WWW-Authenticate": authenticate_value},
)

try:
assert self._config is not None
payload = jwt.decode(
token,
self._config.auth_config.secret_key,
algorithms=[self._config.auth_config.algorithm],
options={"verify_exp": False}, # TODO: supports token expiration
)
username: str = payload.get("sub")
if username is None:
if self.is_legal_api_key(token):
user, token_scopes = self.get_user_and_scopes_with_api_key(token)
else:
try:
assert self._config is not None
payload = jwt.decode(
token,
self._config.auth_config.secret_key,
algorithms=[self._config.auth_config.algorithm],
options={"verify_exp": False}, # TODO: supports token expiration
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_scopes = payload.get("scopes", [])
user = self.get_user(username)
except (JWTError, ValidationError):
raise credentials_exception
token_scopes = payload.get("scopes", [])
token_data = TokenData(scopes=token_scopes, username=username)
except (JWTError, ValidationError):
raise credentials_exception
user = self.get_user(token_data.username)
if user is None:
raise credentials_exception
if "admin" in token_data.scopes:
if "admin" in token_scopes:
return user
for scope in security_scopes.scopes:
if scope not in token_data.scopes:
if scope not in token_scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions",
Expand All @@ -102,6 +122,15 @@ def get_user(self, username: str) -> Optional[User]:
return user
return None

def get_user_and_scopes_with_api_key(
self, api_key: str
) -> Tuple[Optional[User], List]:
for user in self._config.user_config:
for key in user.api_keys:
if api_key == key:
return user, user.permissions
return None, []

def authenticate_user(self, username: str, password: str):
user = self.get_user(username)
if not user:
Expand Down
1 change: 1 addition & 0 deletions xinference/api/oauth2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class LoginUserForm(BaseModel):

class User(LoginUserForm):
permissions: List[str]
api_keys: List[str]


class AuthConfig(BaseModel):
Expand Down
6 changes: 4 additions & 2 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,13 @@ def translations(


class Client:
def __init__(self, base_url):
def __init__(self, base_url, api_key: Optional[str] = None):
self.base_url = base_url
self._headers = {}
self._headers: Dict[str, str] = {}
self._cluster_authed = False
self._check_cluster_authenticated()
if api_key is not None and self._cluster_authed:
self._headers["Authorization"] = f"Bearer {api_key}"

def _set_token(self, token: Optional[str]):
if not self._cluster_authed or token is None:
Expand Down
54 changes: 54 additions & 0 deletions xinference/client/tests/test_client_with_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,57 @@ def test_client_auth(setup_with_auth):
assert len(client.list_models()) == 1
client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0

# test with api-key
client = RESTfulClient(endpoint, api_key="sk-wrongapikey12")
with pytest.raises(RuntimeError):
client.list_models()

client = RESTfulClient(endpoint, api_key="sk-72tkvudyGLPMi")
assert len(client.list_models()) == 0

with pytest.raises(RuntimeError):
client.launch_model(model_name="bge-small-en-v1.5", model_type="embedding")

client = RESTfulClient(endpoint, api_key="sk-ZOTLIY4gt9w11")
model_uid = client.launch_model(
model_name="bge-small-en-v1.5", model_type="embedding"
)
model = client.get_model(model_uid=model_uid)
assert isinstance(model, RESTfulEmbeddingModelHandle)

completion = model.create_embedding("write a poem.")
assert len(completion["data"][0]["embedding"]) == 384

with pytest.raises(RuntimeError):
client.terminate_model(model_uid=model_uid)

client = RESTfulClient(endpoint, api_key="sk-3sjLbdwqAhhAF")
assert len(client.list_models()) == 1

# test with openai SDK
from openai import AuthenticationError, OpenAI, PermissionDeniedError

client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-wrongapikey12")
with pytest.raises(AuthenticationError):
client_ai.models.list()

client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-72tkvudyGLPMi")
assert len(client_ai.models.list().data) == 1
with pytest.raises(PermissionDeniedError):
chat_completion = client_ai.embeddings.create(
model="bge-small-en-v1.5",
input="write a poem.",
)

client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-ZOTLIY4gt9w11")
chat_completion = client_ai.embeddings.create(
model="bge-small-en-v1.5",
input="write a poem.",
)
assert len(chat_completion.data[0].embedding) == 384

client_ai = OpenAI(base_url=endpoint + "/v1", api_key="sk-3sjLbdwqAhhAF")
client.terminate_model(model_uid)
assert len(client.list_models()) == 0
assert len(client_ai.models.list().data) == 0
15 changes: 13 additions & 2 deletions xinference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,23 @@ def setup_with_auth():
if not cluster_health_check(supervisor_addr, max_attempts=10, sleep_interval=3):
raise RuntimeError("Cluster is not available after multiple attempts")

user1 = User(username="user1", password="pass1", permissions=["admin"])
user2 = User(username="user2", password="pass2", permissions=["models:list"])
user1 = User(
username="user1",
password="pass1",
permissions=["admin"],
api_keys=["sk-3sjLbdwqAhhAF", "sk-0HCRO1rauFQDL"],
)
user2 = User(
username="user2",
password="pass2",
permissions=["models:list"],
api_keys=["sk-72tkvudyGLPMi"],
)
user3 = User(
username="user3",
password="pass3",
permissions=["models:list", "models:read", "models:start"],
api_keys=["sk-m6jEzEwmCc4iQ", "sk-ZOTLIY4gt9w11"],
)
auth_config = AuthConfig(
algorithm="HS256",
Expand Down
Loading

0 comments on commit b44c2ce

Please sign in to comment.