Skip to content

Commit

Permalink
fix(onpremise): user_tokens setup
Browse files Browse the repository at this point in the history
On on-premise setup, we should not query the dashboard.

Change-Id: Ic66e046455a59f711dcb8e58438b58fc149313e0
  • Loading branch information
sileht authored and mergify[bot] committed Aug 10, 2021
1 parent bac70e3 commit 3f24c62
Showing 4 changed files with 122 additions and 50 deletions.
14 changes: 11 additions & 3 deletions mergify_engine/debug.py
Original file line number Diff line number Diff line change
@@ -232,9 +232,17 @@ async def report(
)

cached_tokens = await user_tokens.UserTokens.get(redis_cache, client.auth.owner_id)
db_tokens = await user_tokens.UserTokens._retrieve_from_db(
redis_cache, client.auth.owner_id
)
if issubclass(user_tokens.UserTokens, user_tokens.UserTokensGitHubCom):
db_tokens = typing.cast(
user_tokens.UserTokens,
(
await user_tokens.UserTokensGitHubCom._retrieve_from_db(
redis_cache, client.auth.owner_id
)
),
)
else:
db_tokens = cached_tokens

print(f"* SUBSCRIBED (cache/db): {cached_sub.active} / {db_sub.active}")
print("* Features (cache):")
8 changes: 5 additions & 3 deletions mergify_engine/tests/functional/base.py
Original file line number Diff line number Diff line change
@@ -445,7 +445,9 @@ def github_aclient(owner_name=None, owner_id=None, auth=None):
},
],
)
await self.user_tokens.save_to_cache()
await typing.cast(
user_tokens.UserTokensGitHubCom, self.user_tokens
).save_to_cache()

# Let's start recording
cassette = self.recorder.use_cassette("http.json")
@@ -539,12 +541,12 @@ async def fake_user_tokens(redis_cache, owner_id):
return user_tokens.UserTokens(redis_cache, owner_id, {})

mock.patch(
"mergify_engine.user_tokens.UserTokens._retrieve_from_db",
"mergify_engine.user_tokens.UserTokensGitHubCom._retrieve_from_db",
side_effect=fake_retrieve_user_tokens_from_db,
).start()

mock.patch(
"mergify_engine.user_tokens.UserTokens.get",
"mergify_engine.user_tokens.UserTokensGitHubCom.get",
side_effect=fake_user_tokens,
).start()

3 changes: 2 additions & 1 deletion mergify_engine/tests/unit/test_user_tokens.py
Original file line number Diff line number Diff line change
@@ -99,7 +99,7 @@ async def test_unknown_ut(redis_cache):

@pytest.mark.asyncio
async def test_user_tokens_tokens_via_env(monkeypatch, redis_cache):
ut = user_tokens.UserTokens(redis_cache, 123, [])
ut = await user_tokens.UserTokensOnPremise.get(redis_cache, 123)

assert ut.get_token_for("foo") is None
assert ut.get_token_for("login") is None
@@ -109,6 +109,7 @@ async def test_user_tokens_tokens_via_env(monkeypatch, redis_cache):
config, "ACCOUNT_TOKENS", config.AccountTokens("foo:bar,login:token")
)

ut = await user_tokens.UserTokensOnPremise.get(redis_cache, 123)
assert ut.get_token_for("foo")["oauth_access_token"] == "bar"
assert ut.get_token_for("login")["oauth_access_token"] == "token"
assert ut.get_token_for("nop") is None
147 changes: 104 additions & 43 deletions mergify_engine/user_tokens.py
Original file line number Diff line number Diff line change
@@ -45,15 +45,14 @@ class UserTokensUser(typing.TypedDict):
email: typing.Optional[str]


UserTokensT = typing.TypeVar("UserTokensT", bound="UserTokensBase")


@dataclasses.dataclass
class UserTokens:
class UserTokensBase:
redis: utils.RedisCache
owner_id: int
users: typing.List[UserTokensUser]
ttl: int = -2

RETENTION_SECONDS = 60 * 60 * 24 * 3 # 3 days
VALIDITY_SECONDS = 3600

@staticmethod
async def select_users_for(
@@ -76,45 +75,64 @@ async def select_users_for(
users = sorted(users, key=lambda x: x["login"] != ctxt.pull["user"]["login"])
return users

@staticmethod
def _cache_key(owner_id: int) -> str:
return f"user-tokens-cache-owner-{owner_id}"

@staticmethod
def _get_users_from_config() -> typing.List[UserTokensUser]:
return [
{
"login": github_types.GitHubLogin(login),
"oauth_access_token": github_types.GitHubOAuthToken(oauth_access_token),
"email": None,
"name": None,
}
for login, oauth_access_token in config.ACCOUNT_TOKENS.items()
]

def get_token_for(
self, wanted_login: github_types.GitHubLogin
) -> typing.Optional[UserTokensUser]:
wanted_login_lower = wanted_login.lower()
for user in self.users + self._get_users_from_config():
for user in self.users:
if user["login"].lower() == wanted_login_lower:
return user
return None

@classmethod
async def delete(
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
) -> None:
raise NotImplementedError

@classmethod
async def get(
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
) -> UserTokensT:
raise NotImplementedError


@dataclasses.dataclass
class UserTokensGitHubCom(UserTokensBase):
ttl: int = -2

RETENTION_SECONDS = 60 * 60 * 24 * 3 # 3 days
VALIDITY_SECONDS = 3600

@staticmethod
def _cache_key(owner_id: int) -> str:
return f"user-tokens-cache-owner-{owner_id}"

async def _has_expired(self) -> bool:
if self.ttl < 0: # not cached
return True
elapsed_since_stored = self.RETENTION_SECONDS - self.ttl
return elapsed_since_stored > self.VALIDITY_SECONDS

@classmethod
async def delete(cls, redis: utils.RedisCache, owner_id: int) -> None:
await redis.delete(cls._cache_key(owner_id))
async def delete(
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
) -> None:
await redis.delete(typing.cast(UserTokensGitHubCom, cls)._cache_key(owner_id))

@classmethod
async def get(cls, redis: utils.RedisCache, owner_id: int) -> "UserTokens":
"""Get a tokens."""
async def get(
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
) -> UserTokensT:
return typing.cast(
UserTokensT,
await typing.cast(UserTokensGitHubCom, cls)._get(redis, owner_id),
)

@classmethod
async def _get(
cls, redis: utils.RedisCache, owner_id: int
) -> "UserTokensGitHubCom":
cached_tokens = await cls._retrieve_from_cache(redis, owner_id)
if cached_tokens is None or await cached_tokens._has_expired():
try:
@@ -140,26 +158,10 @@ async def save_to_cache(self) -> None:
)
self.ttl = self.RETENTION_SECONDS

@classmethod
async def _retrieve_from_db(
cls, redis: utils.RedisCache, owner_id: int
) -> "UserTokens":
async with http.AsyncClient() as client:
try:
resp = await client.get(
f"{config.SUBSCRIPTION_BASE_URL}/engine/user_tokens/{owner_id}",
auth=(config.OAUTH_CLIENT_ID, config.OAUTH_CLIENT_SECRET),
)
except http.HTTPNotFound:
return cls(redis, owner_id, [])
else:
tokens = resp.json()
return cls(redis, owner_id, tokens["user_tokens"])

@classmethod
async def _retrieve_from_cache(
cls, redis: utils.RedisCache, owner_id: int
) -> typing.Optional["UserTokens"]:
) -> typing.Optional["UserTokensGitHubCom"]:
async with await redis.pipeline() as pipe:
await pipe.get(cls._cache_key(owner_id))
await pipe.ttl(cls._cache_key(owner_id))
@@ -177,3 +179,62 @@ async def _retrieve_from_cache(

return cls(redis, owner_id, decrypted_tokens["user_tokens"], ttl)
return None

@classmethod
async def _retrieve_from_db(
cls, redis: utils.RedisCache, owner_id: int
) -> "UserTokensGitHubCom":
async with http.AsyncClient() as client:
try:
resp = await client.get(
f"{config.SUBSCRIPTION_BASE_URL}/engine/user_tokens/{owner_id}",
auth=(config.OAUTH_CLIENT_ID, config.OAUTH_CLIENT_SECRET),
)
except http.HTTPNotFound:
return cls(redis, owner_id, [])
else:
tokens = resp.json()
return cls(redis, owner_id, tokens["user_tokens"])


@dataclasses.dataclass
class UserTokensOnPremise(UserTokensBase):
@classmethod
async def delete(
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
) -> None:
pass

@classmethod
async def get(
cls: typing.Type[UserTokensT], redis: utils.RedisCache, owner_id: int
) -> UserTokensT:
return cls(
redis,
owner_id,
[
{
"login": github_types.GitHubLogin(login),
"oauth_access_token": github_types.GitHubOAuthToken(
oauth_access_token
),
"email": None,
"name": None,
}
for login, oauth_access_token in config.ACCOUNT_TOKENS.items()
],
)


if config.SUBSCRIPTION_TOKEN is not None:

@dataclasses.dataclass
class UserTokens(UserTokensOnPremise):
pass


else:

@dataclasses.dataclass
class UserTokens(UserTokensGitHubCom): # type: ignore [no-redef]
pass

0 comments on commit 3f24c62

Please sign in to comment.