Skip to content

Commit

Permalink
Migrated !tags command to slash command /tag
Browse files Browse the repository at this point in the history
  • Loading branch information
Ibrahim2750mi committed Feb 14, 2023
1 parent 44b2b2a commit 679cb77
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 104 deletions.
160 changes: 60 additions & 100 deletions bot/exts/info/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import re
import time
from pathlib import Path
from typing import Callable, Iterable, Literal, NamedTuple, Optional, Union
from typing import Literal, NamedTuple, Optional, Union

import discord
import frontmatter
from discord import Embed, Member
from discord.ext.commands import Cog, Context, group
from discord import Embed, Member, app_commands
from discord.ext.commands import Cog

from bot import constants
from bot.bot import Bot
Expand All @@ -27,6 +27,8 @@
REGEX_NON_ALPHABET = re.compile(r"[^a-z]", re.MULTILINE & re.IGNORECASE)
FOOTER_TEXT = f"To show a tag, type {constants.Bot.prefix}tags <tagname>."

GUILD_ID = constants.Guild.id


class COOLDOWN(enum.Enum):
"""Sentinel value to signal that a tag is on cooldown."""
Expand Down Expand Up @@ -138,6 +140,15 @@ def __init__(self, bot: Bot):
self.bot = bot
self.tags: dict[TagIdentifier, Tag] = {}
self.initialize_tags()
self.bot.tree.copy_global_to(guild=discord.Object(id=GUILD_ID))

tag_group = app_commands.Group(name="tag", description="...")
# search_tag = app_commands.Group(name="search", description="...", parent=tag_group)

@Cog.listener()
async def on_ready(self) -> None:
"""Called when the cog is ready."""
await self.bot.tree.sync(guild=discord.Object(id=GUILD_ID))

def initialize_tags(self) -> None:
"""Load all tags from resources into `self.tags`."""
Expand Down Expand Up @@ -182,90 +193,9 @@ def get_fuzzy_matches(self, tag_identifier: TagIdentifier) -> list[tuple[TagIden

return suggestions

def _get_tags_via_content(
self,
check: Callable[[Iterable], bool],
keywords: str,
user: Member,
) -> list[tuple[TagIdentifier, Tag]]:
"""
Search for tags via contents.
`predicate` will be the built-in any, all, or a custom callable. Must return a bool.
"""
keywords_processed = []
for keyword in keywords.split(","):
keyword_sanitized = keyword.strip().casefold()
if not keyword_sanitized:
# this happens when there are leading / trailing / consecutive comma.
continue
keywords_processed.append(keyword_sanitized)

if not keywords_processed:
# after sanitizing, we can end up with an empty list, for example when keywords is ","
# in that case, we simply want to search for such keywords directly instead.
keywords_processed = [keywords]

matching_tags = []
for identifier, tag in self.tags.items():
matches = (query in tag.content.casefold() for query in keywords_processed)
if tag.accessible_by(user) and check(matches):
matching_tags.append((identifier, tag))

return matching_tags

async def _send_matching_tags(
self,
ctx: Context,
keywords: str,
matching_tags: list[tuple[TagIdentifier, Tag]],
) -> None:
"""Send the result of matching tags to user."""
if len(matching_tags) == 1:
await ctx.send(embed=matching_tags[0][1].embed)
elif matching_tags:
is_plural = keywords.strip().count(" ") > 0 or keywords.strip().count(",") > 0
embed = Embed(
title=f"Here are the tags containing the given keyword{'s' * is_plural}:",
)
await LinePaginator.paginate(
sorted(
f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier.name}"
for identifier, _ in matching_tags
),
ctx,
embed,
**self.PAGINATOR_DEFAULTS,
)

@group(name="tags", aliases=("tag", "t"), invoke_without_command=True, usage="[tag_group] [tag_name]")
async def tags_group(self, ctx: Context, *, argument_string: Optional[str]) -> None:
"""Show all known tags, a single tag, or run a subcommand."""
await self.get_command(ctx, argument_string=argument_string)

@tags_group.group(name="search", invoke_without_command=True)
async def search_tag_content(self, ctx: Context, *, keywords: str) -> None:
"""
Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma.
Only search for tags that has ALL the keywords.
"""
matching_tags = self._get_tags_via_content(all, keywords, ctx.author)
await self._send_matching_tags(ctx, keywords, matching_tags)

@search_tag_content.command(name="any")
async def search_tag_content_any_keyword(self, ctx: Context, *, keywords: Optional[str] = "any") -> None:
"""
Search inside tags' contents for tags. Allow searching for multiple keywords separated by comma.
Search for tags that has ANY of the keywords.
"""
matching_tags = self._get_tags_via_content(any, keywords or "any", ctx.author)
await self._send_matching_tags(ctx, keywords, matching_tags)

async def get_tag_embed(
self,
ctx: Context,
interaction: discord.Interaction,
tag_identifier: TagIdentifier,
) -> Optional[Union[Embed, Literal[COOLDOWN.obj]]]:
"""
Expand All @@ -276,7 +206,7 @@ async def get_tag_embed(
filtered_tags = [
(ident, tag) for ident, tag in
self.get_fuzzy_matches(tag_identifier)[:10]
if tag.accessible_by(ctx.author)
if tag.accessible_by(interaction.user)
]

# Try exact match, includes checking through alt names
Expand All @@ -295,10 +225,10 @@ async def get_tag_embed(
tag = filtered_tags[0][1]

if tag is not None:
if tag.on_cooldown_in(ctx.channel):
if tag.on_cooldown_in(interaction.channel):
log.debug(f"Tag {str(tag_identifier)!r} is on cooldown.")
return COOLDOWN.obj
tag.set_cooldown_for(ctx.channel)
tag.set_cooldown_for(interaction.channel)

self.bot.stats.incr(
f"tags.usages"
Expand All @@ -313,7 +243,7 @@ async def get_tag_embed(
suggested_tags_text = "\n".join(
f"**\N{RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK}** {identifier}"
for identifier, tag in filtered_tags
if not tag.on_cooldown_in(ctx.channel)
if not tag.on_cooldown_in(interaction.channel)
)
return Embed(
title="Did you mean ...",
Expand Down Expand Up @@ -362,8 +292,8 @@ def accessible_tags_in_group(self, group: str, user: discord.Member) -> list[str
if identifier.group == group and tag.accessible_by(user)
)

@tags_group.command(name="get", aliases=("show", "g"), usage="[tag_group] [tag_name]")
async def get_command(self, ctx: Context, *, argument_string: Optional[str]) -> bool:
@tag_group.command(name="get")
async def get_command(self, interaction: discord.Interaction, *, tag_name: Optional[str]) -> bool:
"""
If a single argument matching a group name is given, list all accessible tags from that group
Otherwise display the tag if one was found for the given arguments, or try to display suggestions for that name.
Expand All @@ -373,38 +303,68 @@ async def get_command(self, ctx: Context, *, argument_string: Optional[str]) ->
Returns True if a message was sent, or if the tag is on cooldown.
Returns False if no message was sent.
""" # noqa: D205, D415
if not argument_string:
if not tag_name:
if self.tags:
await LinePaginator.paginate(
self.accessible_tags(ctx.author), ctx, Embed(title="Available tags"), **self.PAGINATOR_DEFAULTS
self.accessible_tags(interaction.user),
interaction, Embed(title="Available tags"),
**self.PAGINATOR_DEFAULTS,
)
else:
await ctx.send(embed=Embed(description="**There are no tags!**"))
await interaction.response.send_message(embed=Embed(description="**There are no tags!**"))
return True

identifier = TagIdentifier.from_string(argument_string)
identifier = TagIdentifier.from_string(tag_name)

if identifier.group is None:
# Try to find accessible tags from a group matching the identifier's name.
if group_tags := self.accessible_tags_in_group(identifier.name, ctx.author):
if group_tags := self.accessible_tags_in_group(identifier.name, interaction.user):
await LinePaginator.paginate(
group_tags, ctx, Embed(title=f"Tags under *{identifier.name}*"), **self.PAGINATOR_DEFAULTS
group_tags, interaction, Embed(title=f"Tags under *{identifier.name}*"), **self.PAGINATOR_DEFAULTS
)
return True

embed = await self.get_tag_embed(ctx, identifier)
embed = await self.get_tag_embed(interaction, identifier)
if embed is None:
return False

if embed is not COOLDOWN.obj:
await wait_for_deletion(
await ctx.send(embed=embed),
(ctx.author.id,)
await interaction.response.send_message(embed=embed),
(interaction.user.id,)
)
# A valid tag was found and was either sent, or is on cooldown
return True

@get_command.autocomplete("tag_name")
async def tag_name_autocomplete(
self,
interaction: discord.Interaction,
current: str
) -> list[app_commands.Choice[str]]:
"""Autocompleter for `/tag get` command."""
tag_names = [tag.name for tag in self.tags.keys()]
return [
app_commands.Choice(name=tag, value=tag)
for tag in tag_names if current.lower() in tag
]

@tag_group.command(name="list")
async def list_command(self, interaction: discord.Interaction) -> bool:
"""Lists all accessible tags."""
if self.tags:
await LinePaginator.paginate(
self.accessible_tags(interaction.user),
interaction,
Embed(title="Available tags"),
**self.PAGINATOR_DEFAULTS,
)
else:
await interaction.response.send_message(embed=Embed(description="**There are no tags!**"))
return True


async def setup(bot: Bot) -> None:
"""Load the Tags cog."""
await bot.add_cog(Tags(bot))
await bot.tree.sync(guild=discord.Object(id=GUILD_ID))
20 changes: 16 additions & 4 deletions bot/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _split_remaining_words(self, line: str, max_chars: int) -> t.Tuple[str, t.Op
async def paginate(
cls,
lines: t.List[str],
ctx: Context,
ctx: t.Union[Context, discord.Interaction],
embed: discord.Embed,
prefix: str = "",
suffix: str = "",
Expand Down Expand Up @@ -228,7 +228,10 @@ async def paginate(
current_page = 0

if not restrict_to_user:
restrict_to_user = ctx.author
if isinstance(ctx, discord.Interaction):
restrict_to_user = ctx.user
else:
restrict_to_user = ctx.author

if not lines:
if exception_on_empty_embed:
Expand Down Expand Up @@ -261,6 +264,8 @@ async def paginate(
log.trace(f"Setting embed url to '{url}'")

log.debug("There's less than two pages, so we won't paginate - sending single page on its own")
if isinstance(ctx, discord.Interaction):
return await ctx.response.send_message(embed=embed)
return await ctx.send(embed=embed)
else:
if footer_text:
Expand All @@ -274,7 +279,11 @@ async def paginate(
log.trace(f"Setting embed url to '{url}'")

log.debug("Sending first page to channel...")
message = await ctx.send(embed=embed)
if isinstance(ctx, discord.Interaction):
await ctx.response.send_message(embed=embed)
message = await ctx.original_response()
else:
message = await ctx.send(embed=embed)

log.debug("Adding emoji reactions to message...")

Expand All @@ -292,7 +301,10 @@ async def paginate(

while True:
try:
reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=check)
if isinstance(ctx, discord.Interaction):
reaction, user = await ctx.client.wait_for("reaction_add", timeout=timeout, check=check)
else:
reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=check)
log.trace(f"Got reaction: {reaction}")
except asyncio.TimeoutError:
log.debug("Timed out waiting for a reaction")
Expand Down

0 comments on commit 679cb77

Please sign in to comment.