Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing: Add type annotations to common/logging #6714

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 54 additions & 37 deletions lib/rucio/common/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,56 @@
import sys
from collections.abc import Callable, Iterator, Mapping, Sequence
from traceback import format_tb
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args

from rucio.common.config import config_get, config_get_bool

if TYPE_CHECKING:
from logging import LogRecord
from logging import LogRecord, _SysExcInfoType

from _typeshed import OptExcInfo
from flask import Flask


# Mapping from ECS field paths
# https://www.elastic.co/guide/en/ecs-logging/overview/current/intro.html#_field_mapping
# https://www.elastic.co/guide/en/ecs/8.5/ecs-field-reference.html
# to python log record attributes:
# https://docs.python.org/3/library/logging.html#logrecord-attributes
BUILTIN_FIELDS = (
('@timestamp', 'asctime'),
('message', 'message'),
('log.level', 'levelname'),
('log.origin.function', 'funcName'),
('log.origin.file.line', 'lineno'),
('log.origin.file.name', 'filename'),
('log.logger', 'name'),
('process.pid', 'process'),
('process.name', 'processName'),
('process.thread.id', 'thread'),
('process.thread.name', 'threadName'),
)
ECS_TO_LOG_RECORD_MAP = dict(BUILTIN_FIELDS)
LOG_RECORD_TO_ECS_MAP = dict((f[1], f[0]) for f in BUILTIN_FIELDS)


def _json_serializable(obj: Any):
ECS_FIELDS = Literal[
'@timestamp',
'message',
'log.level',
'log.origin.function',
'log.origin.file.line',
'log.origin.file.name',
'log.logger',
'process.pid',
'process.name',
'process.thread.id',
'process.thread.name'
]

LOG_RECORDS = Literal[
'asctime',
'message',
'levelname',
'funcName',
'lineno',
'filename',
'name',
'process',
'processName',
'thread',
'threadName'
]

BUILTIN_FIELDS: tuple[tuple[ECS_FIELDS, LOG_RECORDS], ...] = tuple((x, y) for x, y in zip(get_args(ECS_FIELDS), get_args(LOG_RECORDS)))
ECS_TO_LOG_RECORD_MAP: dict[ECS_FIELDS, LOG_RECORDS] = dict(BUILTIN_FIELDS)
LOG_RECORD_TO_ECS_MAP: dict[LOG_RECORDS, ECS_FIELDS] = dict((f[1], f[0]) for f in BUILTIN_FIELDS)


def _json_serializable(obj: Any) -> Union[dict[Any, Any], str]:
try:
return obj.__dict__
except AttributeError:
Expand Down Expand Up @@ -160,11 +179,11 @@ def _timestamp_formatter(record_formatter: "LogDataSource", record: "LogRecord")
yield record_formatter.ecs_fields[0], datetime.datetime.utcfromtimestamp(record.created).isoformat(timespec='milliseconds') + 'Z'


def _ecs_field_to_record_attribute(field_name):
def _ecs_field_to_record_attribute(field_name: Union[ECS_FIELDS, str]) -> Union[LOG_RECORDS, str]:
"""
Sanitize the path-like field name into a symbol which can be the name of an object attribute.
"""
record = ECS_TO_LOG_RECORD_MAP.get(field_name)
record = ECS_TO_LOG_RECORD_MAP.get(field_name) # type: ignore
bari12 marked this conversation as resolved.
Show resolved Hide resolved
if record:
return record
return field_name.replace('-', '_').replace('.', '_')
Expand Down Expand Up @@ -195,7 +214,7 @@ def __eq__(self, other: Any):
def __str__(self):
return self.__class__.__name__ + '(' + ', '.join(self.ecs_fields) + ')'

def format(self, record: "LogRecord"):
def format(self, record: "LogRecord") -> Optional[Iterator[tuple[str, Any]]]:
if not self._formatter:
return
for field_name, field_value in self._formatter(self, record):
Expand All @@ -212,7 +231,7 @@ def __init__(self):
)

@staticmethod
def _get_exc_info(record):
def _get_exc_info(record: "LogRecord") -> Optional[Union["OptExcInfo", "_SysExcInfoType"]]:
exc_info = record.exc_info
if not exc_info:
return None
Expand All @@ -222,7 +241,7 @@ def _get_exc_info(record):
return exc_info
return None

def format(self, record: "LogRecord"):
def format(self, record: "LogRecord") -> Iterator[tuple[str, Optional[str]]]:
exc_info = self._get_exc_info(record)
message = record.getMessage()
error_type, error_message, stack_trace = None, None, None
Expand Down Expand Up @@ -253,11 +272,11 @@ class ConstantStrDataSource(LogDataSource):
Prints a constant string for the given ECS field.
"""

def __init__(self, ecs_field, _str):
def __init__(self, ecs_field: ECS_FIELDS, _str: str):
log_record = ECS_TO_LOG_RECORD_MAP.get(ecs_field, None)
self._str = _str

def _formatter(data_source: LogDataSource, record: "LogRecord"):
def _formatter(data_source: LogDataSource, record: "LogRecord") -> Iterator[tuple[str, str]]:
yield self.ecs_fields[0], self._str

super().__init__(ecs_fields=(ecs_field,), formatter=_formatter, dst_record_attr=log_record)
Expand All @@ -284,7 +303,7 @@ def __init__(
fmt: Optional[str] = None,
validate: Optional[bool] = None,
output_json: bool = False,
additional_fields: Optional[Mapping[str, str]] = None
additional_fields: Optional[Mapping[ECS_FIELDS, str]] = None
):
_kwargs = {}
if validate is not None:
Expand Down Expand Up @@ -344,15 +363,15 @@ def __init__(
self.output_json = output_json
super().__init__(fmt=fmt, style='%', **_kwargs)

def format(self, record):
json_record = dict(itertools.chain.from_iterable(f.format(record) for f in self._desired_data_sources))
def format(self, record: "LogRecord") -> str:
json_record = dict(itertools.chain.from_iterable(f.format(record) for f in self._desired_data_sources)) # type: ignore
bari12 marked this conversation as resolved.
Show resolved Hide resolved
if self.output_json:
return self._to_json(_unflatten_dict(json_record))
else:
return super().format(record)

@staticmethod
def _to_json(record):
def _to_json(record: dict[str, Any]) -> str:
try:
return json.dumps(record, default=_json_serializable)
except (TypeError, ValueError, OverflowError):
Expand All @@ -362,7 +381,7 @@ def _to_json(record):
return '{}'


def rucio_log_formatter(process_name: Optional[str] = None):
def rucio_log_formatter(process_name: Optional[str] = None) -> RucioFormatter:
config_logformat = config_get('common', 'logformat', raise_exception=False, default='%(asctime)s\t%(name)s\t%(process)d\t%(levelname)s\t%(message)s')
output_json = config_get_bool('common', 'logjson', default=False)
additional_fields = {}
Expand All @@ -371,7 +390,7 @@ def rucio_log_formatter(process_name: Optional[str] = None):
return RucioFormatter(fmt=config_logformat, output_json=output_json, additional_fields=additional_fields)


def setup_logging(application=None, process_name=None):
def setup_logging(application: Optional["Flask"] = None, process_name: Optional[str] = None) -> None:
"""
Configures the logging by setting the output stream to stdout and
configures log level and log format.
Expand All @@ -387,17 +406,15 @@ def setup_logging(application=None, process_name=None):
application.logger.addHandler(stdouthandler)


def formatted_logger(innerfunc, formatstr="%s"):
def formatted_logger(innerfunc: Callable, formatstr: str = "%s") -> Callable:
"""
Decorates the passed function, formatting log input by
the passed formatstr. The format string must always include a %s.

:param innerfunc: function to be decorated. Must take (level, msg) arguments.
:type innerfunc: Callable
:param formatstr: format string with %s as placeholder.
:type formatstr: str
"""
@functools.wraps(innerfunc)
def log_format(level, msg, *args, **kwargs):
def log_format(level: int, msg: object, *args, **kwargs) -> Callable:
return innerfunc(level, formatstr % msg, *args, **kwargs)
return log_format