From 679cb77a5ad1eefa8c027f3f318d3fb3466cf74d Mon Sep 17 00:00:00 2001 From: Ibrahim2750mi Date: Tue, 14 Feb 2023 19:30:54 +0530 Subject: [PATCH 1/6] Migrated `!tags` command to slash command `/tag` --- bot/exts/info/tags.py | 160 ++++++++++++++++-------------------------- bot/pagination.py | 20 ++++-- 2 files changed, 76 insertions(+), 104 deletions(-) diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index 83d3a9d939..25c51def9e 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -4,12 +4,12 @@ import re import time from pathlib import Path -from typing import Callable, Iterable, Literal, NamedTuple, Optional, Union +from typing import Literal, NamedTuple, Optional, Union import discord import frontmatter -from discord import Embed, Member -from discord.ext.commands import Cog, Context, group +from discord import Embed, Member, app_commands +from discord.ext.commands import Cog from bot import constants from bot.bot import Bot @@ -27,6 +27,8 @@ REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags ." +GUILD_ID = constants.Guild.id + class COOLDOWN(enum.Enum): """Sentinel value to signal that a tag is on cooldown.""" @@ -138,6 +140,15 @@ def __init__(self, bot: Bot): self.bot = bot self.tags: dict[TagIdentifier, Tag] = {} self.initialize_tags() + self.bot.tree.copy_global_to(guild=discord.Object(id=GUILD_ID)) + + tag_group = app_commands.Group(name="tag", description="...") + # search_tag = app_commands.Group(name="search", description="...", parent=tag_group) + + @Cog.listener() + async def on_ready(self) -> None: + """Called when the cog is ready.""" + await self.bot.tree.sync(guild=discord.Object(id=GUILD_ID)) def initialize_tags(self) -> None: """Load all tags from resources into `self.tags`.""" @@ -182,90 +193,9 @@ def get_fuzzy_matches(self, tag_identifier: TagIdentifier) -> list[tuple[TagIden return suggestions - def _get_tags_via_content( - self, - check: Callable[[Iterable], bool], - keywords: str, - user: Member, - ) -> list[tuple[TagIdentifier, Tag]]: - """ - Search for tags via contents. - - `predicate` will be the built-in any, all, or a custom callable. Must return a bool. - """ - keywords_processed = [] - for keyword in keywords.split(","): - keyword_sanitized = keyword.strip().casefold() - if not keyword_sanitized: - # this happens when there are leading / trailing / consecutive comma. - continue - keywords_processed.append(keyword_sanitized) - - if not keywords_processed: - # after sanitizing, we can end up with an empty list, for example when keywords is "," - # in that case, we simply want to search for such keywords directly instead. - keywords_processed = [keywords] - - matching_tags = [] - for identifier, tag in self.tags.items(): - matches = (query in tag.content.casefold() for query in keywords_processed) - if tag.accessible_by(user) and check(matches): - matching_tags.append((identifier, tag)) - - return matching_tags - - async def _send_matching_tags( - self, - ctx: Context, - keywords: str, - matching_tags: list[tuple[TagIdentifier, Tag]], - ) -> None: - """Send the result of matching tags to user.""" - if len(matching_tags) == 1: - await ctx.send(embed=matching_tags[0][1].embed) - elif matching_tags: - is_plural = keywords.strip().count(" ") > 0 or keywords.strip().count(",") > 0 - embed = Embed( - title=f"Here are the tags containing the given keyword{'s' * is_plural}:", - ) - await LinePaginator.paginate( - sorted( - f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier.name}" - for identifier, _ in matching_tags - ), - ctx, - embed, - **self.PAGINATOR_DEFAULTS, - ) - - @group(name="tags", aliases=("tag", "t"), invoke_without_command=True, usage="[tag_group] [tag_name]") - async def tags_group(self, ctx: Context, *, argument_string: Optional[str]) -> None: - """Show all known tags, a single tag, or run a subcommand.""" - await self.get_command(ctx, argument_string=argument_string) - - @tags_group.group(name="search", invoke_without_command=True) - async def search_tag_content(self, ctx: Context, *, keywords: str) -> None: - """ - Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. - - Only search for tags that has ALL the keywords. - """ - matching_tags = self._get_tags_via_content(all, keywords, ctx.author) - await self._send_matching_tags(ctx, keywords, matching_tags) - - @search_tag_content.command(name="any") - async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = "any") -> None: - """ - Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma. - - Search for tags that has ANY of the keywords. - """ - matching_tags = self._get_tags_via_content(any, keywords or "any", ctx.author) - await self._send_matching_tags(ctx, keywords, matching_tags) - async def get_tag_embed( self, - ctx: Context, + interaction: discord.Interaction, tag_identifier: TagIdentifier, ) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]: """ @@ -276,7 +206,7 @@ async def get_tag_embed( filtered_tags = [ (ident, tag) for ident, tag in self.get_fuzzy_matches(tag_identifier)[:10] - if tag.accessible_by(ctx.author) + if tag.accessible_by(interaction.user) ] # Try exact match, includes checking through alt names @@ -295,10 +225,10 @@ async def get_tag_embed( tag = filtered_tags[0][1] if tag is not None: - if tag.on_cooldown_in(ctx.channel): + if tag.on_cooldown_in(interaction.channel): log.debug(f"Tag {str(tag_identifier)!r} is on cooldown.") return COOLDOWN.obj - tag.set_cooldown_for(ctx.channel) + tag.set_cooldown_for(interaction.channel) self.bot.stats.incr( f"tags.usages" @@ -313,7 +243,7 @@ async def get_tag_embed( suggested_tags_text = "\n".join( f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" for identifier, tag in filtered_tags - if not tag.on_cooldown_in(ctx.channel) + if not tag.on_cooldown_in(interaction.channel) ) return Embed( title="Did you mean ...", @@ -362,8 +292,8 @@ def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str if identifier.group == group and tag.accessible_by(user) ) - @tags_group.command(name="get", aliases=("show", "g"), usage="[tag_group] [tag_name]") - async def get_command(self, ctx: Context, *, argument_string: Optional[str]) -> bool: + @tag_group.command(name="get") + async def get_command(self, interaction: discord.Interaction, *, tag_name: Optional[str]) -> bool: """ If a single argument matching a group name is given, list all accessible tags from that group Otherwise display the tag if one was found for the given arguments, or try to display suggestions for that name. @@ -373,38 +303,68 @@ async def get_command(self, ctx: Context, *, argument_string: Optional[str]) -> Returns True if a message was sent, or if the tag is on cooldown. Returns False if no message was sent. """ # noqa: D205, D415 - if not argument_string: + if not tag_name: if self.tags: await LinePaginator.paginate( - self.accessible_tags(ctx.author), ctx, Embed(title="Available tags"), **self.PAGINATOR_DEFAULTS + self.accessible_tags(interaction.user), + interaction, Embed(title="Available tags"), + **self.PAGINATOR_DEFAULTS, ) else: - await ctx.send(embed=Embed(description="**There are no tags!**")) + await interaction.response.send_message(embed=Embed(description="**There are no tags!**")) return True - identifier = TagIdentifier.from_string(argument_string) + identifier = TagIdentifier.from_string(tag_name) if identifier.group is None: # Try to find accessible tags from a group matching the identifier's name. - if group_tags := self.accessible_tags_in_group(identifier.name, ctx.author): + if group_tags := self.accessible_tags_in_group(identifier.name, interaction.user): await LinePaginator.paginate( - group_tags, ctx, Embed(title=f"Tags under *{identifier.name}*"), **self.PAGINATOR_DEFAULTS + group_tags, interaction, Embed(title=f"Tags under *{identifier.name}*"), **self.PAGINATOR_DEFAULTS ) return True - embed = await self.get_tag_embed(ctx, identifier) + embed = await self.get_tag_embed(interaction, identifier) if embed is None: return False if embed is not COOLDOWN.obj: await wait_for_deletion( - await ctx.send(embed=embed), - (ctx.author.id,) + await interaction.response.send_message(embed=embed), + (interaction.user.id,) ) # A valid tag was found and was either sent, or is on cooldown return True + @get_command.autocomplete("tag_name") + async def tag_name_autocomplete( + self, + interaction: discord.Interaction, + current: str + ) -> list[app_commands.Choice[str]]: + """Autocompleter for `/tag get` command.""" + tag_names = [tag.name for tag in self.tags.keys()] + return [ + app_commands.Choice(name=tag, value=tag) + for tag in tag_names if current.lower() in tag + ] + + @tag_group.command(name="list") + async def list_command(self, interaction: discord.Interaction) -> bool: + """Lists all accessible tags.""" + if self.tags: + await LinePaginator.paginate( + self.accessible_tags(interaction.user), + interaction, + Embed(title="Available tags"), + **self.PAGINATOR_DEFAULTS, + ) + else: + await interaction.response.send_message(embed=Embed(description="**There are no tags!**")) + return True + async def setup(bot: Bot) -> None: """Load the Tags cog.""" await bot.add_cog(Tags(bot)) + await bot.tree.sync(guild=discord.Object(id=GUILD_ID)) diff --git a/bot/pagination.py b/bot/pagination.py index 0ef5808ccc..1c63a47688 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -191,7 +191,7 @@ def _split_remaining_words(self, line: str, max_chars: int) -> t.Tuple[str, t.Op async def paginate( cls, lines: t.List[str], - ctx: Context, + ctx: t.Union[Context, discord.Interaction], embed: discord.Embed, prefix: str = "", suffix: str = "", @@ -228,7 +228,10 @@ async def paginate( current_page = 0 if not restrict_to_user: - restrict_to_user = ctx.author + if isinstance(ctx, discord.Interaction): + restrict_to_user = ctx.user + else: + restrict_to_user = ctx.author if not lines: if exception_on_empty_embed: @@ -261,6 +264,8 @@ async def paginate( log.trace(f"Setting embed url to '{url}'") log.debug("There's less than two pages, so we won't paginate - sending single page on its own") + if isinstance(ctx, discord.Interaction): + return await ctx.response.send_message(embed=embed) return await ctx.send(embed=embed) else: if footer_text: @@ -274,7 +279,11 @@ async def paginate( log.trace(f"Setting embed url to '{url}'") log.debug("Sending first page to channel...") - message = await ctx.send(embed=embed) + if isinstance(ctx, discord.Interaction): + await ctx.response.send_message(embed=embed) + message = await ctx.original_response() + else: + message = await ctx.send(embed=embed) log.debug("Adding emoji reactions to message...") @@ -292,7 +301,10 @@ async def paginate( while True: try: - reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=check) + if isinstance(ctx, discord.Interaction): + reaction, user = await ctx.client.wait_for("reaction_add", timeout=timeout, check=check) + else: + reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=check) log.trace(f"Got reaction: {reaction}") except asyncio.TimeoutError: log.debug("Timed out waiting for a reaction") From 1cff5bf589a848576d3d1f4a9c1ab71633406caf Mon Sep 17 00:00:00 2001 From: Ibrahim2750mi Date: Tue, 14 Feb 2023 21:08:09 +0530 Subject: [PATCH 2/6] Update tests for `/tag` as of migration to slash commands --- bot/exts/backend/error_handler.py | 22 ++++++---- tests/bot/exts/backend/test_error_handler.py | 44 ++++++++++---------- tests/helpers.py | 20 +++++++++ 3 files changed, 56 insertions(+), 30 deletions(-) diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index cc2b5ef567..561bf80688 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,7 +1,8 @@ import copy import difflib +import typing as t -from discord import Embed +from discord import Embed, Interaction from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from pydis_core.site_api import ResponseCodeError from sentry_sdk import push_scope @@ -21,6 +22,10 @@ class ErrorHandler(Cog): def __init__(self, bot: Bot): self.bot = bot + @staticmethod + async def _can_run(_: Interaction) -> bool: + return False + def _get_error_embed(self, title: str, body: str) -> Embed: """Return an embed that contains the exception.""" return Embed( @@ -159,7 +164,7 @@ async def try_silence(self, ctx: Context) -> bool: return True return False - async def try_get_tag(self, ctx: Context) -> None: + async def try_get_tag(self, interaction: Interaction, can_run: t.Callable[[Interaction], bool] = False) -> None: """ Attempt to display a tag by interpreting the command name as a tag name. @@ -168,27 +173,28 @@ async def try_get_tag(self, ctx: Context) -> None: the context to prevent infinite recursion in the case of a CommandNotFound exception. """ tags_get_command = self.bot.get_command("tags get") + tags_get_command.can_run = can_run if can_run else self._can_run if not tags_get_command: log.debug("Not attempting to parse message as a tag as could not find `tags get` command.") return - ctx.invoked_from_error_handler = True + interaction.invoked_from_error_handler = True log_msg = "Cancelling attempt to fall back to a tag due to failed checks." try: - if not await tags_get_command.can_run(ctx): + if not await tags_get_command.can_run(interaction): log.debug(log_msg) return except errors.CommandError as tag_error: log.debug(log_msg) - await self.on_command_error(ctx, tag_error) + await self.on_command_error(interaction, tag_error) return - if await ctx.invoke(tags_get_command, argument_string=ctx.message.content): + if await interaction.invoke(tags_get_command, tag_name=interaction.message.content): return - if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): - await self.send_command_suggestion(ctx, ctx.invoked_with) + if not any(role.id in MODERATION_ROLES for role in interaction.user.roles): + await self.send_command_suggestion(interaction, interaction.invoked_with) async def try_run_eval(self, ctx: Context) -> bool: """ diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index adb0252a5e..83bc3c4a10 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -9,7 +9,7 @@ from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure -from tests.helpers import MockBot, MockContext, MockGuild, MockRole, MockTextChannel, MockVoiceChannel +from tests.helpers import MockBot, MockContext, MockGuild, MockInteraction, MockRole, MockTextChannel, MockVoiceChannel class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @@ -331,7 +331,7 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() - self.ctx = MockContext() + self.interaction = MockInteraction() self.tag = Tags(self.bot) self.cog = error_handler.ErrorHandler(self.bot) self.bot.get_command.return_value = self.tag.get_command @@ -339,57 +339,57 @@ def setUp(self): async def test_try_get_tag_get_command(self): """Should call `Bot.get_command` with `tags get` argument.""" self.bot.get_command.reset_mock() - await self.cog.try_get_tag(self.ctx) + await self.cog.try_get_tag(self.interaction) self.bot.get_command.assert_called_once_with("tags get") async def test_try_get_tag_invoked_from_error_handler(self): - """`self.ctx` should have `invoked_from_error_handler` `True`.""" - self.ctx.invoked_from_error_handler = False - await self.cog.try_get_tag(self.ctx) - self.assertTrue(self.ctx.invoked_from_error_handler) + """`self.interaction` should have `invoked_from_error_handler` `True`.""" + self.interaction.invoked_from_error_handler = False + await self.cog.try_get_tag(self.interaction) + self.assertTrue(self.interaction.invoked_from_error_handler) async def test_try_get_tag_no_permissions(self): """Test how to handle checks failing.""" self.tag.get_command.can_run = AsyncMock(return_value=False) - self.ctx.invoked_with = "foo" - self.assertIsNone(await self.cog.try_get_tag(self.ctx)) + self.interaction.invoked_with = "foo" + self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(return_value=False))) async def test_try_get_tag_command_error(self): """Should call `on_command_error` when `CommandError` raised.""" err = errors.CommandError() self.tag.get_command.can_run = AsyncMock(side_effect=err) self.cog.on_command_error = AsyncMock() - self.assertIsNone(await self.cog.try_get_tag(self.ctx)) - self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) + self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(side_effect=err))) + self.cog.on_command_error.assert_awaited_once_with(self.interaction, err) async def test_dont_call_suggestion_tag_sent(self): """Should never call command suggestion if tag is already sent.""" - self.ctx.message = MagicMock(content="foo") - self.ctx.invoke = AsyncMock(return_value=True) + self.interaction.message = MagicMock(content="foo") + self.interaction.invoke = AsyncMock(return_value=True) self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.ctx) + await self.cog.try_get_tag(self.interaction, AsyncMock()) self.cog.send_command_suggestion.assert_not_awaited() @patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234]) async def test_dont_call_suggestion_if_user_mod(self): """Should not call command suggestion if user is a mod.""" - self.ctx.invoked_with = "foo" - self.ctx.invoke = AsyncMock(return_value=False) - self.ctx.author.roles = [MockRole(id=1234)] + self.interaction.invoked_with = "foo" + self.interaction.invoke = AsyncMock(return_value=False) + self.interaction.user.roles = [MockRole(id=1234)] self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.ctx) + await self.cog.try_get_tag(self.interaction, AsyncMock()) self.cog.send_command_suggestion.assert_not_awaited() async def test_call_suggestion(self): """Should call command suggestion if user is not a mod.""" - self.ctx.invoked_with = "foo" - self.ctx.invoke = AsyncMock(return_value=False) + self.interaction.invoked_with = "foo" + self.interaction.invoke = AsyncMock(return_value=False) self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.ctx) - self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo") + await self.cog.try_get_tag(self.interaction, AsyncMock()) + self.cog.send_command_suggestion.assert_awaited_once_with(self.interaction, "foo") class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): diff --git a/tests/helpers.py b/tests/helpers.py index 4b980ac217..2d20b4d079 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -479,6 +479,26 @@ def __init__(self, **kwargs) -> None: self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) +class MockInteraction(CustomMockMixin, unittest.mock.MagicMock): + """ + A MagicMock subclass to mock Interaction objects. + + Instances of this class will follow the specifications of `discord.Interaction` + instances. For more information, see the `MockGuild` docstring. + """ + # spec_set = context_instance + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.me = kwargs.get('me', MockMember()) + self.client = kwargs.get('client', MockBot()) + self.guild = kwargs.get('guild', MockGuild()) + self.user = kwargs.get('user', MockMember()) + self.channel = kwargs.get('channel', MockTextChannel()) + self.message = kwargs.get('message', MockMessage()) + self.invoked_from_error_handler = kwargs.get('invoked_from_error_handler', False) + + attachment_instance = discord.Attachment(data=unittest.mock.MagicMock(id=1), state=unittest.mock.MagicMock()) From 9b98dfe78bb226e26a8d9cb6e8a0e8f8504286dd Mon Sep 17 00:00:00 2001 From: Ibrahim Date: Thu, 23 Feb 2023 04:08:57 +0530 Subject: [PATCH 3/6] Implement all reviews + Remove commented code + Remove unecessarily syncting the bot + Handle direct tag commads + 3.10 type hinting in concerned functions + Add `MockInteractionMessage` + Fix tests for `try_get_tag` --- bot/exts/backend/error_handler.py | 40 ++++++--- bot/exts/info/tags.py | 90 +++++++++++++------- bot/pagination.py | 4 +- bot/utils/messages.py | 2 +- tests/bot/exts/backend/test_error_handler.py | 50 +++++------ tests/helpers.py | 11 ++- 6 files changed, 125 insertions(+), 72 deletions(-) diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 561bf80688..6561f84e4f 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -2,7 +2,7 @@ import difflib import typing as t -from discord import Embed, Interaction +from discord import Embed, Interaction, utils from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from pydis_core.site_api import ResponseCodeError from sentry_sdk import push_scope @@ -23,8 +23,19 @@ def __init__(self, bot: Bot): self.bot = bot @staticmethod - async def _can_run(_: Interaction) -> bool: - return False + async def _can_run(ctx: Context) -> bool: + """ + Add checks for the `get_command_ctx` function here. + + Use discord.utils to run the checks. + """ + checks = [] + predicates = checks + if not predicates: + # Since we have no checks, then we just return True. + return True + + return await utils.async_all(predicate(ctx) for predicate in predicates) def _get_error_embed(self, title: str, body: str) -> Embed: """Return an embed that contains the exception.""" @@ -164,7 +175,7 @@ async def try_silence(self, ctx: Context) -> bool: return True return False - async def try_get_tag(self, interaction: Interaction, can_run: t.Callable[[Interaction], bool] = False) -> None: + async def try_get_tag(self, ctx: Context, can_run: t.Callable[[Interaction], bool] = False) -> None: """ Attempt to display a tag by interpreting the command name as a tag name. @@ -172,29 +183,30 @@ async def try_get_tag(self, interaction: Interaction, can_run: t.Callable[[Inter by `on_command_error`, but the `invoked_from_error_handler` attribute will be added to the context to prevent infinite recursion in the case of a CommandNotFound exception. """ - tags_get_command = self.bot.get_command("tags get") - tags_get_command.can_run = can_run if can_run else self._can_run - if not tags_get_command: - log.debug("Not attempting to parse message as a tag as could not find `tags get` command.") + tags_cog = self.bot.get_cog("Tags") + if not tags_cog: + log.debug("Not attempting to parse message as a tag as could not find `Tags` cog.") return + tags_get_command = tags_cog.get_command_ctx + can_run = can_run if can_run else self._can_run - interaction.invoked_from_error_handler = True + ctx.invoked_from_error_handler = True log_msg = "Cancelling attempt to fall back to a tag due to failed checks." try: - if not await tags_get_command.can_run(interaction): + if not await can_run(ctx): log.debug(log_msg) return except errors.CommandError as tag_error: log.debug(log_msg) - await self.on_command_error(interaction, tag_error) + await self.on_command_error(ctx, tag_error) return - if await interaction.invoke(tags_get_command, tag_name=interaction.message.content): + if await tags_get_command(ctx, ctx.message.content): return - if not any(role.id in MODERATION_ROLES for role in interaction.user.roles): - await self.send_command_suggestion(interaction, interaction.invoked_with) + if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): + await self.send_command_suggestion(ctx, ctx.invoked_with) async def try_run_eval(self, ctx: Context) -> bool: """ diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index 25c51def9e..60f7305865 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -8,7 +8,7 @@ import discord import frontmatter -from discord import Embed, Member, app_commands +from discord import Embed, Interaction, Member, app_commands from discord.ext.commands import Cog from bot import constants @@ -140,15 +140,8 @@ def __init__(self, bot: Bot): self.bot = bot self.tags: dict[TagIdentifier, Tag] = {} self.initialize_tags() - self.bot.tree.copy_global_to(guild=discord.Object(id=GUILD_ID)) tag_group = app_commands.Group(name="tag", description="...") - # search_tag = app_commands.Group(name="search", description="...", parent=tag_group) - - @Cog.listener() - async def on_ready(self) -> None: - """Called when the cog is ready.""" - await self.bot.tree.sync(guild=discord.Object(id=GUILD_ID)) def initialize_tags(self) -> None: """Load all tags from resources into `self.tags`.""" @@ -195,7 +188,8 @@ def get_fuzzy_matches(self, tag_identifier: TagIdentifier) -> list[tuple[TagIden async def get_tag_embed( self, - interaction: discord.Interaction, + author: discord.Member, + channel: discord.TextChannel | discord.Thread, tag_identifier: TagIdentifier, ) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]: """ @@ -206,7 +200,7 @@ async def get_tag_embed( filtered_tags = [ (ident, tag) for ident, tag in self.get_fuzzy_matches(tag_identifier)[:10] - if tag.accessible_by(interaction.user) + if tag.accessible_by(author) ] # Try exact match, includes checking through alt names @@ -225,10 +219,10 @@ async def get_tag_embed( tag = filtered_tags[0][1] if tag is not None: - if tag.on_cooldown_in(interaction.channel): + if tag.on_cooldown_in(channel): log.debug(f"Tag {str(tag_identifier)!r} is on cooldown.") return COOLDOWN.obj - tag.set_cooldown_for(interaction.channel) + tag.set_cooldown_for(channel) self.bot.stats.incr( f"tags.usages" @@ -243,7 +237,7 @@ async def get_tag_embed( suggested_tags_text = "\n".join( f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" for identifier, tag in filtered_tags - if not tag.on_cooldown_in(interaction.channel) + if not tag.on_cooldown_in(channel) ) return Embed( title="Did you mean ...", @@ -292,8 +286,37 @@ def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str if identifier.group == group and tag.accessible_by(user) ) + async def get_command_ctx( + self, + ctx: discord.Context, + name: str + ) -> bool: + """Made specifically for `error_handler.py`, See `get_command` for more info.""" + identifier = TagIdentifier.from_string(name) + + if identifier.group is None: + # Try to find accessible tags from a group matching the identifier's name. + if group_tags := self.accessible_tags_in_group(identifier.name, ctx.author): + await LinePaginator.paginate( + group_tags, ctx, Embed(title=f"Tags under *{identifier.name}*"), **self.PAGINATOR_DEFAULTS + ) + return True + + embed = await self.get_tag_embed(ctx.author, ctx.channel, identifier) + if embed is None: + return False + + if embed is not COOLDOWN.obj: + + await wait_for_deletion( + await ctx.send(embed=embed), + (ctx.author.id,) + ) + # A valid tag was found and was either sent, or is on cooldown + return True + @tag_group.command(name="get") - async def get_command(self, interaction: discord.Interaction, *, tag_name: Optional[str]) -> bool: + async def get_command(self, interaction: Interaction, *, name: Optional[str]) -> bool: """ If a single argument matching a group name is given, list all accessible tags from that group Otherwise display the tag if one was found for the given arguments, or try to display suggestions for that name. @@ -303,7 +326,7 @@ async def get_command(self, interaction: discord.Interaction, *, tag_name: Optio Returns True if a message was sent, or if the tag is on cooldown. Returns False if no message was sent. """ # noqa: D205, D415 - if not tag_name: + if not name: if self.tags: await LinePaginator.paginate( self.accessible_tags(interaction.user), @@ -314,7 +337,7 @@ async def get_command(self, interaction: discord.Interaction, *, tag_name: Optio await interaction.response.send_message(embed=Embed(description="**There are no tags!**")) return True - identifier = TagIdentifier.from_string(tag_name) + identifier = TagIdentifier.from_string(name) if identifier.group is None: # Try to find accessible tags from a group matching the identifier's name. @@ -324,33 +347,43 @@ async def get_command(self, interaction: discord.Interaction, *, tag_name: Optio ) return True - embed = await self.get_tag_embed(interaction, identifier) + embed = await self.get_tag_embed(interaction.user, interaction.channel, identifier) + ephemeral = False if embed is None: - return False - - if embed is not COOLDOWN.obj: + description = f"**There are no tags matching the name {name!r}!**" + embed = Embed(description=description) + ephemeral = True + elif embed is COOLDOWN.obj: + description = f"Tag {name!r} is on cooldown." + embed = Embed(description=description) + ephemeral = True + + await interaction.response.send_message(embed=embed, ephemeral=ephemeral) + if not ephemeral: await wait_for_deletion( - await interaction.response.send_message(embed=embed), + await interaction.original_response(), (interaction.user.id,) ) + # A valid tag was found and was either sent, or is on cooldown return True - @get_command.autocomplete("tag_name") - async def tag_name_autocomplete( + @get_command.autocomplete("name") + async def name_autocomplete( self, - interaction: discord.Interaction, + interaction: Interaction, current: str ) -> list[app_commands.Choice[str]]: """Autocompleter for `/tag get` command.""" - tag_names = [tag.name for tag in self.tags.keys()] - return [ + names = [tag.name for tag in self.tags.keys()] + choices = [ app_commands.Choice(name=tag, value=tag) - for tag in tag_names if current.lower() in tag + for tag in names if current.lower() in tag ] + return choices[:25] if len(choices) > 25 else choices @tag_group.command(name="list") - async def list_command(self, interaction: discord.Interaction) -> bool: + async def list_command(self, interaction: Interaction) -> bool: """Lists all accessible tags.""" if self.tags: await LinePaginator.paginate( @@ -367,4 +400,3 @@ async def list_command(self, interaction: discord.Interaction) -> bool: async def setup(bot: Bot) -> None: """Load the Tags cog.""" await bot.add_cog(Tags(bot)) - await bot.tree.sync(guild=discord.Object(id=GUILD_ID)) diff --git a/bot/pagination.py b/bot/pagination.py index 1c63a47688..c39ce211b1 100644 --- a/bot/pagination.py +++ b/bot/pagination.py @@ -190,8 +190,8 @@ def _split_remaining_words(self, line: str, max_chars: int) -> t.Tuple[str, t.Op @classmethod async def paginate( cls, - lines: t.List[str], - ctx: t.Union[Context, discord.Interaction], + lines: list[str], + ctx: Context | discord.Interaction, embed: discord.Embed, prefix: str = "", suffix: str = "", diff --git a/bot/utils/messages.py b/bot/utils/messages.py index 27f2eac974..f6bdceaefe 100644 --- a/bot/utils/messages.py +++ b/bot/utils/messages.py @@ -58,7 +58,7 @@ def reaction_check( async def wait_for_deletion( - message: discord.Message, + message: discord.Message | discord.InteractionMessage, user_ids: Sequence[int], deletion_emojis: Sequence[str] = (Emojis.trashcan,), timeout: float = 60 * 5, diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 83bc3c4a10..14e7a41254 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -9,7 +9,7 @@ from bot.exts.info.tags import Tags from bot.exts.moderation.silence import Silence from bot.utils.checks import InWhitelistCheckFailure -from tests.helpers import MockBot, MockContext, MockGuild, MockInteraction, MockRole, MockTextChannel, MockVoiceChannel +from tests.helpers import MockBot, MockContext, MockGuild, MockRole, MockTextChannel, MockVoiceChannel class ErrorHandlerTests(unittest.IsolatedAsyncioTestCase): @@ -331,65 +331,65 @@ class TryGetTagTests(unittest.IsolatedAsyncioTestCase): def setUp(self): self.bot = MockBot() - self.interaction = MockInteraction() + self.ctx = MockContext() self.tag = Tags(self.bot) self.cog = error_handler.ErrorHandler(self.bot) - self.bot.get_command.return_value = self.tag.get_command + self.bot.get_cog.return_value = self.tag async def test_try_get_tag_get_command(self): """Should call `Bot.get_command` with `tags get` argument.""" - self.bot.get_command.reset_mock() - await self.cog.try_get_tag(self.interaction) - self.bot.get_command.assert_called_once_with("tags get") + self.bot.get_cog.reset_mock() + await self.cog.try_get_tag(self.ctx) + self.bot.get_cog.assert_called_once_with("Tags") async def test_try_get_tag_invoked_from_error_handler(self): - """`self.interaction` should have `invoked_from_error_handler` `True`.""" - self.interaction.invoked_from_error_handler = False - await self.cog.try_get_tag(self.interaction) - self.assertTrue(self.interaction.invoked_from_error_handler) + """`self.ctx` should have `invoked_from_error_handler` `True`.""" + self.ctx.invoked_from_error_handler = False + await self.cog.try_get_tag(self.ctx) + self.assertTrue(self.ctx.invoked_from_error_handler) async def test_try_get_tag_no_permissions(self): """Test how to handle checks failing.""" self.tag.get_command.can_run = AsyncMock(return_value=False) - self.interaction.invoked_with = "foo" - self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(return_value=False))) + self.ctx.invoked_with = "foo" + self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(return_value=False))) async def test_try_get_tag_command_error(self): """Should call `on_command_error` when `CommandError` raised.""" err = errors.CommandError() self.tag.get_command.can_run = AsyncMock(side_effect=err) self.cog.on_command_error = AsyncMock() - self.assertIsNone(await self.cog.try_get_tag(self.interaction, AsyncMock(side_effect=err))) - self.cog.on_command_error.assert_awaited_once_with(self.interaction, err) + self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(side_effect=err))) + self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) async def test_dont_call_suggestion_tag_sent(self): """Should never call command suggestion if tag is already sent.""" - self.interaction.message = MagicMock(content="foo") - self.interaction.invoke = AsyncMock(return_value=True) + self.ctx.message = MagicMock(content="foo") + self.tag.get_command_ctx = AsyncMock(return_value=True) self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.interaction, AsyncMock()) + await self.cog.try_get_tag(self.ctx) self.cog.send_command_suggestion.assert_not_awaited() @patch("bot.exts.backend.error_handler.MODERATION_ROLES", new=[1234]) async def test_dont_call_suggestion_if_user_mod(self): """Should not call command suggestion if user is a mod.""" - self.interaction.invoked_with = "foo" - self.interaction.invoke = AsyncMock(return_value=False) - self.interaction.user.roles = [MockRole(id=1234)] + self.ctx.invoked_with = "foo" + self.ctx.invoke = AsyncMock(return_value=False) + self.ctx.author.roles = [MockRole(id=1234)] self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.interaction, AsyncMock()) + await self.cog.try_get_tag(self.ctx) self.cog.send_command_suggestion.assert_not_awaited() async def test_call_suggestion(self): """Should call command suggestion if user is not a mod.""" - self.interaction.invoked_with = "foo" - self.interaction.invoke = AsyncMock(return_value=False) + self.ctx.invoked_with = "foo" + self.ctx.invoke = AsyncMock(return_value=False) self.cog.send_command_suggestion = AsyncMock() - await self.cog.try_get_tag(self.interaction, AsyncMock()) - self.cog.send_command_suggestion.assert_awaited_once_with(self.interaction, "foo") + await self.cog.try_get_tag(self.ctx) + self.cog.send_command_suggestion.assert_awaited_once_with(self.ctx, "foo") class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase): diff --git a/tests/helpers.py b/tests/helpers.py index 2d20b4d079..0d955b5216 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -486,7 +486,6 @@ class MockInteraction(CustomMockMixin, unittest.mock.MagicMock): Instances of this class will follow the specifications of `discord.Interaction` instances. For more information, see the `MockGuild` docstring. """ - # spec_set = context_instance def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @@ -550,6 +549,16 @@ def __init__(self, **kwargs) -> None: self.channel = kwargs.get('channel', MockTextChannel()) +class MockInteractionMessage(MockMessage): + """ + A MagicMock subclass to mock InteractionMessage objects. + + Instances of this class will follow the specifications of `discord.InteractionMessage` instances. For more + information, see the `MockGuild` docstring. + """ + pass + + emoji_data = {'require_colons': True, 'managed': True, 'id': 1, 'name': 'hyperlemon'} emoji_instance = discord.Emoji(guild=MockGuild(), state=unittest.mock.MagicMock(), data=emoji_data) From dabf67d92620cf0772e8443b5be57207280544c9 Mon Sep 17 00:00:00 2001 From: Ibrahim Date: Thu, 23 Feb 2023 16:05:37 +0530 Subject: [PATCH 4/6] Upadte docstring for `ErrorHandler()._can_run` --- bot/exts/backend/error_handler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 6561f84e4f..839d882de3 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -27,7 +27,8 @@ async def _can_run(ctx: Context) -> bool: """ Add checks for the `get_command_ctx` function here. - Use discord.utils to run the checks. + The command code style is copied from discord.ext.commands.Command.can_run itself. + Append checks in the checks list. """ checks = [] predicates = checks From 03614c313341497e61c45bbb2a364b969d2bb163 Mon Sep 17 00:00:00 2001 From: Ibrahim Date: Sun, 26 Feb 2023 17:44:47 +0530 Subject: [PATCH 5/6] Implement reviews + used both `discord.User` and `discord.Member` in typehinting as `InteractionResponse.user` returns `discord.User` object + removed `ErrorHandler()._can_run` + edited `try_get_tag` to use `bot.can_run` + removed `/tag list` + change `/tag get ` to `/tag ` + remove redundant `GUILD_ID` in `tags.py` + using `discord.abc.Messageable` because `ctx.channel` returns that instead of `Channel` Object --- bot/exts/backend/error_handler.py | 50 +++++++------------- bot/exts/info/tags.py | 36 ++++---------- tests/bot/exts/backend/test_error_handler.py | 10 ++-- 3 files changed, 32 insertions(+), 64 deletions(-) diff --git a/bot/exts/backend/error_handler.py b/bot/exts/backend/error_handler.py index 839d882de3..e274e337a9 100644 --- a/bot/exts/backend/error_handler.py +++ b/bot/exts/backend/error_handler.py @@ -1,8 +1,7 @@ import copy import difflib -import typing as t -from discord import Embed, Interaction, utils +from discord import Embed, Member from discord.ext.commands import ChannelNotFound, Cog, Context, TextChannelConverter, VoiceChannelConverter, errors from pydis_core.site_api import ResponseCodeError from sentry_sdk import push_scope @@ -22,22 +21,6 @@ class ErrorHandler(Cog): def __init__(self, bot: Bot): self.bot = bot - @staticmethod - async def _can_run(ctx: Context) -> bool: - """ - Add checks for the `get_command_ctx` function here. - - The command code style is copied from discord.ext.commands.Command.can_run itself. - Append checks in the checks list. - """ - checks = [] - predicates = checks - if not predicates: - # Since we have no checks, then we just return True. - return True - - return await utils.async_all(predicate(ctx) for predicate in predicates) - def _get_error_embed(self, title: str, body: str) -> Embed: """Return an embed that contains the exception.""" return Embed( @@ -176,7 +159,7 @@ async def try_silence(self, ctx: Context) -> bool: return True return False - async def try_get_tag(self, ctx: Context, can_run: t.Callable[[Interaction], bool] = False) -> None: + async def try_get_tag(self, ctx: Context) -> None: """ Attempt to display a tag by interpreting the command name as a tag name. @@ -189,25 +172,28 @@ async def try_get_tag(self, ctx: Context, can_run: t.Callable[[Interaction], boo log.debug("Not attempting to parse message as a tag as could not find `Tags` cog.") return tags_get_command = tags_cog.get_command_ctx - can_run = can_run if can_run else self._can_run - ctx.invoked_from_error_handler = True + maybe_tag_name = ctx.invoked_with + if not maybe_tag_name or not isinstance(ctx.author, Member): + return - log_msg = "Cancelling attempt to fall back to a tag due to failed checks." + ctx.invoked_from_error_handler = True try: - if not await can_run(ctx): - log.debug(log_msg) + if not await self.bot.can_run(ctx): + log.debug("Cancelling attempt to fall back to a tag due to failed checks.") return - except errors.CommandError as tag_error: - log.debug(log_msg) - await self.on_command_error(ctx, tag_error) - return - if await tags_get_command(ctx, ctx.message.content): - return + if await tags_get_command(ctx, maybe_tag_name): + return - if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): - await self.send_command_suggestion(ctx, ctx.invoked_with) + if not any(role.id in MODERATION_ROLES for role in ctx.author.roles): + await self.send_command_suggestion(ctx, maybe_tag_name) + except Exception as err: + log.debug("Error while attempting to invoke tag fallback.") + if isinstance(err, errors.CommandError): + await self.on_command_error(ctx, err) + else: + await self.on_command_error(ctx, errors.CommandInvokeError(err)) async def try_run_eval(self, ctx: Context) -> bool: """ diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index 60f7305865..0c244ff37c 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -8,8 +8,8 @@ import discord import frontmatter -from discord import Embed, Interaction, Member, app_commands -from discord.ext.commands import Cog +from discord import Embed, Interaction, Member, User, app_commands +from discord.ext.commands import Cog, Context from bot import constants from bot.bot import Bot @@ -27,8 +27,6 @@ REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE) FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags ." -GUILD_ID = constants.Guild.id - class COOLDOWN(enum.Enum): """Sentinel value to signal that a tag is on cooldown.""" @@ -93,7 +91,7 @@ def embed(self) -> Embed: embed.description = self.content return embed - def accessible_by(self, member: discord.Member) -> bool: + def accessible_by(self, member: Member | User) -> bool: """Check whether `member` can access the tag.""" return bool( not self._restricted_to @@ -141,8 +139,6 @@ def __init__(self, bot: Bot): self.tags: dict[TagIdentifier, Tag] = {} self.initialize_tags() - tag_group = app_commands.Group(name="tag", description="...") - def initialize_tags(self) -> None: """Load all tags from resources into `self.tags`.""" base_path = Path("bot", "resources", "tags") @@ -188,8 +184,8 @@ def get_fuzzy_matches(self, tag_identifier: TagIdentifier) -> list[tuple[TagIden async def get_tag_embed( self, - author: discord.Member, - channel: discord.TextChannel | discord.Thread, + author: Member | User, + channel: discord.abc.Messageable, tag_identifier: TagIdentifier, ) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]: """ @@ -244,7 +240,7 @@ async def get_tag_embed( description=suggested_tags_text ) - def accessible_tags(self, user: Member) -> list[str]: + def accessible_tags(self, user: Member | User) -> list[str]: """Return a formatted list of tags that are accessible by `user`; groups first, and alphabetically sorted.""" def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str: group, name = tag_item[0] @@ -278,7 +274,7 @@ def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str: return result_lines - def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str]: + def accessible_tags_in_group(self, group: str, user: Member | User) -> list[str]: """Return a formatted list of tags in `group`, that are accessible by `user`.""" return sorted( f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" @@ -288,7 +284,7 @@ def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str async def get_command_ctx( self, - ctx: discord.Context, + ctx: Context, name: str ) -> bool: """Made specifically for `error_handler.py`, See `get_command` for more info.""" @@ -315,7 +311,7 @@ async def get_command_ctx( # A valid tag was found and was either sent, or is on cooldown return True - @tag_group.command(name="get") + @app_commands.command(name="tag") async def get_command(self, interaction: Interaction, *, name: Optional[str]) -> bool: """ If a single argument matching a group name is given, list all accessible tags from that group @@ -382,20 +378,6 @@ async def name_autocomplete( ] return choices[:25] if len(choices) > 25 else choices - @tag_group.command(name="list") - async def list_command(self, interaction: Interaction) -> bool: - """Lists all accessible tags.""" - if self.tags: - await LinePaginator.paginate( - self.accessible_tags(interaction.user), - interaction, - Embed(title="Available tags"), - **self.PAGINATOR_DEFAULTS, - ) - else: - await interaction.response.send_message(embed=Embed(description="**There are no tags!**")) - return True - async def setup(bot: Bot) -> None: """Load the Tags cog.""" diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py index 14e7a41254..533eaeda6c 100644 --- a/tests/bot/exts/backend/test_error_handler.py +++ b/tests/bot/exts/backend/test_error_handler.py @@ -350,16 +350,16 @@ async def test_try_get_tag_invoked_from_error_handler(self): async def test_try_get_tag_no_permissions(self): """Test how to handle checks failing.""" - self.tag.get_command.can_run = AsyncMock(return_value=False) + self.bot.can_run = AsyncMock(return_value=False) self.ctx.invoked_with = "foo" - self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(return_value=False))) + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) async def test_try_get_tag_command_error(self): """Should call `on_command_error` when `CommandError` raised.""" err = errors.CommandError() - self.tag.get_command.can_run = AsyncMock(side_effect=err) + self.bot.can_run = AsyncMock(side_effect=err) self.cog.on_command_error = AsyncMock() - self.assertIsNone(await self.cog.try_get_tag(self.ctx, AsyncMock(side_effect=err))) + self.assertIsNone(await self.cog.try_get_tag(self.ctx)) self.cog.on_command_error.assert_awaited_once_with(self.ctx, err) async def test_dont_call_suggestion_tag_sent(self): @@ -385,7 +385,7 @@ async def test_dont_call_suggestion_if_user_mod(self): async def test_call_suggestion(self): """Should call command suggestion if user is not a mod.""" self.ctx.invoked_with = "foo" - self.ctx.invoke = AsyncMock(return_value=False) + self.tag.get_command_ctx = AsyncMock(return_value=False) self.cog.send_command_suggestion = AsyncMock() await self.cog.try_get_tag(self.ctx) From c5b3a55697c9e3dcaab23d124e7692bf552093f7 Mon Sep 17 00:00:00 2001 From: Ibrahim Date: Mon, 27 Feb 2023 17:10:40 +0530 Subject: [PATCH 6/6] Implement shtlrs reviews + removed using of discord.User + update docstring for `get_command_ctx` + renamed user variables to member --- bot/exts/info/tags.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/bot/exts/info/tags.py b/bot/exts/info/tags.py index 0c244ff37c..309f22cad5 100644 --- a/bot/exts/info/tags.py +++ b/bot/exts/info/tags.py @@ -8,7 +8,7 @@ import discord import frontmatter -from discord import Embed, Interaction, Member, User, app_commands +from discord import Embed, Interaction, Member, app_commands from discord.ext.commands import Cog, Context from bot import constants @@ -91,7 +91,7 @@ def embed(self) -> Embed: embed.description = self.content return embed - def accessible_by(self, member: Member | User) -> bool: + def accessible_by(self, member: Member) -> bool: """Check whether `member` can access the tag.""" return bool( not self._restricted_to @@ -184,19 +184,20 @@ def get_fuzzy_matches(self, tag_identifier: TagIdentifier) -> list[tuple[TagIden async def get_tag_embed( self, - author: Member | User, + member: Member, channel: discord.abc.Messageable, tag_identifier: TagIdentifier, ) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]: """ - Generate an embed of the requested tag or of suggestions if the tag doesn't exist/isn't accessible by the user. + Generate an embed of the requested tag or of suggestions if the tag doesn't exist + or isn't accessible by the member. If the requested tag is on cooldown return `COOLDOWN.obj`, otherwise if no suggestions were found return None. - """ + """ # noqa: D205, D415 filtered_tags = [ (ident, tag) for ident, tag in self.get_fuzzy_matches(tag_identifier)[:10] - if tag.accessible_by(author) + if tag.accessible_by(member) ] # Try exact match, includes checking through alt names @@ -240,8 +241,8 @@ async def get_tag_embed( description=suggested_tags_text ) - def accessible_tags(self, user: Member | User) -> list[str]: - """Return a formatted list of tags that are accessible by `user`; groups first, and alphabetically sorted.""" + def accessible_tags(self, member: Member) -> list[str]: + """Return a formatted list of tags that are accessible by `member`; groups first, and alphabetically sorted.""" def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str: group, name = tag_item[0] if group is None: @@ -258,7 +259,7 @@ def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str: if identifier.group != current_group: if not group_accessible: - # Remove group separator line if no tags in the previous group were accessible by the user. + # Remove group separator line if no tags in the previous group were accessible by the member. result_lines.pop() # A new group began, add a separator with the group name. current_group = identifier.group @@ -268,18 +269,18 @@ def tag_sort_key(tag_item: tuple[TagIdentifier, Tag]) -> str: else: result_lines.append("\n\N{BULLET}") - if tag.accessible_by(user): + if tag.accessible_by(member): result_lines.append(f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier.name}") group_accessible = True return result_lines - def accessible_tags_in_group(self, group: str, user: Member | User) -> list[str]: - """Return a formatted list of tags in `group`, that are accessible by `user`.""" + def accessible_tags_in_group(self, group: str, member: Member) -> list[str]: + """Return a formatted list of tags in `group`, that are accessible by `member`.""" return sorted( f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}" for identifier, tag in self.tags.items() - if identifier.group == group and tag.accessible_by(user) + if identifier.group == group and tag.accessible_by(member) ) async def get_command_ctx( @@ -287,7 +288,11 @@ async def get_command_ctx( ctx: Context, name: str ) -> bool: - """Made specifically for `error_handler.py`, See `get_command` for more info.""" + """ + Made specifically for `ErrorHandler().try_get_tag` to handle sending tags through ctx. + + See `get_command` for more info, but here name is not optional unlike `get_command`. + """ identifier = TagIdentifier.from_string(name) if identifier.group is None: