Skip to content

Commit 70f3d79

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: Add a credential refresher registry
PiperOrigin-RevId: 771637955
1 parent 55201cb commit 70f3d79

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Credential refresher registry."""
16+
17+
from __future__ import annotations
18+
19+
from typing import Dict
20+
from typing import Optional
21+
22+
from google.adk.auth.auth_credential import AuthCredentialTypes
23+
from google.adk.utils.feature_decorator import experimental
24+
25+
from .base_credential_refresher import BaseCredentialRefresher
26+
27+
28+
@experimental
29+
class CredentialRefresherRegistry:
30+
"""Registry for credential refresher instances."""
31+
32+
def __init__(self):
33+
self._refreshers: Dict[AuthCredentialTypes, BaseCredentialRefresher] = {}
34+
35+
def register(
36+
self,
37+
credential_type: AuthCredentialTypes,
38+
refresher_instance: BaseCredentialRefresher,
39+
) -> None:
40+
"""Register a refresher instance for a credential type.
41+
42+
Args:
43+
credential_type: The credential type to register for.
44+
refresher_instance: The refresher instance to register.
45+
"""
46+
self._refreshers[credential_type] = refresher_instance
47+
48+
def get_refresher(
49+
self, credential_type: AuthCredentialTypes
50+
) -> Optional[BaseCredentialRefresher]:
51+
"""Get the refresher instance for a credential type.
52+
53+
Args:
54+
credential_type: The credential type to get refresher for.
55+
56+
Returns:
57+
The refresher instance if registered, None otherwise.
58+
"""
59+
return self._refreshers.get(credential_type)
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for CredentialRefresherRegistry."""
16+
17+
from unittest.mock import Mock
18+
19+
from google.adk.auth.auth_credential import AuthCredentialTypes
20+
from google.adk.auth.refresher.base_credential_refresher import BaseCredentialRefresher
21+
from google.adk.auth.refresher.credential_refresher_registry import CredentialRefresherRegistry
22+
23+
24+
class TestCredentialRefresherRegistry:
25+
"""Tests for the CredentialRefresherRegistry class."""
26+
27+
def test_init(self):
28+
"""Test that registry initializes with empty refreshers dictionary."""
29+
registry = CredentialRefresherRegistry()
30+
assert registry._refreshers == {}
31+
32+
def test_register_refresher(self):
33+
"""Test registering a refresher instance for a credential type."""
34+
registry = CredentialRefresherRegistry()
35+
mock_refresher = Mock(spec=BaseCredentialRefresher)
36+
37+
registry.register(AuthCredentialTypes.OAUTH2, mock_refresher)
38+
39+
assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher
40+
41+
def test_register_multiple_refreshers(self):
42+
"""Test registering multiple refresher instances for different credential types."""
43+
registry = CredentialRefresherRegistry()
44+
mock_oauth2_refresher = Mock(spec=BaseCredentialRefresher)
45+
mock_openid_refresher = Mock(spec=BaseCredentialRefresher)
46+
mock_service_account_refresher = Mock(spec=BaseCredentialRefresher)
47+
48+
registry.register(AuthCredentialTypes.OAUTH2, mock_oauth2_refresher)
49+
registry.register(
50+
AuthCredentialTypes.OPEN_ID_CONNECT, mock_openid_refresher
51+
)
52+
registry.register(
53+
AuthCredentialTypes.SERVICE_ACCOUNT, mock_service_account_refresher
54+
)
55+
56+
assert (
57+
registry._refreshers[AuthCredentialTypes.OAUTH2]
58+
== mock_oauth2_refresher
59+
)
60+
assert (
61+
registry._refreshers[AuthCredentialTypes.OPEN_ID_CONNECT]
62+
== mock_openid_refresher
63+
)
64+
assert (
65+
registry._refreshers[AuthCredentialTypes.SERVICE_ACCOUNT]
66+
== mock_service_account_refresher
67+
)
68+
69+
def test_register_overwrite_existing_refresher(self):
70+
"""Test that registering a refresher overwrites an existing one for the same credential type."""
71+
registry = CredentialRefresherRegistry()
72+
mock_refresher_1 = Mock(spec=BaseCredentialRefresher)
73+
mock_refresher_2 = Mock(spec=BaseCredentialRefresher)
74+
75+
# Register first refresher
76+
registry.register(AuthCredentialTypes.OAUTH2, mock_refresher_1)
77+
assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher_1
78+
79+
# Register second refresher for same credential type
80+
registry.register(AuthCredentialTypes.OAUTH2, mock_refresher_2)
81+
assert registry._refreshers[AuthCredentialTypes.OAUTH2] == mock_refresher_2
82+
83+
def test_get_refresher_existing(self):
84+
"""Test getting a refresher instance for a registered credential type."""
85+
registry = CredentialRefresherRegistry()
86+
mock_refresher = Mock(spec=BaseCredentialRefresher)
87+
88+
registry.register(AuthCredentialTypes.OAUTH2, mock_refresher)
89+
result = registry.get_refresher(AuthCredentialTypes.OAUTH2)
90+
91+
assert result == mock_refresher
92+
93+
def test_get_refresher_non_existing(self):
94+
"""Test getting a refresher instance for a non-registered credential type returns None."""
95+
registry = CredentialRefresherRegistry()
96+
97+
result = registry.get_refresher(AuthCredentialTypes.OAUTH2)
98+
99+
assert result is None
100+
101+
def test_get_refresher_after_registration(self):
102+
"""Test getting refresher instances for multiple credential types."""
103+
registry = CredentialRefresherRegistry()
104+
mock_oauth2_refresher = Mock(spec=BaseCredentialRefresher)
105+
mock_api_key_refresher = Mock(spec=BaseCredentialRefresher)
106+
107+
registry.register(AuthCredentialTypes.OAUTH2, mock_oauth2_refresher)
108+
registry.register(AuthCredentialTypes.API_KEY, mock_api_key_refresher)
109+
110+
# Get registered refreshers
111+
oauth2_result = registry.get_refresher(AuthCredentialTypes.OAUTH2)
112+
api_key_result = registry.get_refresher(AuthCredentialTypes.API_KEY)
113+
114+
assert oauth2_result == mock_oauth2_refresher
115+
assert api_key_result == mock_api_key_refresher
116+
117+
# Get non-registered refresher
118+
http_result = registry.get_refresher(AuthCredentialTypes.HTTP)
119+
assert http_result is None
120+
121+
def test_register_all_credential_types(self):
122+
"""Test registering refreshers for all available credential types."""
123+
registry = CredentialRefresherRegistry()
124+
125+
refreshers = {}
126+
for credential_type in AuthCredentialTypes:
127+
mock_refresher = Mock(spec=BaseCredentialRefresher)
128+
refreshers[credential_type] = mock_refresher
129+
registry.register(credential_type, mock_refresher)
130+
131+
# Verify all refreshers are registered correctly
132+
for credential_type in AuthCredentialTypes:
133+
result = registry.get_refresher(credential_type)
134+
assert result == refreshers[credential_type]
135+
136+
def test_empty_registry_get_refresher(self):
137+
"""Test getting refresher from empty registry returns None for any credential type."""
138+
registry = CredentialRefresherRegistry()
139+
140+
for credential_type in AuthCredentialTypes:
141+
result = registry.get_refresher(credential_type)
142+
assert result is None
143+
144+
def test_registry_independence(self):
145+
"""Test that multiple registry instances are independent."""
146+
registry1 = CredentialRefresherRegistry()
147+
registry2 = CredentialRefresherRegistry()
148+
149+
mock_refresher1 = Mock(spec=BaseCredentialRefresher)
150+
mock_refresher2 = Mock(spec=BaseCredentialRefresher)
151+
152+
registry1.register(AuthCredentialTypes.OAUTH2, mock_refresher1)
153+
registry2.register(AuthCredentialTypes.OAUTH2, mock_refresher2)
154+
155+
# Verify registries are independent
156+
assert (
157+
registry1.get_refresher(AuthCredentialTypes.OAUTH2) == mock_refresher1
158+
)
159+
assert (
160+
registry2.get_refresher(AuthCredentialTypes.OAUTH2) == mock_refresher2
161+
)
162+
assert registry1.get_refresher(
163+
AuthCredentialTypes.OAUTH2
164+
) != registry2.get_refresher(AuthCredentialTypes.OAUTH2)
165+
166+
def test_register_with_none_refresher(self):
167+
"""Test registering None as a refresher instance."""
168+
registry = CredentialRefresherRegistry()
169+
170+
# This should technically work as the registry accepts any value
171+
registry.register(AuthCredentialTypes.OAUTH2, None)
172+
result = registry.get_refresher(AuthCredentialTypes.OAUTH2)
173+
174+
assert result is None

0 commit comments

Comments
 (0)