From 43ecf3de208cb5ee2213e182f6ae1a7c23f21276 Mon Sep 17 00:00:00 2001 From: Mega-JC <65417594+Mega-JC@users.noreply.github.com> Date: Fri, 26 Jul 2024 21:49:38 +0200 Subject: [PATCH] Generalize message validation system for showcase messages in 'showcase' extension, fix deploy-to-vps.yml --- .github/workflows/deploy-to-vps.yml | 3 +- pcbot/exts/op.py | 1 - pcbot/exts/showcase/__init__.py | 61 +++++ pcbot/exts/{showcase.py => showcase/cogs.py} | 145 ++++------ pcbot/exts/showcase/utils/__init__.py | 20 ++ pcbot/exts/showcase/utils/rules.py | 271 +++++++++++++++++++ pcbot/exts/showcase/utils/utils.py | 76 ++++++ pcbot/exts/showcase/utils/validators.py | 234 ++++++++++++++++ 8 files changed, 720 insertions(+), 91 deletions(-) create mode 100644 pcbot/exts/showcase/__init__.py rename pcbot/exts/{showcase.py => showcase/cogs.py} (87%) create mode 100644 pcbot/exts/showcase/utils/__init__.py create mode 100644 pcbot/exts/showcase/utils/rules.py create mode 100644 pcbot/exts/showcase/utils/utils.py create mode 100644 pcbot/exts/showcase/utils/validators.py diff --git a/.github/workflows/deploy-to-vps.yml b/.github/workflows/deploy-to-vps.yml index 221c9da..29935c1 100644 --- a/.github/workflows/deploy-to-vps.yml +++ b/.github/workflows/deploy-to-vps.yml @@ -19,11 +19,10 @@ jobs: SSH_PRIVATE_KEY: ${{ secrets.VPS_SSH_PRIVATE_KEY }} REMOTE_HOST: ${{ secrets.VPS_HOST }} REMOTE_USER: ${{ secrets.VPS_USER }} - SOURCE: "." - TARGET: ~/PygameCommunityBot/ SCRIPT_AFTER: | cp ~/config.py ~/PygameCommunityBot/config.py cp ~/.env ~/PygameCommunityBot/.env + cd ~/PygameCommunityBot docker compose stop docker compose rm -f sleep 60 && docker compose up -d --build diff --git a/pcbot/exts/op.py b/pcbot/exts/op.py index c1a0db3..6f1c9b8 100644 --- a/pcbot/exts/op.py +++ b/pcbot/exts/op.py @@ -151,7 +151,6 @@ async def op_pin_func( ), timeout=2, ) - print("HERE") except asyncio.TimeoutError: pass else: diff --git a/pcbot/exts/showcase/__init__.py b/pcbot/exts/showcase/__init__.py new file mode 100644 index 0000000..12cf62b --- /dev/null +++ b/pcbot/exts/showcase/__init__.py @@ -0,0 +1,61 @@ +from typing import Collection +import discord +import snakecore + +from .utils import ShowcaseChannelConfig + +BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot + + +@snakecore.commands.decorators.with_config_kwargs +async def setup( + bot: BotT, + showcase_channels_config: Collection[ShowcaseChannelConfig], + theme_color: int | discord.Color = 0, +): + # validate showcase channels config + for i, showcase_channel_config in enumerate(showcase_channels_config): + if "channel_id" not in showcase_channel_config: + raise ValueError("Showcase channel config must have a 'channel_id' key") + elif ( + "default_auto_archive_duration" in showcase_channel_config + and not isinstance( + showcase_channel_config["default_auto_archive_duration"], int + ) + ): + raise ValueError( + "Showcase channel config 'default_auto_archive_duration' must be an integer" + ) + elif ( + "default_thread_slowmode_delay" in showcase_channel_config + and not isinstance( + showcase_channel_config["default_thread_slowmode_delay"], int + ) + ): + raise ValueError( + "Showcase channel config 'default_thread_slowmode_delay' must be an integer" + ) + elif "showcase_message_rules" not in showcase_channel_config: + raise ValueError( + "Showcase channel config must have a 'showcase_message_rules' key" + ) + + from .utils import dispatch_rule_specifier_dict_validator, BadRuleSpecifier + + specifier_dict_validator = dispatch_rule_specifier_dict_validator( + showcase_channel_config["showcase_message_rules"] + ) + + # validate 'showcase_message_rules' value + try: + specifier_dict_validator( + showcase_channel_config["showcase_message_rules"] # type: ignore + ) + except BadRuleSpecifier as e: + raise ValueError( + f"Error while parsing config.{i}.showcase_message_rules field: {e}" + ) from e + + from .cogs import Showcasing + + await bot.add_cog(Showcasing(bot, showcase_channels_config, theme_color)) diff --git a/pcbot/exts/showcase.py b/pcbot/exts/showcase/cogs.py similarity index 87% rename from pcbot/exts/showcase.py rename to pcbot/exts/showcase/cogs.py index 0fb4ffe..42fa93c 100644 --- a/pcbot/exts/showcase.py +++ b/pcbot/exts/showcase/cogs.py @@ -3,13 +3,15 @@ Copyright (c) 2022-present pygame-community. """ +import abc import asyncio from collections.abc import Collection import datetime +import enum import itertools import re import time -from typing import NotRequired, TypedDict +from typing import Any, Callable, Literal, NotRequired, Protocol, TypedDict import discord from discord.ext import commands @@ -18,17 +20,11 @@ from snakecore.commands import UnicodeEmoji from snakecore.commands.converters import DateTime -from ..base import BaseExtensionCog - -BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot +from .utils import ShowcaseChannelConfig, validate_message +from ...base import BaseExtensionCog -class ShowcaseChannelConfig(TypedDict): - """A typed dict for specifying showcase channel configurations.""" - - channel_id: int - default_auto_archive_duration: NotRequired[int] - default_thread_slowmode_delay: NotRequired[int] +BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot class Showcasing(BaseExtensionCog, name="showcasing"): @@ -372,39 +368,21 @@ async def delete_bad_message_with_thread( # don't error here if thread and/or message were already deleted pass - @staticmethod def showcase_message_validity_check( - message: discord.Message, min_chars=32, max_chars=float("inf") - ): - """Checks if a thread's starter message has the right format. + self, + message: discord.Message, + ) -> tuple[bool, str | None]: + """Checks if a showcase message has the right format. Returns ------- - bool: - True/False + tuple[bool, str | None]: + A tuple containing a boolean indicating whether the message is valid or not, and a string describing the reason why it is invalid if it is not valid. """ - - search_obj = re.search( - snakecore.utils.regex_patterns.URL, message.content or "" + return validate_message( + message, + self.showcase_channels_config[message.channel.id]["showcase_message_rules"], ) - link_in_msg = bool(search_obj) - first_link_str = search_obj.group() if link_in_msg else "" - - char_length = len(message.content) + len( - message.channel.name if isinstance(message.channel, discord.Thread) else "" - ) - - if ( - message.content - and (link_in_msg and char_length > len(first_link_str)) - and min_chars <= char_length < max_chars - ): - return True - - elif message.content and message.attachments: - return True - - return False @commands.Cog.listener() async def on_thread_create(self, thread: discord.Thread): @@ -419,15 +397,15 @@ async def on_thread_create(self, thread: discord.Thread): except discord.NotFound: return - if not self.showcase_message_validity_check(message): + is_valid, reason = self.showcase_message_validity_check(message) + + if not is_valid: deletion_datetime = datetime.datetime.now( datetime.timezone.utc ) + datetime.timedelta(minutes=5) warn_msg = await message.reply( - "Your message must contain an attachment or text and safe links to be valid.\n\n" - "- Attachment-only entries must be in reference to a previous post of yours.\n" - "- Text-only posts must contain at least 32 characters (including their title " - "and including links, but not links alone).\n\n" + "### Invalid showcase message\n\n" + f"{reason}\n\n" " If no changes are made, your message (and its thread/post) will be " f"deleted {snakecore.utils.create_markdown_timestamp(deletion_datetime, 'R')}." ) @@ -493,28 +471,31 @@ async def prompt_author_for_feedback_thread(self, message: discord.Message): pass else: if snakecore.utils.is_emoji_equal(event.emoji, "✅"): - await message.create_thread( - name=( - f"Feedback for " - + f"@{message.author.name} | {str(message.author.id)[-6:]}" - )[:100], - auto_archive_duration=( - self.showcase_channels_config[message.channel.id].get( - "default_auto_archive_duration", 60 - ) - if bot_perms.manage_threads - else discord.utils.MISSING - ), # type: ignore - slowmode_delay=( - self.showcase_channels_config[message.channel.id].get( - "default_thread_slowmode_delay", - ) - if bot_perms.manage_threads - else None - ), # type: ignore - reason=f"A '#{message.channel.name}' message " - "author requested a feedback thread.", - ) + try: + await message.create_thread( + name=( + f"Feedback for " + + f"@{message.author.name} | {str(message.author.id)[-6:]}" + )[:100], + auto_archive_duration=( + self.showcase_channels_config[message.channel.id].get( + "default_auto_archive_duration", 60 + ) + if bot_perms.manage_threads + else discord.utils.MISSING + ), # type: ignore + slowmode_delay=( + self.showcase_channels_config[message.channel.id].get( + "default_thread_slowmode_delay", + ) + if bot_perms.manage_threads + else None + ), # type: ignore + reason=f"A '#{message.channel.name}' message " + "author requested a feedback thread.", + ) + except discord.HTTPException: + pass try: await alert_msg.delete() @@ -533,17 +514,17 @@ async def on_message(self, message: discord.Message): ): return - if self.showcase_message_validity_check(message): + is_valid, reason = self.showcase_message_validity_check(message) + + if is_valid: await self.prompt_author_for_feedback_thread(message) else: deletion_datetime = datetime.datetime.now( datetime.timezone.utc ) + datetime.timedelta(minutes=5) warn_msg = await message.reply( - "Your message must contain an attachment or text and safe links to be valid.\n\n" - "- Attachment-only entries must be in reference to a previous post of yours.\n" - "- Text-only posts must contain at least 32 characters (including their title " - "and including links, but not links alone).\n\n" + "### Invalid showcase message\n\n" + f"{reason}\n\n" " If no changes are made, your message (and its thread/post) will be " f"deleted {snakecore.utils.create_markdown_timestamp(deletion_datetime, 'R')}." ) @@ -575,7 +556,9 @@ async def on_message_edit(self, old: discord.Message, new: discord.Message): ): return - if not self.showcase_message_validity_check(new): + is_valid, reason = self.showcase_message_validity_check(new) + + if not is_valid: if new.id in self.entry_message_deletion_dict: deletion_data_tuple = self.entry_message_deletion_dict[new.id] deletion_task = deletion_data_tuple[0] @@ -594,14 +577,9 @@ async def on_message_edit(self, old: discord.Message, new: discord.Message): ) + datetime.timedelta(minutes=5) await warn_msg.edit( content=( - "I noticed your edit. However:\n\n" - "Your post must contain an attachment or text and safe " - "links to be valid.\n\n" - "- Attachment-only entries must be in reference to a " - "previous post of yours.\n" - "- Text-only posts must contain at least 32 " - "characters (including their title " - "and including links, but not links alone).\n\n" + "### Invalid showcase message\n\n" + "Your edited showcase message is invalid.\n\n" + f"{reason}\n\n" " If no changes are made, your post will be " f"deleted " + snakecore.utils.create_markdown_timestamp( @@ -647,7 +625,7 @@ async def on_message_edit(self, old: discord.Message, new: discord.Message): ) elif ( - self.showcase_message_validity_check(new) + is_valid ) and new.id in self.entry_message_deletion_dict: # an invalid entry was corrected deletion_data_tuple = self.entry_message_deletion_dict[new.id] deletion_task = deletion_data_tuple[0] @@ -787,12 +765,3 @@ async def on_raw_thread_delete(self, payload: discord.RawThreadDeleteEvent): deletion_task.cancel() del self.entry_message_deletion_dict[payload.thread_id] - - -@snakecore.commands.decorators.with_config_kwargs -async def setup( - bot: BotT, - showcase_channels_config: Collection[ShowcaseChannelConfig], - theme_color: int | discord.Color = 0, -): - await bot.add_cog(Showcasing(bot, showcase_channels_config, theme_color)) diff --git a/pcbot/exts/showcase/utils/__init__.py b/pcbot/exts/showcase/utils/__init__.py new file mode 100644 index 0000000..dcb8b11 --- /dev/null +++ b/pcbot/exts/showcase/utils/__init__.py @@ -0,0 +1,20 @@ +from abc import ABC +import re +from typing import Any, Callable, Collection, Literal, NotRequired, TypedDict +import discord +import snakecore + +from .utils import * +from .validators import * + +BotT = snakecore.commands.Bot | snakecore.commands.AutoShardedBot + + +class ShowcaseChannelConfig(TypedDict): + """A typed dict for specifying showcase channel configurations.""" + + channel_id: int + default_auto_archive_duration: NotRequired[int] + default_thread_slowmode_delay: NotRequired[int] + showcase_message_rules: RuleSpecifier | RuleSpecifierPair | RuleSpecifierList + "A rule specifier dict for validating messages posted to the showcase channel" diff --git a/pcbot/exts/showcase/utils/rules.py b/pcbot/exts/showcase/utils/rules.py new file mode 100644 index 0000000..a91c8c6 --- /dev/null +++ b/pcbot/exts/showcase/utils/rules.py @@ -0,0 +1,271 @@ +# Base class for common message validation logic +import re +from typing import Literal +import discord +from .utils import MISSING, URL_PATTERN, DiscordMessageRule, EnforceType, is_vcs_url + + +class ContentRule(DiscordMessageRule, name="content"): + """A rule for validating if a Discord message contains only, any (the default), or no content.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: Literal["any", "only", "none"] = "any", + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + """Validate a message for the presence of content according to the specified arguments.""" + + has_content = bool(message.content) + only_content = not has_content and not (message.attachments or message.embeds) + + if enforce_type == "always" and arg == "only" and not only_content: + return (False, "Message must always contain only text content") + + if enforce_type == "always" and arg == "any" and not has_content: + return (False, "Message must always contain text content") + + if enforce_type == "always" and arg == "none" and has_content: + return (False, "Message must always contain no text content") + + if enforce_type == "never" and arg == "only" and only_content: + return (False, "Message must never contain only text content") + + if enforce_type == "never" and arg == "any" and has_content: + return (False, "Message must never contain text content") + + if enforce_type == "never" and arg == "none" and not has_content: + return (False, "Message must never contain no text content") + + return (True, None) + + @staticmethod + def validate_arg(arg: Literal["any", "only", "none"]) -> str | None: + if arg not in (MISSING, "any", "only", "none"): + return "Argument must be one of 'any', 'only', or 'none'" + + +class ContentLengthRule(DiscordMessageRule, name="content-length"): + """A rule for validating if a Discord message contains text content within the specified length range.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: tuple[int, int], + ): + """Validate a message for the presence of text content within the specified length range.""" + + if not isinstance(arg, tuple) or len(arg) != 2: + raise ValueError("Argument must be a tuple of two integers") + + min_length, max_length = arg + + min_length = min_length or 0 + max_length = max_length or 4096 + + if min_length > max_length: + raise ValueError( + "Minimum length must be less than or equal to maximum length" + ) + + content_length = len(message.content) + + if enforce_type == "always" and not ( + min_length <= content_length <= max_length + ): + return ( + False, + f"Message must always contain text content within {min_length}-{max_length} characters", + ) + + if enforce_type == "never" and (min_length <= content_length <= max_length): + return ( + False, + f"Message must never contain text content within {min_length}-{max_length} characters", + ) + + return (True, None) + + @staticmethod + def validate_arg(arg: tuple[int | None, int | None]) -> str | None: + if (not isinstance(arg, (list, tuple))) or ( + isinstance(arg, (list, tuple)) and len(arg) != 2 + ): + return "Argument must be a list/tuple of two integers" + + if arg[0] is not None and arg[1] is not None: + if arg[0] > arg[1]: + return "Minimum length must be less than or equal to maximum length" + elif arg[0] is not None: + if arg[0] < 0: + return "Minimum length must be greater than or equal to 0" + elif arg[1] is not None: + if arg[1] < 0: + return "Maximum length must be greater than or equal to 0" + + +class URLsRule(DiscordMessageRule, name="urls"): + """A rule for validating if a Discord message contains only, at least one or no URLs.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: Literal["any", "only", "none"], + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + """Validate a message for the presence of URLs according to the specified arguments.""" + + search_obj = tuple(re.finditer(URL_PATTERN, message.content)) + links = tuple(match.group() for match in search_obj if match) + any_urls = bool(links) + only_urls = any_urls and sum(len(link) for link in links) == len( + re.sub(r"\s", "", message.content) + ) + no_urls = not any_urls + + if enforce_type == "always" and arg == "only" and not only_urls: + return (False, "Message must always contain only URLs") + + if enforce_type == "always" and arg == "any" and not any_urls: + return (False, "Message must always contain at least one URL") + + if enforce_type == "always" and arg == "none" and not no_urls: + return (False, "Message must always contain no URLs") + + if enforce_type == "never" and arg == "only" and only_urls: + return (False, "Message must never contain only URLs") + + if enforce_type == "never" and arg == "any" and any_urls: + return (False, "Message must never contain at least one URL") + + if enforce_type == "never" and arg == "none" and no_urls: + return (False, "Message must never contain no URLs") + + return (True, None) + + +# Rule for validating VCS URLs +class VCSURLsRule(DiscordMessageRule, name="vcs-urls"): + """A rule for validating if a Discord message contains only, at least one (the default), or no valid VCS URLs.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: Literal["any", "all", "none"] = "any", + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + """Validate a message for the presence of VCS URLs according to the specified arguments.""" + + search_obj = tuple(re.finditer(URL_PATTERN, message.content or "")) + links = tuple(match.group() for match in search_obj if match) + any_vcs_urls = links and any(is_vcs_url(link) for link in links) + no_vcs_urls = not any_vcs_urls + all_vcs_urls = not any(not is_vcs_url(link) for link in links) + + if enforce_type == "always" and arg == "all" and not all_vcs_urls: + return (False, "Message must always contain only valid VCS URLs") + + if enforce_type == "always" and arg == "any" and not any_vcs_urls: + return (False, "Message must always contain at least one valid VCS URL") + + if enforce_type == "always" and arg == "none" and not no_vcs_urls: + return (False, "Message must always contain no valid VCS URLs") + + if enforce_type == "never" and arg == "all" and all_vcs_urls: + return (False, "Message must never contain only valid VCS URLs") + + if enforce_type == "never" and arg == "any" and any_vcs_urls: + return (False, "Message must never contain at least one valid VCS URL") + + if enforce_type == "never" and arg == "none" and no_vcs_urls: + return (False, "Message must never contain no valid VCS URLs") + + return (True, None) + + @staticmethod + def validate_arg(arg: Literal["any", "all", "none"]) -> str | None: + if arg not in (MISSING, "any", "all", "none"): + return "Argument must be one of 'any', 'all', or 'none'" + + +class AttachmentsRule(DiscordMessageRule, name="attachments"): + """A rule for validating if a Discord message contains only, at least one or no attachments.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: Literal["any", "only", "none"], + ): + """Validate a message for the presence of attachments according to the specified arguments.""" + + any_attachments = bool(message.attachments) + only_attachments = any_attachments and not (message.content or message.embeds) + no_attachments = not any_attachments + + if enforce_type == "always" and arg == "only" and not only_attachments: + return (False, "Message must always contain only attachments") + + if enforce_type == "always" and arg == "any" and not any_attachments: + return (False, "Message must always contain at least one attachment") + + if enforce_type == "always" and arg == "none" and not no_attachments: + return (False, "Message must always contain no attachments") + + if enforce_type == "never" and arg == "only" and only_attachments: + return (False, "Message must never contain only attachments") + + if enforce_type == "never" and arg == "any" and any_attachments: + return (False, "Message must never contain at least one attachment") + + if enforce_type == "never" and arg == "none" and no_attachments: + return (False, "Message must never contain no attachments") + + return (True, None) + + +class EmbedsRule(DiscordMessageRule, name="embeds"): + """A rule for validating if a Discord message contains only, at least one or no embeds.""" + + @staticmethod + def validate( + enforce_type: EnforceType, + message: discord.Message, + arg: Literal["any", "only", "none"], + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + """Validate a message for the presence of embeds according to the specified arguments.""" + + any_embeds = bool(message.embeds) + only_embeds = any_embeds and not (message.content or message.attachments) + no_embeds = not any_embeds + + if enforce_type == "always" and arg == "only" and not only_embeds: + return (False, "Message must always contain only embeds") + + if enforce_type == "always" and arg == "any" and not any_embeds: + return (False, "Message must always contain at least one embed") + + if enforce_type == "always" and arg == "none" and not no_embeds: + return (False, "Message must always contain no embeds") + + if enforce_type == "never" and arg == "only" and only_embeds: + return (False, "Message must never contain only embeds") + + if enforce_type == "never" and arg == "any" and any_embeds: + return (False, "Message must never contain at least one embed") + + if enforce_type == "never" and arg == "none" and no_embeds: + return (False, "Message must never contain no embeds") + + return (True, None) + + +RULE_MAPPING: dict[str, type[DiscordMessageRule]] = { + "content": ContentRule, + "content-length": ContentLengthRule, + "urls": URLsRule, + "vcs-urls": VCSURLsRule, + "attachments": AttachmentsRule, + "embeds": EmbedsRule, +} diff --git a/pcbot/exts/showcase/utils/utils.py b/pcbot/exts/showcase/utils/utils.py new file mode 100644 index 0000000..5127bb3 --- /dev/null +++ b/pcbot/exts/showcase/utils/utils.py @@ -0,0 +1,76 @@ +# ABC for rules +from abc import ABC, abstractmethod +import re +from typing import Any, Literal, NotRequired, TypedDict +import discord + +EnforceType = Literal["always", "never"] + +MISSING: Any = object() + + +URL_PATTERN = re.compile( + r"(?P\w+):\/\/(?:(?P[\w_.-]+(?::[\w_.-]+)?)@)?(?P(?:(?P[\w_-]+(?:\.[\w_-]+)*)\.)?(?P(?P[\w_-]+)\.(?P\w+))|(?P[\w_-]+))(?:\:(?P\d+))?(?P\/[\w.,@?^=%&:\/~+-]*)?(?:\?(?P[\w.,@?^=%&:\/~+-]*))?(?:#(?P[\w@?^=%&\/~+#-]*))?" +) + + +def is_vcs_url(url: str) -> bool: + """Check if a URL points to a known VCS SaaS (e.g. GitHub, GitLab, Bitbucket).""" + return bool( + (match_ := (re.match(URL_PATTERN, url))) + and match_.group("scheme") in ("https", "http") + and match_.group("domain") in ("github.com", "gitlab.com", "bitbucket.org") + ) + + +class RuleSpecifier(TypedDict): + name: str + enforce_type: EnforceType + arg: NotRequired[Any] + description: NotRequired[str] + + +class RuleSpecifierPair(TypedDict): + mode: Literal["and", "or"] + clause1: "RuleSpecifier | RuleSpecifierPair | RuleSpecifierList" + clause2: "RuleSpecifier | RuleSpecifierPair | RuleSpecifierList" + description: NotRequired[str] + + +class RuleSpecifierList(TypedDict): + mode: Literal["any", "all"] + clauses: list["RuleSpecifier | RuleSpecifierPair | RuleSpecifierList"] + description: NotRequired[str] + + +class BadRuleSpecifier(Exception): + """Exception raised when a rule specifier is invalid.""" + + pass + + +class DiscordMessageRule(ABC): + name: str + + def __init_subclass__(cls, name: str) -> None: + cls.name = name + + @staticmethod + @abstractmethod + def validate( + enforce_type: EnforceType, message: discord.Message, arg: Any = None + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + ... + + @staticmethod + def validate_arg(arg: Any) -> str | None: + ... + + +class AsyncDiscordMessageRule(DiscordMessageRule, name="AsyncDiscordMessageRule"): + @staticmethod + @abstractmethod + async def validate( + enforce_type: EnforceType, message: discord.Message, arg: Any = None + ) -> tuple[Literal[False], str] | tuple[Literal[True], None]: + ... diff --git a/pcbot/exts/showcase/utils/validators.py b/pcbot/exts/showcase/utils/validators.py new file mode 100644 index 0000000..330c614 --- /dev/null +++ b/pcbot/exts/showcase/utils/validators.py @@ -0,0 +1,234 @@ +from typing import Callable, overload + +import discord + +from .rules import RULE_MAPPING +from .utils import ( + MISSING, + BadRuleSpecifier, + RuleSpecifier, + RuleSpecifierList, + RuleSpecifierPair, +) + + +def dispatch_rule_specifier_dict_validator( + specifier: RuleSpecifier | RuleSpecifierPair | RuleSpecifierList, +) -> ( + Callable[[RuleSpecifier], None] + | Callable[[RuleSpecifierPair], None] + | Callable[[RuleSpecifierList], None] + | None +): + """Dispatch the appropriate validator to use to validate the structure of a rule specifier.""" + + if "mode" in specifier: + if specifier["mode"] in ("and", "or"): + return validate_rule_specifier_dict_pair + elif specifier["mode"] in ("any", "all"): + return validate_rule_specifier_dict_list + else: + return validate_rule_specifier_dict_single + + return None + + +def validate_rule_specifier_dict_single( + specifier: RuleSpecifier, + depth_viz: str = "RuleSpecifier", +) -> None: + """Validate a single rule specifier's structure.""" + + if specifier["name"] not in RULE_MAPPING: + raise BadRuleSpecifier( + f"{depth_viz}.name: Unknown rule '{specifier['name']}'" + ) # type + elif "enforce_type" not in specifier or specifier["enforce_type"].lower() not in ( + "always", + "never", + ): + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifier 'enforce_type' field must be set to 'always' or 'never'" + ) + + error_string = RULE_MAPPING[specifier["name"]].validate_arg( + specifier.get("arg", MISSING) + ) + + if error_string is not None: + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifier 'arg' field validation failed: {error_string}" + ) + + +def validate_rule_specifier_dict_pair( + specifier: RuleSpecifierPair, + depth_viz: str = "RuleSpecifierPair", +) -> None: + """Validate a rule specifier pair's structure.""" + + if "mode" not in specifier or specifier["mode"] not in ("and", "or"): + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierPair 'mode' field must be 'and' or 'or'" + ) + + if "clause1" not in specifier or "clause2" not in specifier: + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierPair must have 'clause1' " + "and 'clause2' fields pointing to RuleSpecifier or RuleSpecifierPair or RuleSpecifierList dicts" + ) + + dict_validator1 = dispatch_rule_specifier_dict_validator(specifier["clause1"]) + dict_validator2 = dispatch_rule_specifier_dict_validator(specifier["clause2"]) + + if dict_validator1 is None: + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierPair 'clause1' field " + "must be a RuleSpecifier or RuleSpecifierPair or RuleSpecifierList dict" + ) + + if dict_validator2 is None: + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierPair 'clause2' field " + "must be a RuleSpecifier or RuleSpecifierPair or RuleSpecifierList dict" + ) + + dict_validator1(specifier["clause1"], depth_viz=f"{depth_viz}.clause1") # type: ignore + dict_validator2(specifier["clause2"], depth_viz=f"{depth_viz}.clause2") # type: ignore + + +def validate_rule_specifier_dict_list( + specifier: RuleSpecifierList, + depth_viz: str = "RuleSpecifierList", +) -> None: + """Validate a rule specifier list's structure.""" + + if "mode" not in specifier or specifier["mode"] not in ("any", "all"): + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierList 'mode' field must be 'any' or 'all'" + ) + + if "clauses" not in specifier or not specifier["clauses"]: + raise BadRuleSpecifier( + f"{depth_viz}.RuleSpecifierList must have 'clauses' " + "field pointing to a list of RuleSpecifier or RuleSpecifierPair or RuleSpecifierList dicts" + ) + + for i, clause in enumerate(specifier["clauses"]): + dict_validator = dispatch_rule_specifier_dict_validator(clause) + if dict_validator is None: + raise BadRuleSpecifier( + f"{depth_viz}.clauses.{i} field " + "must be a RuleSpecifier or RuleSpecifierPair or RuleSpecifierList dict" + ) + + dict_validator(clause, depth_viz=f"{depth_viz}.clauses.{i}") # type: ignore + + +def dispatch_rule_specifier_message_validator( + specifier: RuleSpecifier | RuleSpecifierPair | RuleSpecifierList, +): + """Dispatch the appropriate validator to use to enforce a rule specifier on a Discord message.""" + + if "mode" in specifier: + if specifier["mode"] in ("and", "or"): + return rule_specifier_pair_validate_message + elif specifier["mode"] in ("any", "all"): + return rule_specifier_list_validate_message + return rule_specifier_single_validate_message + + +def rule_specifier_single_validate_message( + specifier: RuleSpecifier, + message: discord.Message, + depth_viz: str = "", +) -> tuple[bool, str | None]: + """Validate a message according to a single rule specifier.""" + + rule = RULE_MAPPING[specifier["name"]] + + if "arg" in specifier: + result = rule.validate(specifier["enforce_type"], message, specifier["arg"]) + else: + result = rule.validate(specifier["enforce_type"], message) + + if "description" in specifier: + # insert description of rule specifier if present + return (result[0], specifier["description"] if not result[0] else None) + + return result + + +def rule_specifier_pair_validate_message( + specifier: RuleSpecifierPair, + message: discord.Message, +) -> tuple[bool, str | None]: + """Validate a message according to a rule specifier pair.""" + + success = True + failure_description = specifier.get("description") + + print(specifier) + + validator1 = dispatch_rule_specifier_message_validator(specifier["clause1"]) + validator2 = dispatch_rule_specifier_message_validator(specifier["clause2"]) + + result1 = validator1(specifier["clause1"], message) # type: ignore + result2 = None + + success = result1[0] + if (specifier["mode"] == "and" and success) or ( + specifier["mode"] == "or" and not success + ): + result2 = validator2(specifier["clause2"], message) # type: ignore + success = bool(result2[0]) + + if not result1[0] and failure_description is None: + failure_description = result1[1] + elif result2 and not result2[0] and failure_description is None: + failure_description = result2[1] + + return (success, failure_description if not success else None) + + +def rule_specifier_list_validate_message( + specifier: RuleSpecifierList, + message: discord.Message, +) -> tuple[bool, str | None]: + """Validate a message according to a rule specifier list.""" + + success = True + failure_description = specifier.get("description") + + if specifier["mode"] == "all": + for i, clause in enumerate(specifier["clauses"]): + validator = dispatch_rule_specifier_message_validator(clause) + result = validator(clause, message) # type: ignore + if not result[0]: + success = False + if failure_description is None: + failure_description = result[1] + break + + elif specifier["mode"] == "any": + for i, clause in enumerate(specifier["clauses"]): + validator = dispatch_rule_specifier_message_validator(clause) + result = validator(clause, message) # type: ignore + success = success or result[0] + + if not success and failure_description is None: + failure_description = result[1] + + return (success, failure_description if not success else None) + + +def validate_message( + message: discord.Message, + specifier: RuleSpecifier | RuleSpecifierPair | RuleSpecifierList, +) -> tuple[bool, str | None]: + """Validate a message according to a rule specifier.""" + + validator = dispatch_rule_specifier_message_validator(specifier) + result = validator(specifier, message) # type: ignore + + return result