From eb14984ad02f1e2cdca4251a957e3e6f59d13358 Mon Sep 17 00:00:00 2001 From: kangalioo Date: Mon, 9 Aug 2021 20:12:54 +0200 Subject: [PATCH] Redesign the `Parse` trait and add support for most applicable model types (#1380) This commit redesigns and renames the `Parse` trait - which is now called `ArgumentConvert` - according to the ideas from https://github.com/serenity-rs/serenity/issues/1327. This is not a breaking change, since a dummy version of the `Parse` trait is still present which internally delegates to `ArgumentConvert`. Additionally, there is now `ArgumentConvert` support for most applicable model types: - `Channel` - `GuildChannel` - `ChannelCategory` - `Emoji` - `GuildChannel` - `Member` (already supported) - `Message` (already supported) - `Role` - `User` I oriented myself at [discord.py's converters.][0] [0]: https://discordpy.readthedocs.io/en/latest/ext/commands/commands.html#discord-converters --- src/utils/argument_convert/_template.rs | 45 +++++ src/utils/argument_convert/channel.rs | 209 ++++++++++++++++++++++++ src/utils/argument_convert/emoji.rs | 63 +++++++ src/utils/argument_convert/guild.rs | 45 +++++ src/utils/argument_convert/member.rs | 90 ++++++++++ src/utils/argument_convert/message.rs | 80 +++++++++ src/utils/argument_convert/mod.rs | 158 ++++++++++++++++++ src/utils/argument_convert/role.rs | 66 ++++++++ src/utils/argument_convert/user.rs | 65 ++++++++ src/utils/mod.rs | 8 +- src/utils/parse.rs | 187 --------------------- 11 files changed, 825 insertions(+), 191 deletions(-) create mode 100644 src/utils/argument_convert/_template.rs create mode 100644 src/utils/argument_convert/channel.rs create mode 100644 src/utils/argument_convert/emoji.rs create mode 100644 src/utils/argument_convert/guild.rs create mode 100644 src/utils/argument_convert/member.rs create mode 100644 src/utils/argument_convert/message.rs create mode 100644 src/utils/argument_convert/mod.rs create mode 100644 src/utils/argument_convert/role.rs create mode 100644 src/utils/argument_convert/user.rs diff --git a/src/utils/argument_convert/_template.rs b/src/utils/argument_convert/_template.rs new file mode 100644 index 00000000000..f979497c6f9 --- /dev/null +++ b/src/utils/argument_convert/_template.rs @@ -0,0 +1,45 @@ +use super::ArgumentConvert; +use crate::{model::prelude::*, prelude::*}; + +/// Error that can be returned from [`PLACEHOLDER::convert`]. +#[non_exhaustive] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +pub enum PLACEHOLDERParseError { +} + +impl std::error::Error for PLACEHOLDERParseError {} + +impl std::fmt::Display for PLACEHOLDERParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + } + } +} + +/// Look up a [`PLACEHOLDER`] by a string case-insensitively. +/// +/// Requires the cache feature to be enabled. +/// +/// The lookup strategy is as follows (in order): +/// 1. Lookup by PLACEHOLDER +/// 2. [Lookup by PLACEHOLDER](`crate::utils::parse_PLACEHOLDER`). +#[cfg(feature = "cache")] +#[async_trait::async_trait] +impl ArgumentConvert for PLACEHOLDER { + type Err = PLACEHOLDERParseError; + + async fn convert( + ctx: &Context, + guild_id: Option, + _channel_id: Option, + s: &str, + ) -> Result { + let lookup_by_PLACEHOLDER = || PLACEHOLDER; + + lookup_by_PLACEHOLDER() + .or_else(lookup_by_PLACEHOLDER) + .or_else(lookup_by_PLACEHOLDER) + .cloned() + .ok_or(PLACEHOLDERParseError::NotFoundOrMalformed) + } +} diff --git a/src/utils/argument_convert/channel.rs b/src/utils/argument_convert/channel.rs new file mode 100644 index 00000000000..802e54645ef --- /dev/null +++ b/src/utils/argument_convert/channel.rs @@ -0,0 +1,209 @@ +use super::ArgumentConvert; +use crate::{model::prelude::*, prelude::*}; + +/// Error that can be returned from [`Channel::convert`]. +#[non_exhaustive] +#[derive(Debug)] +pub enum ChannelParseError { + /// When channel retrieval via HTTP failed + Http(SerenityError), + /// The provided channel string failed to parse, or the parsed result cannot be found in the + /// cache. + NotFoundOrMalformed, +} + +impl std::error::Error for ChannelParseError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Http(e) => Some(e), + Self::NotFoundOrMalformed => None, + } + } +} + +impl std::fmt::Display for ChannelParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Http(_) => write!(f, "Failed to request channel via HTTP"), + Self::NotFoundOrMalformed => write!(f, "Channel not found or unknown format"), + } + } +} + +fn channel_belongs_to_guild(channel: &Channel, guild: GuildId) -> bool { + match channel { + Channel::Guild(channel) => channel.guild_id == guild, + Channel::Category(channel) => channel.guild_id == guild, + Channel::Private(_channel) => false, + } +} + +async fn lookup_channel_global(ctx: &Context, s: &str) -> Result { + if let Some(channel_id) = s.parse::().ok().or_else(|| crate::utils::parse_channel(s)) { + return ChannelId(channel_id).to_channel(ctx).await.map_err(ChannelParseError::Http); + } + + let channels = ctx.cache.channels.read().await; + if let Some(channel) = + channels.values().find(|channel| channel.name.eq_ignore_ascii_case(s)).cloned() + { + return Ok(Channel::Guild(channel)); + } + + Err(ChannelParseError::NotFoundOrMalformed) +} + +/// Look up a Channel by a string case-insensitively. +/// +/// Lookup are done via local guild. If in DMs, the global cache is used instead. +/// +/// The cache feature needs to be enabled. +/// +/// The lookup strategy is as follows (in order): +/// 1. Lookup by ID. +/// 2. [Lookup by mention](`crate::utils::parse_channel`). +/// 3. Lookup by name. +#[cfg(feature = "cache")] +#[async_trait::async_trait] +impl ArgumentConvert for Channel { + type Err = ChannelParseError; + + async fn convert( + ctx: &Context, + guild_id: Option, + _channel_id: Option, + s: &str, + ) -> Result { + let channel = lookup_channel_global(ctx, s).await?; + + // Don't yield for other guilds' channels + if let Some(guild_id) = guild_id { + if !channel_belongs_to_guild(&channel, guild_id) { + return Err(ChannelParseError::NotFoundOrMalformed); + } + }; + + Ok(channel) + } +} + +/// Error that can be returned from [`GuildChannel::convert`]. +#[non_exhaustive] +#[derive(Debug)] +pub enum GuildChannelParseError { + /// When channel retrieval via HTTP failed + Http(SerenityError), + /// The provided channel string failed to parse, or the parsed result cannot be found in the + /// cache. + NotFoundOrMalformed, + /// When the referenced channel is not a guild channel + NotAGuildChannel, +} + +impl std::error::Error for GuildChannelParseError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Http(e) => Some(e), + Self::NotFoundOrMalformed => None, + Self::NotAGuildChannel => None, + } + } +} + +impl std::fmt::Display for GuildChannelParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Http(_) => write!(f, "Failed to request channel via HTTP"), + Self::NotFoundOrMalformed => write!(f, "Channel not found or unknown format"), + Self::NotAGuildChannel => write!(f, "Channel is not a guild channel"), + } + } +} + +/// Look up a GuildChannel by a string case-insensitively. +/// +/// Lookup is done by the global cache, hence the cache feature needs to be enabled. +/// +/// For more information, see the ArgumentConvert implementation for [`Channel`] +#[cfg(feature = "cache")] +#[async_trait::async_trait] +impl ArgumentConvert for GuildChannel { + type Err = GuildChannelParseError; + + async fn convert( + ctx: &Context, + guild_id: Option, + channel_id: Option, + s: &str, + ) -> Result { + match Channel::convert(ctx, guild_id, channel_id, s).await { + Ok(Channel::Guild(channel)) => Ok(channel), + Ok(_) => Err(GuildChannelParseError::NotAGuildChannel), + Err(ChannelParseError::Http(e)) => Err(GuildChannelParseError::Http(e)), + Err(ChannelParseError::NotFoundOrMalformed) => { + Err(GuildChannelParseError::NotFoundOrMalformed) + }, + } + } +} + +/// Error that can be returned from [`ChannelCategory::convert`]. +#[non_exhaustive] +#[derive(Debug)] +pub enum ChannelCategoryParseError { + /// When channel retrieval via HTTP failed + Http(SerenityError), + /// The provided channel string failed to parse, or the parsed result cannot be found in the + /// cache. + NotFoundOrMalformed, + /// When the referenced channel is not a channel category + NotAChannelCategory, +} + +impl std::error::Error for ChannelCategoryParseError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Http(e) => Some(e), + Self::NotFoundOrMalformed => None, + Self::NotAChannelCategory => None, + } + } +} + +impl std::fmt::Display for ChannelCategoryParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Http(_) => write!(f, "Failed to request channel via HTTP"), + Self::NotFoundOrMalformed => write!(f, "Channel not found or unknown format"), + Self::NotAChannelCategory => write!(f, "Channel is not a channel category"), + } + } +} + +/// Look up a ChannelCategory by a string case-insensitively. +/// +/// Lookup is done by the global cache, hence the cache feature needs to be enabled. +/// +/// For more information, see the ArgumentConvert implementation for [`Channel`] +#[cfg(feature = "cache")] +#[async_trait::async_trait] +impl ArgumentConvert for ChannelCategory { + type Err = ChannelCategoryParseError; + + async fn convert( + ctx: &Context, + guild_id: Option, + channel_id: Option, + s: &str, + ) -> Result { + match Channel::convert(ctx, guild_id, channel_id, s).await { + Ok(Channel::Category(channel)) => Ok(channel), + // TODO: accomodate issue #1352 somehow + Ok(_) => Err(ChannelCategoryParseError::NotAChannelCategory), + Err(ChannelParseError::Http(e)) => Err(ChannelCategoryParseError::Http(e)), + Err(ChannelParseError::NotFoundOrMalformed) => { + Err(ChannelCategoryParseError::NotFoundOrMalformed) + }, + } + } +} diff --git a/src/utils/argument_convert/emoji.rs b/src/utils/argument_convert/emoji.rs new file mode 100644 index 00000000000..22dd6f44752 --- /dev/null +++ b/src/utils/argument_convert/emoji.rs @@ -0,0 +1,63 @@ +use super::ArgumentConvert; +use crate::{model::prelude::*, prelude::*}; + +/// Error that can be returned from [`Emoji::convert`]. +#[non_exhaustive] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +pub enum EmojiParseError { + /// The provided emoji string failed to parse, or the parsed result cannot be found in the + /// cache. + NotFoundOrMalformed, +} + +impl std::error::Error for EmojiParseError {} + +impl std::fmt::Display for EmojiParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NotFoundOrMalformed => write!(f, "Emoji not found or unknown format"), + } + } +} + +/// Look up a [`Emoji`]. +/// +/// Requires the cache feature to be enabled. +/// +/// The lookup strategy is as follows (in order): +/// 1. Lookup by ID. +/// 2. [Lookup by extracting ID from the emoji](`crate::utils::parse_emoji`). +/// 3. Lookup by name. +#[cfg(feature = "cache")] +#[async_trait::async_trait] +impl ArgumentConvert for Emoji { + type Err = EmojiParseError; + + async fn convert( + ctx: &Context, + _guild_id: Option, + _channel_id: Option, + s: &str, + ) -> Result { + let guilds = ctx.cache.guilds.read().await; + + let direct_id = s.parse::().ok().map(EmojiId); + let id_from_mention = crate::utils::parse_emoji(s).map(|e| e.id); + + if let Some(emoji_id) = direct_id.or(id_from_mention) { + if let Some(emoji) = guilds.values().find_map(|guild| guild.emojis.get(&emoji_id)) { + return Ok(emoji.clone()); + } + } + + if let Some(emoji) = guilds + .values() + .flat_map(|guild| guild.emojis.values()) + .find(|emoji| emoji.name.eq_ignore_ascii_case(s)) + { + return Ok(emoji.clone()); + } + + Err(EmojiParseError::NotFoundOrMalformed) + } +} diff --git a/src/utils/argument_convert/guild.rs b/src/utils/argument_convert/guild.rs new file mode 100644 index 00000000000..f420639986d --- /dev/null +++ b/src/utils/argument_convert/guild.rs @@ -0,0 +1,45 @@ +use super::ArgumentConvert; +use crate::{model::prelude::*, prelude::*}; + +/// Error that can be returned from [`Guild::convert`]. +#[non_exhaustive] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +pub enum GuildParseError { + /// The provided guild string failed to parse, or the parsed result cannot be found in the + /// cache. + NotFoundOrMalformed, +} + +impl std::error::Error for GuildParseError {} + +impl std::fmt::Display for GuildParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NotFoundOrMalformed => write!(f, "Guild not found or unknown format"), + } + } +} + +/// Look up a Guild, either by ID or by a string case-insensitively. +/// +/// Requires the cache feature to be enabled. +#[cfg(feature = "cache")] +#[async_trait::async_trait] +impl ArgumentConvert for Guild { + type Err = GuildParseError; + + async fn convert( + ctx: &Context, + _guild_id: Option, + _channel_id: Option, + s: &str, + ) -> Result { + let guilds = ctx.cache.guilds.read().await; + + let lookup_by_id = || guilds.get(&GuildId(s.parse().ok()?)); + + let lookup_by_name = || guilds.values().find(|guild| guild.name.eq_ignore_ascii_case(s)); + + lookup_by_id().or_else(lookup_by_name).cloned().ok_or(GuildParseError::NotFoundOrMalformed) + } +} diff --git a/src/utils/argument_convert/member.rs b/src/utils/argument_convert/member.rs new file mode 100644 index 00000000000..37a90925a3f --- /dev/null +++ b/src/utils/argument_convert/member.rs @@ -0,0 +1,90 @@ +use super::ArgumentConvert; +use crate::{model::prelude::*, prelude::*}; + +/// Error that can be returned from [`Member::convert`]. +#[non_exhaustive] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +pub enum MemberParseError { + /// Parser was invoked outside a guild. + OutsideGuild, + /// The guild in which the parser was invoked is not in cache. + GuildNotInCache, + /// The provided member string failed to parse, or the parsed result cannot be found in the + /// guild cache data. + NotFoundOrMalformed, +} + +impl std::error::Error for MemberParseError {} + +impl std::fmt::Display for MemberParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::OutsideGuild => write!(f, "Tried to find member outside a guild"), + Self::GuildNotInCache => write!(f, "Guild is not in cache"), + Self::NotFoundOrMalformed => write!(f, "Member not found or unknown format"), + } + } +} + +/// Look up a guild member by a string case-insensitively. +/// +/// Requires the cache feature to be enabled. +/// +/// The lookup strategy is as follows (in order): +/// 1. Lookup by ID. +/// 2. [Lookup by mention](`crate::utils::parse_username`). +/// 3. [Lookup by name#discrim](`crate::utils::parse_user_tag`). +/// 4. Lookup by name +/// 5. Lookup by nickname +#[cfg(feature = "cache")] +#[async_trait::async_trait] +impl ArgumentConvert for Member { + type Err = MemberParseError; + + async fn convert( + ctx: &Context, + guild_id: Option, + _channel_id: Option, + s: &str, + ) -> Result { + let guild = guild_id + .ok_or(MemberParseError::OutsideGuild)? + .to_guild_cached(ctx) + .await + .ok_or(MemberParseError::GuildNotInCache)?; + + // DON'T use guild.members: it's only populated when guild presences intent is enabled! + + // If string is a raw user ID or a mention + if let Some(user_id) = s.parse().ok().or_else(|| crate::utils::parse_username(s)) { + if let Ok(member) = guild.member(ctx, UserId(user_id)).await { + return Ok(member); + } + } + + // Following code is inspired by discord.py's MemberConvert::query_member_named + + // If string is a username+discriminator + if let Some((name, discrim)) = crate::utils::parse_user_tag(s) { + if let Ok(member_results) = guild.search_members(ctx, name, Some(100)).await { + if let Some(member) = member_results.into_iter().find(|m| { + m.user.name.eq_ignore_ascii_case(name) && m.user.discriminator == discrim + }) { + return Ok(member); + } + } + } + + // If string is username or nickname + if let Ok(member_results) = guild.search_members(ctx, s, Some(100)).await { + if let Some(member) = member_results.into_iter().find(|m| { + m.user.name.eq_ignore_ascii_case(s) + || m.nick.as_ref().map_or(false, |nick| nick.eq_ignore_ascii_case(s)) + }) { + return Ok(member); + } + } + + Err(MemberParseError::NotFoundOrMalformed) + } +} diff --git a/src/utils/argument_convert/message.rs b/src/utils/argument_convert/message.rs new file mode 100644 index 00000000000..e9b90918387 --- /dev/null +++ b/src/utils/argument_convert/message.rs @@ -0,0 +1,80 @@ +use super::ArgumentConvert; +use crate::{model::prelude::*, prelude::*}; + +/// Error that can be returned from [`Message::convert`]. +#[non_exhaustive] +#[derive(Debug)] +pub enum MessageParseError { + /// When the provided string does not adhere to any known guild message format + Malformed, + /// When message data retrieval via HTTP failed + Http(SerenityError), + /// When the `gateway` feature is disabled and the required information was not in cache. + HttpNotAvailable, +} + +impl std::error::Error for MessageParseError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Malformed => None, + Self::Http(e) => Some(e), + Self::HttpNotAvailable => None, + } + } +} + +impl std::fmt::Display for MessageParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Malformed => { + write!(f, "Provided string did not adhere to any known guild message format") + }, + Self::Http(_) => write!(f, "Failed to request message data via HTTP"), + Self::HttpNotAvailable => write!( + f, + "Gateway feature is disabled and the required information was not in cache" + ), + } + } +} + +/// Look up a message by a string. +/// +/// The lookup strategy is as follows (in order): +/// 1. [Lookup by "{channel ID}-{message ID}"](`crate::utils::parse_message_id_pair`) (retrieved by shift-clicking on "Copy ID") +/// 2. Lookup by message ID (the message must be in the context channel) +/// 3. [Lookup by message URL](`crate::utils::parse_message_url`) +#[async_trait::async_trait] +impl ArgumentConvert for Message { + type Err = MessageParseError; + + async fn convert( + ctx: &Context, + _guild_id: Option, + channel_id: Option, + s: &str, + ) -> Result { + let extract_from_message_id = || Some((channel_id?, MessageId(s.parse().ok()?))); + + let extract_from_message_url = || { + let (_guild_id, channel_id, message_id) = crate::utils::parse_message_url(s)?; + Some((channel_id, message_id)) + }; + + let (channel_id, message_id) = crate::utils::parse_message_id_pair(s) + .or_else(extract_from_message_id) + .or_else(extract_from_message_url) + .ok_or(MessageParseError::Malformed)?; + + #[cfg(feature = "cache")] + if let Some(msg) = ctx.cache.message(channel_id, message_id).await { + return Ok(msg); + } + + if cfg!(feature = "http") { + ctx.http.get_message(channel_id.0, message_id.0).await.map_err(MessageParseError::Http) + } else { + Err(MessageParseError::HttpNotAvailable) + } + } +} diff --git a/src/utils/argument_convert/mod.rs b/src/utils/argument_convert/mod.rs new file mode 100644 index 00000000000..3fa4b23364a --- /dev/null +++ b/src/utils/argument_convert/mod.rs @@ -0,0 +1,158 @@ +mod member; +pub use member::*; + +mod message; +pub use message::*; + +mod user; +pub use user::*; + +mod channel; +pub use channel::*; + +mod guild; +pub use guild::*; + +mod role; +pub use role::*; + +mod emoji; +pub use emoji::*; + +use crate::model::prelude::*; +use crate::prelude::*; + +#[deprecated(note = "Superseded by ArgumentConvert trait")] +#[async_trait::async_trait] +pub trait Parse: Sized { + /// The associated error which can be returned from parsing. + type Err; + + /// Parses a string `s` as a command parameter of this type. + async fn parse(ctx: &Context, msg: &Message, s: &str) -> Result; +} + +#[allow(deprecated)] +#[async_trait::async_trait] +impl Parse for T { + type Err = ::Err; + + async fn parse(ctx: &Context, msg: &Message, s: &str) -> Result { + Self::convert(ctx, msg.guild_id, Some(msg.channel_id), s).await + } +} + +/// Parse a value from a string in context of a received message. +/// +/// This trait is a superset of [`std::str::FromStr`]. The +/// difference is that this trait aims to support serenity-specific Discord types like [`Member`] +/// or [`Message`]. +/// +/// Trait implementations may do network requests as part of their parsing procedure. +/// +/// Useful for implementing argument parsing in command frameworks. +#[async_trait::async_trait] +pub trait ArgumentConvert: Sized { + /// The associated error which can be returned from parsing. + type Err; + + /// Parses a string `s` as a command parameter of this type. + async fn convert( + ctx: &Context, + guild_id: Option, + channel_id: Option, + s: &str, + ) -> Result; +} + +#[async_trait::async_trait] +impl ArgumentConvert for T { + type Err = ::Err; + + async fn convert( + _: &Context, + _: Option, + _: Option, + s: &str, + ) -> Result { + T::from_str(s) + } +} + +// The following few parse_XXX methods are in here (parse.rs) because they need to be gated +// behind the model feature and it's just convenient to put them here for that + +/// Retrieves IDs from "{channel ID}-{message ID}" (retrieved by shift-clicking on "Copy ID"). +/// +/// If the string is invalid, None is returned. +/// +/// # Examples +/// ```rust +/// use serenity::model::prelude::*; +/// use serenity::utils::parse_message_id_pair; +/// +/// assert_eq!( +/// parse_message_id_pair("673965002805477386-842482646604972082"), +/// Some((ChannelId(673965002805477386), MessageId(842482646604972082))), +/// ); +/// assert_eq!( +/// parse_message_id_pair("673965002805477386-842482646604972082-472029906943868929"), +/// None, +/// ); +/// ``` +pub fn parse_message_id_pair(s: &str) -> Option<(ChannelId, MessageId)> { + let mut parts = s.splitn(2, '-'); + let channel_id = ChannelId(parts.next()?.parse().ok()?); + let message_id = MessageId(parts.next()?.parse().ok()?); + Some((channel_id, message_id)) +} + +/// Retrieves guild, channel, and message ID from a message URL. +/// +/// If the URL is malformed, None is returned. +/// +/// # Examples +/// ```rust +/// use serenity::model::prelude::*; +/// use serenity::utils::parse_message_url; +/// +/// assert_eq!( +/// parse_message_url( +/// "https://discord.com/channels/381880193251409931/381880193700069377/806164913558781963" +/// ), +/// Some(( +/// GuildId(381880193251409931), +/// ChannelId(381880193700069377), +/// MessageId(806164913558781963), +/// )), +/// ); +/// assert_eq!(parse_message_url("https://google.com"), None); +/// ``` +pub fn parse_message_url(s: &str) -> Option<(GuildId, ChannelId, MessageId)> { + let mut parts = s.strip_prefix("https://discord.com/channels/")?.splitn(3, '/'); + let guild_id = GuildId(parts.next()?.parse().ok()?); + let channel_id = ChannelId(parts.next()?.parse().ok()?); + let message_id = MessageId(parts.next()?.parse().ok()?); + Some((guild_id, channel_id, message_id)) +} + +/// Retrieves the username and discriminator out of a user tag (`name#discrim`). +/// +/// If the user tag is invalid, None is returned. +/// +/// # Examples +/// ```rust +/// use serenity::utils::parse_user_tag; +/// +/// assert_eq!(parse_user_tag("kangalioo#9108"), Some(("kangalioo", 9108))); +/// assert_eq!(parse_user_tag("kangalioo#10108"), None); +/// ``` +pub fn parse_user_tag(s: &str) -> Option<(&str, u16)> { + let pound_sign = s.find('#')?; + let name = &s[..pound_sign]; + let discrim = s[(pound_sign + 1)..].parse::().ok()?; + if discrim > 9999 { + return None; + } + Some((name, discrim)) +} diff --git a/src/utils/argument_convert/role.rs b/src/utils/argument_convert/role.rs new file mode 100644 index 00000000000..b8f1e640e88 --- /dev/null +++ b/src/utils/argument_convert/role.rs @@ -0,0 +1,66 @@ +use super::ArgumentConvert; +use crate::{model::prelude::*, prelude::*}; + +/// Error that can be returned from [`Role::convert`]. +#[non_exhaustive] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +pub enum RoleParseError { + /// When the operation was invoked outside a guild. + NotInGuild, + /// When the guild's roles were not found in cache. + NotInCache, + /// The provided channel string failed to parse, or the parsed result cannot be found in the + /// cache. + NotFoundOrMalformed, +} + +impl std::error::Error for RoleParseError {} + +impl std::fmt::Display for RoleParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NotInGuild => f.write_str("Must invoke this operation in a guild"), + Self::NotInCache => f.write_str("Guild's roles were not found in cache"), + Self::NotFoundOrMalformed => f.write_str("Role not found or unknown format"), + } + } +} + +/// Look up a [`Role`] by a string case-insensitively. +/// +/// Requires the cache feature to be enabled. +/// +/// The lookup strategy is as follows (in order): +/// 1. Lookup by ID +/// 2. [Lookup by mention](`crate::utils::parse_role`). +/// 3. Lookup by name (case-insensitive) +#[cfg(feature = "cache")] +#[async_trait::async_trait] +impl ArgumentConvert for Role { + type Err = RoleParseError; + + async fn convert( + ctx: &Context, + guild_id: Option, + _channel_id: Option, + s: &str, + ) -> Result { + let roles = ctx + .cache + .guild_roles(guild_id.ok_or(RoleParseError::NotInGuild)?) + .await + .ok_or(RoleParseError::NotInCache)?; + + if let Some(role_id) = s.parse::().ok().or_else(|| crate::utils::parse_role(s)) { + if let Some(role) = roles.get(&RoleId(role_id)) { + return Ok(role.clone()); + } + } + + if let Some(role) = roles.values().find(|role| role.name.eq_ignore_ascii_case(s)) { + return Ok(role.clone()); + } + + Err(RoleParseError::NotFoundOrMalformed) + } +} diff --git a/src/utils/argument_convert/user.rs b/src/utils/argument_convert/user.rs new file mode 100644 index 00000000000..d92451eb63b --- /dev/null +++ b/src/utils/argument_convert/user.rs @@ -0,0 +1,65 @@ +use super::ArgumentConvert; +use crate::{model::prelude::*, prelude::*}; + +/// Error that can be returned from [`User::convert`]. +#[non_exhaustive] +#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] +pub enum UserParseError { + /// The provided user string failed to parse, or the parsed result cannot be found in the + /// guild cache data. + NotFoundOrMalformed, +} + +impl std::error::Error for UserParseError {} + +impl std::fmt::Display for UserParseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NotFoundOrMalformed => write!(f, "User not found or unknown format"), + } + } +} + +/// Look up a user by a string case-insensitively. +/// +/// Requires the cache feature to be enabled. If a user is not in cache, they will not be found! +/// +/// The lookup strategy is as follows (in order): +/// 1. Lookup by ID. +/// 2. [Lookup by mention](`crate::utils::parse_username`). +/// 3. [Lookup by name#discrim](`crate::utils::parse_user_tag`). +/// 4. Lookup by name +#[cfg(feature = "cache")] +#[async_trait::async_trait] +impl ArgumentConvert for User { + type Err = UserParseError; + + async fn convert( + ctx: &Context, + _guild_id: Option, + _channel_id: Option, + s: &str, + ) -> Result { + let users = ctx.cache.users.read().await; + + let lookup_by_id = || users.get(&UserId(s.parse().ok()?)); + + let lookup_by_mention = || users.get(&UserId(crate::utils::parse_username(s)?)); + + let lookup_by_name_and_discrim = || { + let (name, discrim) = crate::utils::parse_user_tag(s)?; + users + .values() + .find(|user| user.discriminator == discrim && user.name.eq_ignore_ascii_case(name)) + }; + + let lookup_by_name = || users.values().find(|user| user.name == s); + + lookup_by_id() + .or_else(lookup_by_mention) + .or_else(lookup_by_name_and_discrim) + .or_else(lookup_by_name) + .cloned() + .ok_or(UserParseError::NotFoundOrMalformed) + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 9aa0eda4019..4869f46c00c 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,14 +1,14 @@ //! A set of utilities to help with common use cases that are not required to //! fully use the library. +#[cfg(all(feature = "client", feature = "cache"))] +mod argument_convert; mod colour; mod custom_message; mod message_builder; -#[cfg(feature = "client")] -mod parse; -#[cfg(feature = "client")] -pub use parse::*; +#[cfg(all(feature = "client", feature = "cache"))] +pub use argument_convert::*; pub use self::{ colour::Colour, diff --git a/src/utils/parse.rs b/src/utils/parse.rs index 2d343a05e96..8b137891791 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -1,188 +1 @@ -use crate::model::prelude::*; -use crate::prelude::*; -/// Parse a value from a string in context of a received message. -/// -/// This trait is a superset of [`std::str::FromStr`]. The -/// difference is that this trait aims to support serenity-specific Discord types like [`Member`] -/// or [`Message`]. -/// -/// Trait implementations may do network requests as part of their parsing procedure. -/// -/// Useful for implementing argument parsing in command frameworks. -#[async_trait::async_trait] -pub trait Parse: Sized { - /// The associated error which can be returned from parsing. - type Err; - - /// Parses a string `s` as a command parameter of this type. - async fn parse(ctx: &Context, msg: &Message, s: &str) -> Result; -} - -#[async_trait::async_trait] -impl Parse for T { - type Err = ::Err; - - async fn parse(_: &Context, _: &Message, s: &str) -> Result { - T::from_str(s) - } -} - -/// Error that can be returned from [`Member::parse`]. -#[non_exhaustive] -#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] -pub enum MemberParseError { - /// The guild in which the parser was invoked is not in cache. - GuildNotInCache, - /// The provided member string failed to parse, or the parsed result cannot be found in the - /// guild cache data. - NotFoundOrMalformed, -} - -impl std::error::Error for MemberParseError {} - -impl std::fmt::Display for MemberParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::GuildNotInCache => write!(f, "Guild is not in cache"), - Self::NotFoundOrMalformed => write!(f, "Provided member was not found or provided string did not adhere to any known guild member format"), - } - } -} - -/// Look up a guild member by a string case-insensitively. -/// -/// Requires the cache feature to be enabled. -/// -/// The lookup strategy is as follows (in order): -/// 1. Lookup by ID. -/// 2. Lookup by mention. -/// 3. Lookup by name#discrim -/// 4. Lookup by name -/// 5. Lookup by nickname -#[cfg(feature = "cache")] -#[async_trait::async_trait] -impl Parse for Member { - type Err = MemberParseError; - - async fn parse(ctx: &Context, msg: &Message, s: &str) -> Result { - let guild = msg.guild(&ctx.cache).await.ok_or(MemberParseError::GuildNotInCache)?; - - let lookup_by_id = || guild.members.get(&UserId(s.parse().ok()?)); - - let lookup_by_mention = || { - guild.members.get(&UserId( - s.strip_prefix("<@")?.trim_start_matches('!').strip_suffix('>')?.parse().ok()?, - )) - }; - - let lookup_by_name_and_discrim = || { - let pound_sign = s.find('#')?; - let name = &s[..pound_sign]; - let discrim = s[(pound_sign + 1)..].parse::().ok()?; - guild.members.values().find(|member| { - member.user.discriminator == discrim && member.user.name.eq_ignore_ascii_case(name) - }) - }; - - let lookup_by_name = || guild.members.values().find(|member| member.user.name == s); - - let lookup_by_nickname = || { - guild.members.values().find(|member| match &member.nick { - Some(nick) => nick.eq_ignore_ascii_case(s), - None => false, - }) - }; - - lookup_by_id() - .or_else(lookup_by_mention) - .or_else(lookup_by_name_and_discrim) - .or_else(lookup_by_name) - .or_else(lookup_by_nickname) - .cloned() - .ok_or(MemberParseError::NotFoundOrMalformed) - } -} - -/// Error that can be returned from [`Message::parse`]. -#[non_exhaustive] -#[derive(Debug)] -pub enum MessageParseError { - /// When the provided string does not adhere to any known guild message format - Malformed, - /// When message data retrieval via HTTP failed - Http(SerenityError), - /// When the `gateway` feature is disabled and the required information was not in cache. - HttpNotAvailable, -} - -impl std::error::Error for MessageParseError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::Malformed => None, - Self::Http(e) => Some(e), - Self::HttpNotAvailable => None, - } - } -} - -impl std::fmt::Display for MessageParseError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Malformed => { - write!(f, "Provided string did not adhere to any known guild message format") - }, - Self::Http(e) => write!(f, "Failed to request message data via HTTP: {}", e), - Self::HttpNotAvailable => write!( - f, - "Gateway feature is disabled and the required information was not in cache" - ), - } - } -} - -/// Look up a message by a string. -/// -/// The lookup strategy is as follows (in order): -/// 1. Lookup by "{channel ID}-{message ID}" (retrieved by shift-clicking on "Copy ID") -/// 2. Lookup by message ID (the message must be in the context channel) -/// 3. Lookup by message URL -#[async_trait::async_trait] -impl Parse for Message { - type Err = MessageParseError; - - async fn parse(ctx: &Context, msg: &Message, s: &str) -> Result { - let extract_from_id_pair = || { - let mut parts = s.splitn(2, '-'); - let channel_id = ChannelId(parts.next()?.parse().ok()?); - let message_id = MessageId(parts.next()?.parse().ok()?); - Some((channel_id, message_id)) - }; - - let extract_from_message_id = || Some((msg.channel_id, MessageId(s.parse().ok()?))); - - let extract_from_message_url = || { - let mut parts = s.strip_prefix("https://discord.com/channels/")?.splitn(3, '/'); - let _guild_id = GuildId(parts.next()?.parse().ok()?); - let channel_id = ChannelId(parts.next()?.parse().ok()?); - let message_id = MessageId(parts.next()?.parse().ok()?); - Some((channel_id, message_id)) - }; - - let (channel_id, message_id) = extract_from_id_pair() - .or_else(extract_from_message_id) - .or_else(extract_from_message_url) - .ok_or(MessageParseError::Malformed)?; - - #[cfg(feature = "cache")] - if let Some(msg) = ctx.cache.message(channel_id, message_id).await { - return Ok(msg); - } - - if cfg!(feature = "http") { - ctx.http.get_message(channel_id.0, message_id.0).await.map_err(MessageParseError::Http) - } else { - Err(MessageParseError::HttpNotAvailable) - } - } -}