Skip to content

Commit

Permalink
Refactor OAuthManager to inherit from base Daemon class; rucio#6478
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Feb 6, 2024
1 parent 9ac49a0 commit 70bb68e
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 158 deletions.
15 changes: 8 additions & 7 deletions bin/rucio-oauth-manager
Expand Up @@ -24,7 +24,7 @@ OAuth Manager is a daemon which is reponsible for:
import argparse
import signal

from rucio.daemons.oauthmanager.oauthmanager import run, stop
from rucio.daemons.oauthmanager.oauthmanager import OAuthManager


def get_parser():
Expand Down Expand Up @@ -54,11 +54,12 @@ can be specified by 'maxrows' parameter.


if __name__ == "__main__":
signal.signal(signal.SIGTERM, stop)
PARSER = get_parser()
ARGS = PARSER.parse_args()
parser = get_parser()
args = parser.parse_args()
oauthmanager = OAuthManager(once=args.run_once, max_rows=args.max_rows, total_workers=args.threads,
sleep_time=args.sleep_time)
signal.signal(signal.SIGTERM, oauthmanager.stop)
try:
run(once=ARGS.run_once, max_rows=ARGS.max_rows, threads=ARGS.threads,
sleep_time=ARGS.sleep_time)
oauthmanager.run()
except KeyboardInterrupt:
stop()
oauthmanager.stop()
241 changes: 95 additions & 146 deletions lib/rucio/daemons/oauthmanager/oauthmanager.py
Expand Up @@ -26,175 +26,124 @@
"""

import functools

import logging
import threading
import traceback
from re import match
from typing import TYPE_CHECKING
from typing import Any
from rucio.db.sqla.constants import ORACLE_CONNECTION_LOST_CONTACT_REGEX

from sqlalchemy.exc import DatabaseError

import rucio.db.sqla.util
from rucio.common.exception import DatabaseException
from rucio.common.logging import setup_logging
from rucio.common.stopwatch import Stopwatch
from rucio.core.authentication import delete_expired_tokens
from rucio.core.monitor import MetricManager
from rucio.core.oidc import delete_expired_oauthrequests, refresh_jwt_tokens
from rucio.daemons.common import HeartbeatHandler
from rucio.daemons.common import run_daemon

if TYPE_CHECKING:
from types import FrameType
from typing import Optional
from rucio.daemons.common import Daemon, HeartbeatHandler

METRICS = MetricManager(module=__name__)
graceful_stop = threading.Event()
DAEMON_NAME = 'oauth-manager'


def OAuthManager(once: bool = False, max_rows: int = 100, sleep_time: int = 300) -> None:
class OAuthManager(Daemon):
"""
Main loop to delete all expired tokens, refresh tokens eligible
Daemon to delete all expired tokens, refresh tokens eligible
for refresh and delete all expired OAuth session parameters.
It was decided to have only 1 daemon for all 3 of these cleanup activities.
"""
def __init__(self, sleep_time: int = 300, max_rows: int = 100, **_kwargs) -> None:
"""
:param max_rows: Max number of DB rows to deal with per operation.
"""
super().__init__(daemon_name="oauth-manager", sleep_time=sleep_time, **_kwargs)
self.max_rows = max_rows
self.paused_dids = {}

:param once: If True, the loop is run just once, otherwise the daemon continues looping until stopped.
:param max_rows: Max number of DB rows to deal with per operation.
:param sleep_time: The number of seconds the daemon will wait before running next loop of operations.
def _run_once(self, heartbeat_handler: "HeartbeatHandler", **_kwargs) -> tuple[bool, Any]:
# make an initial heartbeat
heartbeat_handler.live()

:returns: None
"""
must_sleep = False

run_daemon(
once=once,
graceful_stop=graceful_stop,
executable=DAEMON_NAME,
partition_wait_time=1,
sleep_time=sleep_time,
run_once_fnc=functools.partial(
run_once,
max_rows=max_rows,
sleep_time=sleep_time
),
)


def run_once(heartbeat_handler: HeartbeatHandler, max_rows: int, sleep_time: int, **_kwargs) -> None:

# make an initial heartbeat
heartbeat_handler.live()

stopwatch = Stopwatch()

ndeleted = 0
ndeletedreq = 0
nrefreshed = 0

# make a heartbeat
worker_number, total_workers, logger = heartbeat_handler.live()
try:
# ACCESS TOKEN REFRESH - better to run first (in case some of the refreshed tokens needed deletion after this step)
logger(logging.INFO, '----- START ----- ACCESS TOKEN REFRESH ----- ')
logger(logging.INFO, 'starting to query tokens for automatic refresh')
nrefreshed = refresh_jwt_tokens(total_workers, worker_number, refreshrate=int(sleep_time), limit=max_rows)
logger(logging.INFO, 'successfully refreshed %i tokens', nrefreshed)
logger(logging.INFO, '----- END ----- ACCESS TOKEN REFRESH ----- ')
METRICS.counter(name='oauth_manager.tokens.refreshed').inc(nrefreshed)

except (DatabaseException, DatabaseError) as err:
if match('.*QueuePool.*', str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
elif match(ORACLE_CONNECTION_LOST_CONTACT_REGEX, str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
else:
logger(logging.CRITICAL, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()

try:
# waiting 1 sec as DBs does not store milisecond and tokens
# eligible for deletion after refresh might not get deleted otherwise
graceful_stop.wait(1)
stopwatch = Stopwatch()

# make a heartbeat
worker_number, total_workers, logger = heartbeat_handler.live()
ndeleted = 0
ndeletedreq = 0
nrefreshed = 0

# EXPIRED TOKEN DELETION
logger(logging.INFO, '----- START ----- DELETION OF EXPIRED TOKENS ----- ')
logger(logging.INFO, 'starting to delete expired tokens')
ndeleted += delete_expired_tokens(total_workers, worker_number, limit=max_rows)
logger(logging.INFO, 'deleted %i expired tokens', ndeleted)
logger(logging.INFO, '----- END ----- DELETION OF EXPIRED TOKENS ----- ')
METRICS.counter(name='oauth_manager.tokens.deleted').inc(ndeleted)

except (DatabaseException, DatabaseError) as err:
if match('.*QueuePool.*', str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
elif match(ORACLE_CONNECTION_LOST_CONTACT_REGEX, str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
else:
logger(logging.CRITICAL, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()

try:
# make a heartbeat
worker_number, total_workers, logger = heartbeat_handler.live()

# DELETING EXPIRED OAUTH SESSION PARAMETERS
logger(logging.INFO, '----- START ----- DELETION OF EXPIRED OAUTH SESSION REQUESTS ----- ')
logger(logging.INFO, 'starting deletion of expired OAuth session requests')
ndeletedreq += delete_expired_oauthrequests(total_workers, worker_number, limit=max_rows)
logger(logging.INFO, 'expired parameters of %i authentication requests were deleted', ndeletedreq)
logger(logging.INFO, '----- END ----- DELETION OF EXPIRED OAUTH SESSION REQUESTS ----- ')
METRICS.counter(name='oauth_manager.oauthreq.deleted').inc(ndeletedreq)

except (DatabaseException, DatabaseError) as err:
if match('.*QueuePool.*', str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
elif match(ORACLE_CONNECTION_LOST_CONTACT_REGEX, str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
else:
logger(logging.CRITICAL, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()

stopwatch.stop()
logger(logging.INFO, 'took %f seconds to delete %i tokens, %i session parameters and refreshed %i tokens', stopwatch.elapsed, ndeleted, ndeletedreq, nrefreshed)
METRICS.timer('duration').observe(stopwatch.elapsed)


def run(once: bool = False, threads: int = 1, max_rows: int = 100, sleep_time: int = 300) -> None:
"""
Starts up the OAuth Manager threads.
"""
setup_logging(process_name=DAEMON_NAME)

if rucio.db.sqla.util.is_old_db():
raise DatabaseException('Database was not updated, daemon won\'t start')

if once:
OAuthManager(once, max_rows, sleep_time)
else:
logging.info('OAuth Manager starting %s threads', str(threads))
threads = [threading.Thread(target=OAuthManager,
kwargs={'once': once,
'max_rows': max_rows,
'sleep_time': sleep_time}) for i in range(0, threads)]
_ = [t.start() for t in threads]
# Interruptible joins require a timeout.
while threads[0].is_alive():
_ = [t.join(timeout=3.14) for t in threads]


def stop(signum: "Optional[int]" = None, frame: "Optional[FrameType]" = None) -> None:
"""
Graceful exit.
"""
graceful_stop.set()
try:
# ACCESS TOKEN REFRESH - better to run first (in case some of the refreshed tokens needed deletion after this step)
logger(logging.INFO, '----- START ----- ACCESS TOKEN REFRESH ----- ')
logger(logging.INFO, 'starting to query tokens for automatic refresh')
nrefreshed = refresh_jwt_tokens(total_workers, worker_number, refreshrate=int(self.sleep_time), limit=self.max_rows)
logger(logging.INFO, 'successfully refreshed %i tokens', nrefreshed)
logger(logging.INFO, '----- END ----- ACCESS TOKEN REFRESH ----- ')
METRICS.counter(name='oauth_manager.tokens.refreshed').inc(nrefreshed)

except (DatabaseException, DatabaseError) as err:
if match('.*QueuePool.*', str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
elif match(ORACLE_CONNECTION_LOST_CONTACT_REGEX, str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
else:
logger(logging.CRITICAL, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()

try:
# waiting 1 sec as DBs does not store milisecond and tokens
# eligible for deletion after refresh might not get deleted otherwise
self.graceful_stop.wait(1)

# make a heartbeat
worker_number, total_workers, logger = heartbeat_handler.live()

# EXPIRED TOKEN DELETION
logger(logging.INFO, '----- START ----- DELETION OF EXPIRED TOKENS ----- ')
logger(logging.INFO, 'starting to delete expired tokens')
ndeleted += delete_expired_tokens(total_workers, worker_number, limit=self.max_rows)
logger(logging.INFO, 'deleted %i expired tokens', ndeleted)
logger(logging.INFO, '----- END ----- DELETION OF EXPIRED TOKENS ----- ')
METRICS.counter(name='oauth_manager.tokens.deleted').inc(ndeleted)

except (DatabaseException, DatabaseError) as err:
if match('.*QueuePool.*', str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
elif match(ORACLE_CONNECTION_LOST_CONTACT_REGEX, str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
else:
logger(logging.CRITICAL, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()

try:
# make a heartbeat
worker_number, total_workers, logger = heartbeat_handler.live()

# DELETING EXPIRED OAUTH SESSION PARAMETERS
logger(logging.INFO, '----- START ----- DELETION OF EXPIRED OAUTH SESSION REQUESTS ----- ')
logger(logging.INFO, 'starting deletion of expired OAuth session requests')
ndeletedreq += delete_expired_oauthrequests(total_workers, worker_number, limit=self.max_rows)
logger(logging.INFO, 'expired parameters of %i authentication requests were deleted', ndeletedreq)
logger(logging.INFO, '----- END ----- DELETION OF EXPIRED OAUTH SESSION REQUESTS ----- ')
METRICS.counter(name='oauth_manager.oauthreq.deleted').inc(ndeletedreq)

except (DatabaseException, DatabaseError) as err:
if match('.*QueuePool.*', str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
elif match(ORACLE_CONNECTION_LOST_CONTACT_REGEX, str(err.args[0])):
logger(logging.WARNING, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()
else:
logger(logging.CRITICAL, traceback.format_exc())
METRICS.counter('exceptions.{exception}').labels(exception=err.__class__.__name__).inc()

stopwatch.stop()
logger(logging.INFO, 'took %f seconds to delete %i tokens, %i session parameters and refreshed %i tokens', stopwatch.elapsed, ndeleted, ndeletedreq, nrefreshed)
METRICS.timer('duration').observe(stopwatch.elapsed)
return must_sleep, None
2 changes: 0 additions & 2 deletions tests/test_daemons.py
Expand Up @@ -26,7 +26,6 @@
from rucio.daemons.follower import follower
from rucio.daemons.hermes import hermes
from rucio.daemons.judge import cleaner, evaluator, injector
from rucio.daemons.oauthmanager import oauthmanager
from rucio.daemons.reaper import dark_reaper
from rucio.daemons.replicarecoverer import suspicious_replica_recoverer
from rucio.daemons.tracer import kronos
Expand All @@ -50,7 +49,6 @@
cleaner,
evaluator,
injector,
oauthmanager,
dark_reaper,
suspicious_replica_recoverer,
kronos,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_oauthmanager.py
Expand Up @@ -22,7 +22,7 @@
from sqlalchemy import and_, or_
from sqlalchemy.sql.expression import true

from rucio.daemons.oauthmanager.oauthmanager import run, stop
from rucio.daemons.oauthmanager.oauthmanager import OAuthManager
from rucio.db.sqla import models
from rucio.db.sqla.session import get_session

Expand Down Expand Up @@ -236,10 +236,11 @@ def test_oauthmanager(self, mock_oidc_client, random_account):
assert get_token_count(account) == 21

# Run replica recoverer once
oauthmanager = OAuthManager(once=True, max_rows=100)
try:
run(once=True, max_rows=100)
oauthmanager.run()
except KeyboardInterrupt:
stop()
oauthmanager.stop()

# Checking the outcome
assert get_oauth_session_param_count(account) == 2
Expand Down

0 comments on commit 70bb68e

Please sign in to comment.