Skip to content

Commit 3f24c62

Browse files
silehtmergify[bot]
authored andcommitted
fix(onpremise): user_tokens setup
On on-premise setup, we should not query the dashboard. Change-Id: Ic66e046455a59f711dcb8e58438b58fc149313e0
1 parent bac70e3 commit 3f24c62

File tree

4 files changed

+122
-50
lines changed

4 files changed

+122
-50
lines changed

mergify_engine/debug.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,17 @@ async def report(
232232
)
233233

234234
cached_tokens = await user_tokens.UserTokens.get(redis_cache, client.auth.owner_id)
235-
db_tokens = await user_tokens.UserTokens._retrieve_from_db(
236-
redis_cache, client.auth.owner_id
237-
)
235+
if issubclass(user_tokens.UserTokens, user_tokens.UserTokensGitHubCom):
236+
db_tokens = typing.cast(
237+
user_tokens.UserTokens,
238+
(
239+
await user_tokens.UserTokensGitHubCom._retrieve_from_db(
240+
redis_cache, client.auth.owner_id
241+
)
242+
),
243+
)
244+
else:
245+
db_tokens = cached_tokens
238246

239247
print(f"* SUBSCRIBED (cache/db): {cached_sub.active} / {db_sub.active}")
240248
print("* Features (cache):")

mergify_engine/tests/functional/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,9 @@ def github_aclient(owner_name=None, owner_id=None, auth=None):
445445
},
446446
],
447447
)
448-
await self.user_tokens.save_to_cache()
448+
await typing.cast(
449+
user_tokens.UserTokensGitHubCom, self.user_tokens
450+
).save_to_cache()
449451

450452
# Let's start recording
451453
cassette = self.recorder.use_cassette("http.json")
@@ -539,12 +541,12 @@ async def fake_user_tokens(redis_cache, owner_id):
539541
return user_tokens.UserTokens(redis_cache, owner_id, {})
540542

541543
mock.patch(
542-
"mergify_engine.user_tokens.UserTokens._retrieve_from_db",
544+
"mergify_engine.user_tokens.UserTokensGitHubCom._retrieve_from_db",
543545
side_effect=fake_retrieve_user_tokens_from_db,
544546
).start()
545547

546548
mock.patch(
547-
"mergify_engine.user_tokens.UserTokens.get",
549+
"mergify_engine.user_tokens.UserTokensGitHubCom.get",
548550
side_effect=fake_user_tokens,
549551
).start()
550552

mergify_engine/tests/unit/test_user_tokens.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ async def test_unknown_ut(redis_cache):
9999

100100
@pytest.mark.asyncio
101101
async def test_user_tokens_tokens_via_env(monkeypatch, redis_cache):
102-
ut = user_tokens.UserTokens(redis_cache, 123, [])
102+
ut = await user_tokens.UserTokensOnPremise.get(redis_cache, 123)
103103

104104
assert ut.get_token_for("foo") is None
105105
assert ut.get_token_for("login") is None
@@ -109,6 +109,7 @@ async def test_user_tokens_tokens_via_env(monkeypatch, redis_cache):
109109
config, "ACCOUNT_TOKENS", config.AccountTokens("foo:bar,login:token")
110110
)
111111

112+
ut = await user_tokens.UserTokensOnPremise.get(redis_cache, 123)
112113
assert ut.get_token_for("foo")["oauth_access_token"] == "bar"
113114
assert ut.get_token_for("login")["oauth_access_token"] == "token"
114115
assert ut.get_token_for("nop") is None

mergify_engine/user_tokens.py

Lines changed: 104 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,14 @@ class UserTokensUser(typing.TypedDict):
4545
email: typing.Optional[str]
4646

4747

48+
UserTokensT = typing.TypeVar("UserTokensT", bound="UserTokensBase")
49+
50+
4851
@dataclasses.dataclass
49-
class UserTokens:
52+
class UserTokensBase:
5053
redis: utils.RedisCache
5154
owner_id: int
5255
users: typing.List[UserTokensUser]
53-
ttl: int = -2
54-
55-
RETENTION_SECONDS = 60 * 60 * 24 * 3 # 3 days
56-
VALIDITY_SECONDS = 3600
5756

5857
@staticmethod
5958
async def select_users_for(
@@ -76,45 +75,64 @@ async def select_users_for(
7675
users = sorted(users, key=lambda x: x["login"] != ctxt.pull["user"]["login"])
7776
return users
7877

79-
@staticmethod
80-
def _cache_key(owner_id: int) -> str:
81-
return f"user-tokens-cache-owner-{owner_id}"
82-
83-
@staticmethod
84-
def _get_users_from_config() -> typing.List[UserTokensUser]:
85-
return [
86-
{
87-
"login": github_types.GitHubLogin(login),
88-
"oauth_access_token": github_types.GitHubOAuthToken(oauth_access_token),
89-
"email": None,
90-
"name": None,
91-
}
92-
for login, oauth_access_token in config.ACCOUNT_TOKENS.items()
93-
]
94-
9578
def get_token_for(
9679
self, wanted_login: github_types.GitHubLogin
9780
) -> typing.Optional[UserTokensUser]:
9881
wanted_login_lower = wanted_login.lower()
99-
for user in self.users + self._get_users_from_config():
82+
for user in self.users:
10083
if user["login"].lower() == wanted_login_lower:
10184
return user
10285
return None
10386

87+
@classmethod
88+
async def delete(
89+
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
90+
) -> None:
91+
raise NotImplementedError
92+
93+
@classmethod
94+
async def get(
95+
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
96+
) -> UserTokensT:
97+
raise NotImplementedError
98+
99+
100+
@dataclasses.dataclass
101+
class UserTokensGitHubCom(UserTokensBase):
102+
ttl: int = -2
103+
104+
RETENTION_SECONDS = 60 * 60 * 24 * 3 # 3 days
105+
VALIDITY_SECONDS = 3600
106+
107+
@staticmethod
108+
def _cache_key(owner_id: int) -> str:
109+
return f"user-tokens-cache-owner-{owner_id}"
110+
104111
async def _has_expired(self) -> bool:
105112
if self.ttl < 0: # not cached
106113
return True
107114
elapsed_since_stored = self.RETENTION_SECONDS - self.ttl
108115
return elapsed_since_stored > self.VALIDITY_SECONDS
109116

110117
@classmethod
111-
async def delete(cls, redis: utils.RedisCache, owner_id: int) -> None:
112-
await redis.delete(cls._cache_key(owner_id))
118+
async def delete(
119+
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
120+
) -> None:
121+
await redis.delete(typing.cast(UserTokensGitHubCom, cls)._cache_key(owner_id))
113122

114123
@classmethod
115-
async def get(cls, redis: utils.RedisCache, owner_id: int) -> "UserTokens":
116-
"""Get a tokens."""
124+
async def get(
125+
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
126+
) -> UserTokensT:
127+
return typing.cast(
128+
UserTokensT,
129+
await typing.cast(UserTokensGitHubCom, cls)._get(redis, owner_id),
130+
)
117131

132+
@classmethod
133+
async def _get(
134+
cls, redis: utils.RedisCache, owner_id: int
135+
) -> "UserTokensGitHubCom":
118136
cached_tokens = await cls._retrieve_from_cache(redis, owner_id)
119137
if cached_tokens is None or await cached_tokens._has_expired():
120138
try:
@@ -140,26 +158,10 @@ async def save_to_cache(self) -> None:
140158
)
141159
self.ttl = self.RETENTION_SECONDS
142160

143-
@classmethod
144-
async def _retrieve_from_db(
145-
cls, redis: utils.RedisCache, owner_id: int
146-
) -> "UserTokens":
147-
async with http.AsyncClient() as client:
148-
try:
149-
resp = await client.get(
150-
f"{config.SUBSCRIPTION_BASE_URL}/engine/user_tokens/{owner_id}",
151-
auth=(config.OAUTH_CLIENT_ID, config.OAUTH_CLIENT_SECRET),
152-
)
153-
except http.HTTPNotFound:
154-
return cls(redis, owner_id, [])
155-
else:
156-
tokens = resp.json()
157-
return cls(redis, owner_id, tokens["user_tokens"])
158-
159161
@classmethod
160162
async def _retrieve_from_cache(
161163
cls, redis: utils.RedisCache, owner_id: int
162-
) -> typing.Optional["UserTokens"]:
164+
) -> typing.Optional["UserTokensGitHubCom"]:
163165
async with await redis.pipeline() as pipe:
164166
await pipe.get(cls._cache_key(owner_id))
165167
await pipe.ttl(cls._cache_key(owner_id))
@@ -177,3 +179,62 @@ async def _retrieve_from_cache(
177179

178180
return cls(redis, owner_id, decrypted_tokens["user_tokens"], ttl)
179181
return None
182+
183+
@classmethod
184+
async def _retrieve_from_db(
185+
cls, redis: utils.RedisCache, owner_id: int
186+
) -> "UserTokensGitHubCom":
187+
async with http.AsyncClient() as client:
188+
try:
189+
resp = await client.get(
190+
f"{config.SUBSCRIPTION_BASE_URL}/engine/user_tokens/{owner_id}",
191+
auth=(config.OAUTH_CLIENT_ID, config.OAUTH_CLIENT_SECRET),
192+
)
193+
except http.HTTPNotFound:
194+
return cls(redis, owner_id, [])
195+
else:
196+
tokens = resp.json()
197+
return cls(redis, owner_id, tokens["user_tokens"])
198+
199+
200+
@dataclasses.dataclass
201+
class UserTokensOnPremise(UserTokensBase):
202+
@classmethod
203+
async def delete(
204+
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
205+
) -> None:
206+
pass
207+
208+
@classmethod
209+
async def get(
210+
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
211+
) -> UserTokensT:
212+
return cls(
213+
redis,
214+
owner_id,
215+
[
216+
{
217+
"login": github_types.GitHubLogin(login),
218+
"oauth_access_token": github_types.GitHubOAuthToken(
219+
oauth_access_token
220+
),
221+
"email": None,
222+
"name": None,
223+
}
224+
for login, oauth_access_token in config.ACCOUNT_TOKENS.items()
225+
],
226+
)
227+
228+
229+
if config.SUBSCRIPTION_TOKEN is not None:
230+
231+
@dataclasses.dataclass
232+
class UserTokens(UserTokensOnPremise):
233+
pass
234+
235+
236+
else:
237+
238+
@dataclasses.dataclass
239+
class UserTokens(UserTokensGitHubCom): # type: ignore [no-redef]
240+
pass

0 commit comments

Comments
 (0)