Skip to content

Commit

Permalink
Make ArgumentConvert compatible without the cache feature (#1818)
Browse files Browse the repository at this point in the history
  • Loading branch information
kangalio authored and arqunis committed Apr 18, 2022
1 parent 63a1000 commit cdaa70c
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 72 deletions.
1 change: 0 additions & 1 deletion src/utils/argument_convert/_template.rs
Expand Up @@ -24,7 +24,6 @@ impl fmt::Display for PLACEHOLDERParseError {
/// The lookup strategy is as follows (in order):
/// 1. Lookup by PLACEHOLDER
/// 2. [Lookup by PLACEHOLDER](`crate::utils::parse_PLACEHOLDER`).
#[cfg(feature = "cache")]
#[async_trait::async_trait]
impl ArgumentConvert for PLACEHOLDER {
type Err = PLACEHOLDERParseError;
Expand Down
25 changes: 17 additions & 8 deletions src/utils/argument_convert/channel.rs
Expand Up @@ -40,14 +40,17 @@ fn channel_belongs_to_guild(channel: &Channel, guild: GuildId) -> bool {
}
}

async fn lookup_channel_global(ctx: &Context, s: &str) -> Result<Channel, ChannelParseError> {
async fn lookup_channel_global(
ctx: &Context,
guild_id: Option<GuildId>,
s: &str,
) -> Result<Channel, ChannelParseError> {
if let Some(channel_id) = s.parse::<u64>().ok().or_else(|| crate::utils::parse_channel(s)) {
return ChannelId(channel_id).to_channel(ctx).await.map_err(ChannelParseError::Http);
}

let channels = &ctx.cache.channels;

if let Some(channel) = channels.iter().find_map(|m| {
#[cfg(feature = "cache")]
if let Some(channel) = ctx.cache.channels.iter().find_map(|m| {
let channel = m.value();
if channel.name.eq_ignore_ascii_case(s) {
Some(channel.clone())
Expand All @@ -58,6 +61,15 @@ async fn lookup_channel_global(ctx: &Context, s: &str) -> Result<Channel, Channe
return Ok(Channel::Guild(channel));
}

if let Some(guild_id) = guild_id {
let channels = ctx.http.get_channels(guild_id.0).await.map_err(ChannelParseError::Http)?;
if let Some(channel) =
channels.into_iter().find(|channel| channel.name.eq_ignore_ascii_case(s))
{
return Ok(Channel::Guild(channel));
}
}

Err(ChannelParseError::NotFoundOrMalformed)
}

Expand All @@ -71,7 +83,6 @@ async fn lookup_channel_global(ctx: &Context, s: &str) -> Result<Channel, Channe
/// 1. Lookup by ID.
/// 2. [Lookup by mention](`crate::utils::parse_channel`).
/// 3. Lookup by name.
#[cfg(feature = "cache")]
#[async_trait::async_trait]
impl ArgumentConvert for Channel {
type Err = ChannelParseError;
Expand All @@ -82,7 +93,7 @@ impl ArgumentConvert for Channel {
_channel_id: Option<ChannelId>,
s: &str,
) -> Result<Self, Self::Err> {
let channel = lookup_channel_global(ctx, s).await?;
let channel = lookup_channel_global(ctx, guild_id, s).await?;

// Don't yield for other guilds' channels
if let Some(guild_id) = guild_id {
Expand Down Expand Up @@ -133,7 +144,6 @@ impl fmt::Display for GuildChannelParseError {
/// Lookup is done by the global cache, hence the cache feature needs to be enabled.
///
/// For more information, see the ArgumentConvert implementation for [`Channel`]
#[cfg(feature = "cache")]
#[async_trait::async_trait]
impl ArgumentConvert for GuildChannel {
type Err = GuildChannelParseError;
Expand Down Expand Up @@ -193,7 +203,6 @@ impl fmt::Display for ChannelCategoryParseError {
/// Lookup is done by the global cache, hence the cache feature needs to be enabled.
///
/// For more information, see the ArgumentConvert implementation for [`Channel`]
#[cfg(feature = "cache")]
#[async_trait::async_trait]
impl ArgumentConvert for ChannelCategory {
type Err = ChannelCategoryParseError;
Expand Down
27 changes: 18 additions & 9 deletions src/utils/argument_convert/emoji.rs
Expand Up @@ -7,8 +7,12 @@ use crate::{model::prelude::*, prelude::*};
#[non_exhaustive]
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
pub enum EmojiParseError {
/// Parser was invoked outside a guild.
OutsideGuild,
/// Guild was not in cache, or guild HTTP request failed.
FailedToRetrieveGuild,
/// The provided emoji string failed to parse, or the parsed result cannot be found in the
/// cache.
/// guild roles.
NotFoundOrMalformed,
}

Expand All @@ -17,6 +21,8 @@ impl std::error::Error for EmojiParseError {}
impl fmt::Display for EmojiParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::OutsideGuild => f.write_str("Tried to find emoji outside a guild"),
Self::FailedToRetrieveGuild => f.write_str("Could not retrieve guild data"),
Self::NotFoundOrMalformed => f.write_str("Emoji not found or unknown format"),
}
}
Expand All @@ -30,33 +36,36 @@ impl fmt::Display for EmojiParseError {
/// 1. Lookup by ID.
/// 2. [Lookup by extracting ID from the emoji](`crate::utils::parse_emoji`).
/// 3. Lookup by name.
#[cfg(feature = "cache")]
#[async_trait::async_trait]
impl ArgumentConvert for Emoji {
type Err = EmojiParseError;

async fn convert(
ctx: &Context,
_guild_id: Option<GuildId>,
guild_id: Option<GuildId>,
_channel_id: Option<ChannelId>,
s: &str,
) -> Result<Self, Self::Err> {
let guilds = &ctx.cache.guilds;
// Get Guild or PartialGuild
let guild_id = guild_id.ok_or(EmojiParseError::OutsideGuild)?;
#[cfg(feature = "cache")]
let guild = ctx.cache.guilds.get(&guild_id);
#[cfg(not(feature = "cache"))]
let guild = ctx.http.get_guild(guild_id.0).await.ok();
let guild = guild.ok_or(EmojiParseError::FailedToRetrieveGuild)?;

let direct_id = s.parse::<u64>().ok().map(EmojiId);
let id_from_mention = crate::utils::parse_emoji(s).map(|e| e.id);

if let Some(emoji_id) = direct_id.or(id_from_mention) {
if let Some(emoji) =
guilds.iter().find_map(|guild| guild.emojis.get(&emoji_id).cloned())
{
if let Some(emoji) = guild.emojis.get(&emoji_id).cloned() {
return Ok(emoji);
}
}

if let Some(emoji) = guilds.iter().find_map(|guild| {
if let Some(emoji) =
guild.emojis.values().find(|emoji| emoji.name.eq_ignore_ascii_case(s)).cloned()
}) {
{
return Ok(emoji);
}

Expand Down
4 changes: 3 additions & 1 deletion src/utils/argument_convert/guild.rs
@@ -1,3 +1,6 @@
// From HTTP you can only get PartialGuild; for Guild you need gateway and cache
#![cfg(feature = "cache")]

use std::fmt;

use super::ArgumentConvert;
Expand Down Expand Up @@ -25,7 +28,6 @@ impl fmt::Display for GuildParseError {
/// Look up a Guild, either by ID or by a string case-insensitively.
///
/// Requires the cache feature to be enabled.
#[cfg(feature = "cache")]
#[async_trait::async_trait]
impl ArgumentConvert for Guild {
type Err = GuildParseError;
Expand Down
12 changes: 4 additions & 8 deletions src/utils/argument_convert/member.rs
Expand Up @@ -38,7 +38,6 @@ impl fmt::Display for MemberParseError {
/// 3. [Lookup by name#discrim](`crate::utils::parse_user_tag`).
/// 4. Lookup by name
/// 5. Lookup by nickname
#[cfg(feature = "cache")]
#[async_trait::async_trait]
impl ArgumentConvert for Member {
type Err = MemberParseError;
Expand All @@ -49,16 +48,13 @@ impl ArgumentConvert for Member {
_channel_id: Option<ChannelId>,
s: &str,
) -> Result<Self, Self::Err> {
let guild = guild_id
.ok_or(MemberParseError::OutsideGuild)?
.to_guild_cached(ctx)
.ok_or(MemberParseError::GuildNotInCache)?;
let guild_id = guild_id.ok_or(MemberParseError::OutsideGuild)?;

// DON'T use guild.members: it's only populated when guild presences intent is enabled!

// If string is a raw user ID or a mention
if let Some(user_id) = s.parse().ok().or_else(|| crate::utils::parse_username(s)) {
if let Ok(member) = guild.member(ctx, UserId(user_id)).await {
if let Ok(member) = guild_id.member(ctx, UserId(user_id)).await {
return Ok(member);
}
}
Expand All @@ -67,7 +63,7 @@ impl ArgumentConvert for Member {

// If string is a username+discriminator
if let Some((name, discrim)) = crate::utils::parse_user_tag(s) {
if let Ok(member_results) = guild.search_members(ctx, name, Some(100)).await {
if let Ok(member_results) = guild_id.search_members(ctx, name, Some(100)).await {
if let Some(member) = member_results.into_iter().find(|m| {
m.user.name.eq_ignore_ascii_case(name) && m.user.discriminator == discrim
}) {
Expand All @@ -77,7 +73,7 @@ impl ArgumentConvert for Member {
}

// If string is username or nickname
if let Ok(member_results) = guild.search_members(ctx, s, Some(100)).await {
if let Ok(member_results) = guild_id.search_members(ctx, s, Some(100)).await {
if let Some(member) = member_results.into_iter().find(|m| {
m.user.name.eq_ignore_ascii_case(s)
|| m.nick.as_ref().map_or(false, |nick| nick.eq_ignore_ascii_case(s))
Expand Down
37 changes: 30 additions & 7 deletions src/utils/argument_convert/role.rs
Expand Up @@ -5,25 +5,37 @@ use crate::{model::prelude::*, prelude::*};

/// Error that can be returned from [`Role::convert`].
#[non_exhaustive]
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
#[derive(Debug)]
#[allow(clippy::enum_variant_names)]
pub enum RoleParseError {
/// When the operation was invoked outside a guild.
NotInGuild,
/// When the guild's roles were not found in cache.
NotInCache,
/// HTTP error while retrieving guild roles.
Http(SerenityError),
/// The provided channel string failed to parse, or the parsed result cannot be found in the
/// cache.
NotFoundOrMalformed,
}

impl std::error::Error for RoleParseError {}
impl std::error::Error for RoleParseError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
RoleParseError::NotInGuild => None,
RoleParseError::NotInCache => None,
RoleParseError::Http(e) => Some(e),
RoleParseError::NotFoundOrMalformed => None,
}
}
}

impl fmt::Display for RoleParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NotInGuild => f.write_str("Must invoke this operation in a guild"),
Self::NotInCache => f.write_str("Guild's roles were not found in cache"),
Self::Http(_) => f.write_str("Failed to retrieve roles via HTTP"),
Self::NotFoundOrMalformed => f.write_str("Role not found or unknown format"),
}
}
Expand All @@ -37,7 +49,6 @@ impl fmt::Display for RoleParseError {
/// 1. Lookup by ID
/// 2. [Lookup by mention](`crate::utils::parse_role`).
/// 3. Lookup by name (case-insensitive)
#[cfg(feature = "cache")]
#[async_trait::async_trait]
impl ArgumentConvert for Role {
type Err = RoleParseError;
Expand All @@ -48,20 +59,32 @@ impl ArgumentConvert for Role {
_channel_id: Option<ChannelId>,
s: &str,
) -> Result<Self, Self::Err> {
let roles = ctx
.cache
.guild_roles(guild_id.ok_or(RoleParseError::NotInGuild)?)
.ok_or(RoleParseError::NotInCache)?;
let guild_id = guild_id.ok_or(RoleParseError::NotInGuild)?;

#[cfg(feature = "cache")]
let roles = ctx.cache.guild_roles(guild_id).ok_or(RoleParseError::NotInCache)?;
#[cfg(not(feature = "cache"))]
let roles = ctx.http.get_guild_roles(guild_id.0).await.map_err(RoleParseError::Http)?;

if let Some(role_id) = s.parse::<u64>().ok().or_else(|| crate::utils::parse_role(s)) {
#[cfg(feature = "cache")]
if let Some(role) = roles.get(&RoleId(role_id)) {
return Ok(role.clone());
}
#[cfg(not(feature = "cache"))]
if let Some(role) = roles.iter().find(|role| role.id.0 == role_id) {
return Ok(role.clone());
}
}

#[cfg(feature = "cache")]
if let Some(role) = roles.values().find(|role| role.name.eq_ignore_ascii_case(s)) {
return Ok(role.clone());
}
#[cfg(not(feature = "cache"))]
if let Some(role) = roles.into_iter().find(|role| role.name.eq_ignore_ascii_case(s)) {
return Ok(role);
}

Err(RoleParseError::NotFoundOrMalformed)
}
Expand Down
76 changes: 40 additions & 36 deletions src/utils/argument_convert/user.rs
Expand Up @@ -23,6 +23,44 @@ impl fmt::Display for UserParseError {
}
}

#[cfg(feature = "cache")]
fn lookup_by_global_cache(ctx: &Context, s: &str) -> Option<User> {
let users = &ctx.cache.users;

let lookup_by_id = || users.get(&UserId(s.parse().ok()?)).map(|u| u.clone());

let lookup_by_mention =
|| users.get(&UserId(crate::utils::parse_username(s)?)).map(|u| u.clone());

let lookup_by_name_and_discrim = || {
let (name, discrim) = crate::utils::parse_user_tag(s)?;
users.iter().find_map(|m| {
let user = m.value();
if user.discriminator == discrim && user.name.eq_ignore_ascii_case(name) {
Some(user.clone())
} else {
None
}
})
};

let lookup_by_name = || {
users.iter().find_map(|m| {
let user = m.value();
if user.name == s {
Some(user.clone())
} else {
None
}
})
};

lookup_by_id()
.or_else(lookup_by_mention)
.or_else(lookup_by_name_and_discrim)
.or_else(lookup_by_name)
}

/// Look up a user by a string case-insensitively.
///
/// Requires the cache feature to be enabled. If a user is not in cache, they will not be found!
Expand All @@ -32,7 +70,6 @@ impl fmt::Display for UserParseError {
/// 2. [Lookup by mention](`crate::utils::parse_username`).
/// 3. [Lookup by name#discrim](`crate::utils::parse_user_tag`).
/// 4. Lookup by name
#[cfg(feature = "cache")]
#[async_trait::async_trait]
impl ArgumentConvert for User {
type Err = UserParseError;
Expand All @@ -43,42 +80,9 @@ impl ArgumentConvert for User {
channel_id: Option<ChannelId>,
s: &str,
) -> Result<Self, Self::Err> {
let users = &ctx.cache.users;

let lookup_by_id = || users.get(&UserId(s.parse().ok()?)).map(|u| u.clone());

let lookup_by_mention =
|| users.get(&UserId(crate::utils::parse_username(s)?)).map(|u| u.clone());

let lookup_by_name_and_discrim = || {
let (name, discrim) = crate::utils::parse_user_tag(s)?;
users.iter().find_map(|m| {
let user = m.value();
if user.discriminator == discrim && user.name.eq_ignore_ascii_case(name) {
Some(user.clone())
} else {
None
}
})
};

let lookup_by_name = || {
users.iter().find_map(|m| {
let user = m.value();
if user.name == s {
Some(user.clone())
} else {
None
}
})
};

// Try to look up in global user cache via a variety of methods
if let Some(user) = lookup_by_id()
.or_else(lookup_by_mention)
.or_else(lookup_by_name_and_discrim)
.or_else(lookup_by_name)
{
#[cfg(feature = "cache")]
if let Some(user) = lookup_by_global_cache(ctx, s) {
return Ok(user);
}

Expand Down

0 comments on commit cdaa70c

Please sign in to comment.