From 8969a1063d4848a651de11bacf66c18f11a620b4 Mon Sep 17 00:00:00 2001 From: Dirk Kulawiak Date: Thu, 3 Apr 2025 12:00:19 +0200 Subject: [PATCH 1/3] Fix kwargs vs args --- weaviate/users/sync.pyi | 12 ++++++------ weaviate/users/users.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/weaviate/users/sync.pyi b/weaviate/users/sync.pyi index f93915e96..240a790ce 100644 --- a/weaviate/users/sync.pyi +++ b/weaviate/users/sync.pyi @@ -7,15 +7,15 @@ from typing_extensions import deprecated class _UsersOIDC(_UsersInit): @overload def get_assigned_roles( - self, user_id: str, include_permissions: Literal[False] = False + self, *, user_id: str, include_permissions: Literal[False] = False ) -> Dict[str, RoleBase]: ... @overload def get_assigned_roles( - self, user_id: str, include_permissions: Literal[True] + self, *, user_id: str, include_permissions: Literal[True] ) -> Dict[str, Role]: ... @overload def get_assigned_roles( - self, user_id: str, include_permissions: bool = False + self, *, user_id: str, include_permissions: bool = False ) -> Union[Dict[str, Role], Dict[str, RoleBase]]: ... def assign_roles(self, *, user_id: str, role_names: Union[str, List[str]]) -> None: ... def revoke_roles(self, *, user_id: str, role_names: Union[str, List[str]]) -> None: ... @@ -23,15 +23,15 @@ class _UsersOIDC(_UsersInit): class _UsersDB(_UsersInit): @overload def get_assigned_roles( - self, user_id: str, include_permissions: Literal[False] = False + self, *, user_id: str, include_permissions: Literal[False] = False ) -> Dict[str, RoleBase]: ... @overload def get_assigned_roles( - self, user_id: str, include_permissions: Literal[True] + self, *, user_id: str, include_permissions: Literal[True] ) -> Dict[str, Role]: ... @overload def get_assigned_roles( - self, user_id: str, include_permissions: bool = False + self, *, user_id: str, include_permissions: bool = False ) -> Union[Dict[str, Role], Dict[str, RoleBase]]: ... def assign_roles(self, *, user_id: str, role_names: Union[str, List[str]]) -> None: ... def revoke_roles(self, *, user_id: str, role_names: Union[str, List[str]]) -> None: ... diff --git a/weaviate/users/users.py b/weaviate/users/users.py index 6b31560f1..64e7ed27e 100644 --- a/weaviate/users/users.py +++ b/weaviate/users/users.py @@ -173,16 +173,16 @@ async def _list_all_users(self) -> List[WeaviateDBUserRoleNames]: class _UserDBAsync(_UsersBase): @overload async def get_assigned_roles( - self, user_id: str, include_permissions: Literal[False] = ... + self, *, user_id: str, include_permissions: Literal[False] = ... ) -> Dict[str, RoleBase]: ... @overload async def get_assigned_roles( - self, user_id: str, include_permissions: Literal[True] = ... + self, *, user_id: str, include_permissions: Literal[True] = ... ) -> Dict[str, Role]: ... async def get_assigned_roles( - self, user_id: str, include_permissions: bool = False + self, *, user_id: str, include_permissions: bool = False ) -> Union[Dict[str, Role], Dict[str, RoleBase]]: """Get the roles assigned to a user. @@ -292,16 +292,16 @@ async def list_all(self) -> List[UserDB]: class _UserOIDCAsync(_UsersBase): @overload async def get_assigned_roles( - self, user_id: str, include_permissions: Literal[False] = ... + self, *, user_id: str, include_permissions: Literal[False] = ... ) -> Dict[str, RoleBase]: ... @overload async def get_assigned_roles( - self, user_id: str, include_permissions: Literal[True] = ... + self, *, user_id: str, include_permissions: Literal[True] = ... ) -> Dict[str, Role]: ... async def get_assigned_roles( - self, user_id: str, include_permissions: bool = False + self, *, user_id: str, include_permissions: bool = False ) -> Union[Dict[str, Role], Dict[str, RoleBase]]: """Get the roles assigned to a user. From f5ed8416f81a972a7066edfafc75cefd2335925a Mon Sep 17 00:00:00 2001 From: Dirk Kulawiak Date: Thu, 3 Apr 2025 12:14:20 +0200 Subject: [PATCH 2/3] Add missing revoke key --- integration/test_users.py | 36 ++++++++++++++++++++++++++++++++++++ weaviate/users/sync.pyi | 2 +- weaviate/users/users.py | 9 +++++---- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/integration/test_users.py b/integration/test_users.py index 602f14f2d..a8a69e93a 100644 --- a/integration/test_users.py +++ b/integration/test_users.py @@ -154,6 +154,42 @@ def test_de_activate(client_factory: ClientFactory) -> None: client.users.db.delete(user_id=randomUserName) +def test_deactivate_and_revoke(client_factory: ClientFactory) -> None: + with client_factory(ports=RBAC_PORTS, auth_credentials=Auth.api_key("admin-key")) as client: + if client._connection._weaviate_version.is_lower_than(1, 30, 0): + pytest.skip("This test requires Weaviate 1.30.0 or higher") + + randomUserName = "new-user" + str(random.randint(1, 1000)) + apiKeyOld = client.users.db.create(user_id=randomUserName) + assert client.users.db.deactivate(user_id=randomUserName, revoke_key=True) + + with pytest.raises(weaviate.exceptions.UnexpectedStatusCodeError): + weaviate.connect_to_local( + port=RBAC_PORTS[0], + grpc_port=RBAC_PORTS[1], + auth_credentials=Auth.api_key(apiKeyOld), + ) + + # re-activating is not enough + assert client.users.db.activate(user_id=randomUserName) + with pytest.raises(weaviate.exceptions.UnexpectedStatusCodeError): + weaviate.connect_to_local( + port=RBAC_PORTS[0], + grpc_port=RBAC_PORTS[1], + auth_credentials=Auth.api_key(apiKeyOld), + ) + + apiKeyNew = client.users.db.rotate_key(user_id=randomUserName) + + with weaviate.connect_to_local( + port=RBAC_PORTS[0], grpc_port=RBAC_PORTS[1], auth_credentials=Auth.api_key(apiKeyNew) + ) as client2: + user = client2.users.get_my_user() + assert user.user_id == randomUserName + + client.users.db.delete(user_id=randomUserName) + + def test_deprecated_syntax(client_factory: ClientFactory) -> None: with client_factory(ports=RBAC_PORTS, auth_credentials=Auth.api_key("admin-key")) as client: if client._connection._weaviate_version.is_lower_than(1, 30, 0): diff --git a/weaviate/users/sync.pyi b/weaviate/users/sync.pyi index 240a790ce..8ac7f9e43 100644 --- a/weaviate/users/sync.pyi +++ b/weaviate/users/sync.pyi @@ -38,7 +38,7 @@ class _UsersDB(_UsersInit): def create(self, *, user_id: str) -> str: ... def delete(self, *, user_id: str) -> bool: ... def rotate_key(self, *, user_id: str) -> str: ... - def deactivate(self, *, user_id: str) -> bool: ... + def deactivate(self, *, user_id: str, revoke_key: bool = False) -> bool: ... def activate(self, *, user_id: str) -> bool: ... def get(self, *, user_id: str) -> UserDB: ... def list_all(self) -> List[UserDB]: ... diff --git a/weaviate/users/users.py b/weaviate/users/users.py index 64e7ed27e..c7b720bd8 100644 --- a/weaviate/users/users.py +++ b/weaviate/users/users.py @@ -127,11 +127,11 @@ async def _rotate_key(self, user_id: str) -> str: assert resp_typed is not None return str(resp_typed["apikey"]) - async def _deactivate(self, user_id: str) -> bool: + async def _deactivate(self, user_id: str, revoke_key: bool) -> bool: path = f"/users/db/{user_id}/deactivate" resp = await self._connection.post( path, - weaviate_object={}, + weaviate_object={"revoke_key": revoke_key}, error_msg=f"Could not deactivate user '{user_id}'", status_codes=_ExpectedStatusCodes(ok_in=[200, 409], error="deactivate key"), ) @@ -251,13 +251,14 @@ async def activate(self, *, user_id: str) -> bool: """ return await self._activate(user_id) - async def deactivate(self, *, user_id: str) -> bool: + async def deactivate(self, *, user_id: str, revoke_key: bool = False) -> bool: """Deactivate an active user. Args: user_id: The id of the user. + revoke_key: If True, the old key will be revoked and needs to be rotated. """ - return await self._deactivate(user_id) + return await self._deactivate(user_id, revoke_key) async def get(self, *, user_id: str) -> UserDB: """Get all information about an user. From 8e9ac982eab6aa3228d8797069bcbeb2da9a4e59 Mon Sep 17 00:00:00 2001 From: Tommy Smith Date: Thu, 3 Apr 2025 15:23:34 +0100 Subject: [PATCH 3/3] Fix tests using new kwarg syntax --- integration/test_users.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/integration/test_users.py b/integration/test_users.py index a8a69e93a..6cecf43e6 100644 --- a/integration/test_users.py +++ b/integration/test_users.py @@ -34,12 +34,12 @@ def test_get_user_roles_db(client_factory: ClientFactory) -> None: with client_factory(ports=RBAC_PORTS, auth_credentials=RBAC_AUTH_CREDS) as client: if client._connection._weaviate_version.is_lower_than(1, 30, 0): pytest.skip("This test requires Weaviate 1.30.0 or higher") - roles_base = client.users.db.get_assigned_roles("admin-user") + roles_base = client.users.db.get_assigned_roles(user_id="admin-user") names = list(roles_base.keys()) assert len(roles_base) > 0 assert isinstance(roles_base[names[0]], RoleBase) - roles = client.users.db.get_assigned_roles("admin-user", include_permissions=True) + roles = client.users.db.get_assigned_roles(user_id="admin-user", include_permissions=True) assert len(roles) > 0 assert isinstance(roles[names[0]], Role) @@ -48,12 +48,12 @@ def test_get_user_roles_oidc(client_factory: ClientFactory) -> None: with client_factory(ports=RBAC_PORTS, auth_credentials=RBAC_AUTH_CREDS) as client: if client._connection._weaviate_version.is_lower_than(1, 30, 0): pytest.skip("This test requires Weaviate 1.30.0 or higher") - roles_base = client.users.oidc.get_assigned_roles("admin-user") + roles_base = client.users.oidc.get_assigned_roles(user_id="admin-user") names = list(roles_base.keys()) assert len(roles_base) > 0 assert isinstance(roles_base[names[0]], RoleBase) - roles = client.users.oidc.get_assigned_roles("admin-user", include_permissions=True) + roles = client.users.oidc.get_assigned_roles(user_id="admin-user", include_permissions=True) assert len(roles) > 0 assert isinstance(roles[names[0]], Role)