Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions bot/exts/filtering/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import inspect
import pkgutil
import types
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable, Iterable
Expand All @@ -15,6 +16,7 @@
import discord
import regex
from discord.ext.commands import Command
from pydantic import PydanticDeprecatedSince20
from pydantic_core import core_schema

import bot
Expand Down Expand Up @@ -181,11 +183,16 @@ def inherited(attr: str) -> bool:

# If a new attribute with the value MUST_SET_UNIQUE was defined in an abstract class, record it.
if inspect.isabstract(cls):
for attribute in dir(cls):
if getattr(cls, attribute, None) is FieldRequiring.MUST_SET_UNIQUE:
if not inherited(attribute):
# A new attribute with the value MUST_SET_UNIQUE.
FieldRequiring.__unique_attributes[cls][attribute] = set()
with warnings.catch_warnings():
# The code below will raise a warning about the use the __fields__ attr on a pydantic model
# This will continue to be warned about until removed in pydantic 3.0
# This warning is a false-positive as only the custom MUST_SET_UNIQUE attr is used here
warnings.simplefilter("ignore", category=PydanticDeprecatedSince20)
for attribute in dir(cls):
if getattr(cls, attribute, None) is FieldRequiring.MUST_SET_UNIQUE:
if not inherited(attribute):
# A new attribute with the value MUST_SET_UNIQUE.
FieldRequiring.__unique_attributes[cls][attribute] = set()
return

for attribute in dir(cls):
Expand Down
3 changes: 2 additions & 1 deletion bot/exts/moderation/watchchannels/_watchchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
from pydis_core.site_api import ResponseCodeError
from pydis_core.utils import scheduling
from pydis_core.utils.channel import get_or_fetch_channel
from pydis_core.utils.logging import CustomLogger
from pydis_core.utils.members import get_or_fetch_member

from bot.bot import Bot
from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons
from bot.exts.filtering._filters.unique.discord_token import DiscordTokenFilter
from bot.exts.filtering._filters.unique.webhook import WEBHOOK_URL_RE
from bot.exts.moderation.modlog import ModLog
from bot.log import CustomLogger, get_logger
from bot.log import get_logger
from bot.pagination import LinePaginator
from bot.utils import CogABCMeta, messages, time

Expand Down
59 changes: 9 additions & 50 deletions bot/log.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,29 @@
import logging
import os
import sys
from logging import Logger, handlers
from logging import handlers
from pathlib import Path
from typing import TYPE_CHECKING, cast

import coloredlogs
import sentry_sdk
from pydis_core.utils import logging as core_logging
from sentry_sdk.integrations.logging import LoggingIntegration
from sentry_sdk.integrations.redis import RedisIntegration

from bot import constants

TRACE_LEVEL = 5


if TYPE_CHECKING:
LoggerClass = Logger
else:
LoggerClass = logging.getLoggerClass()


class CustomLogger(LoggerClass):
"""Custom implementation of the `Logger` class with an added `trace` method."""

def trace(self, msg: str, *args, **kwargs) -> None:
"""
Log 'msg % args' with severity 'TRACE'.

To pass exception information, use the keyword argument exc_info with
a true value, e.g.

logger.trace("Houston, we have an %s", "interesting problem", exc_info=1)
"""
if self.isEnabledFor(TRACE_LEVEL):
self.log(TRACE_LEVEL, msg, *args, **kwargs)


def get_logger(name: str | None = None) -> CustomLogger:
"""Utility to make mypy recognise that logger is of type `CustomLogger`."""
return cast(CustomLogger, logging.getLogger(name))
get_logger = core_logging.get_logger


def setup() -> None:
"""Set up loggers."""
logging.TRACE = TRACE_LEVEL
logging.addLevelName(TRACE_LEVEL, "TRACE")
logging.setLoggerClass(CustomLogger)

root_log = get_logger()

format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s"
log_format = logging.Formatter(format_string)

if constants.FILE_LOGS:
log_file = Path("logs", "bot.log")
log_file.parent.mkdir(exist_ok=True)
file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7, encoding="utf8")
file_handler.setFormatter(log_format)
file_handler.setFormatter(core_logging.log_format)
root_log.addHandler(file_handler)

if "COLOREDLOGS_LEVEL_STYLES" not in os.environ:
Expand All @@ -69,18 +35,11 @@ def setup() -> None:
}

if "COLOREDLOGS_LOG_FORMAT" not in os.environ:
coloredlogs.DEFAULT_LOG_FORMAT = format_string
coloredlogs.DEFAULT_LOG_FORMAT = core_logging.log_format._fmt
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only thing I'm slightly sad about. Since I didn't surface the logging string used by bot core, this is the only way to get the raw logging string.

I think it's fine, but if anyone feels strongly about it I can update bot core to make this available in a variable.


coloredlogs.install(level=TRACE_LEVEL, logger=root_log, stream=sys.stdout)
coloredlogs.install(level=core_logging.TRACE_LEVEL, logger=root_log, stream=sys.stdout)

root_log.setLevel(logging.DEBUG if constants.DEBUG_MODE else logging.INFO)
get_logger("discord").setLevel(logging.WARNING)
get_logger("websockets").setLevel(logging.WARNING)
get_logger("chardet").setLevel(logging.WARNING)
get_logger("async_rediscache").setLevel(logging.WARNING)

# Set back to the default of INFO even if asyncio's debug mode is enabled.
get_logger("asyncio").setLevel(logging.INFO)

_set_trace_loggers()

Expand Down Expand Up @@ -121,13 +80,13 @@ def _set_trace_loggers() -> None:
level_filter = constants.Bot.trace_loggers
if level_filter:
if level_filter.startswith("*"):
get_logger().setLevel(TRACE_LEVEL)
get_logger().setLevel(core_logging.TRACE_LEVEL)

elif level_filter.startswith("!"):
get_logger().setLevel(TRACE_LEVEL)
get_logger().setLevel(core_logging.TRACE_LEVEL)
for logger_name in level_filter.strip("!,").split(","):
get_logger(logger_name).setLevel(logging.DEBUG)

else:
for logger_name in level_filter.strip(",").split(","):
get_logger(logger_name).setLevel(TRACE_LEVEL)
get_logger(logger_name).setLevel(core_logging.TRACE_LEVEL)