Skip to content

Commit

Permalink
Rewrite collectors
Browse files Browse the repository at this point in the history
  • Loading branch information
kangalio committed Nov 9, 2022
1 parent 60c5ab3 commit 3838436
Show file tree
Hide file tree
Showing 40 changed files with 463 additions and 1,415 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ futures = { version = "0.3", default-features = false, features = ["std"] }
dep_time = { version = "0.3.6", package = "time", features = ["formatting", "parsing", "serde-well-known"] }
# Optional dependencies
fxhash = { version = "0.2.1", optional = true }
derivative = { version = "2.2.0", optional = true }
simd-json = { version = "0.6", optional = true }
uwl = { version = "0.6.0", optional = true }
base64 = { version = "0.13", optional = true }
Expand Down Expand Up @@ -82,7 +81,7 @@ builder = ["base64"]
cache = ["fxhash", "dashmap", "parking_lot"]
# Enables collectors, a utility feature that lets you await interaction events in code with
# zero setup, without needing to setup an InteractionCreate event listener.
collector = ["gateway", "model", "derivative"]
collector = ["gateway", "model"]
# Wraps the gateway and http functionality into a single interface
# TODO: should this require "gateway"?
client = ["http", "typemap_rev"]
Expand Down
42 changes: 17 additions & 25 deletions examples/e10_collectors/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::env;
use std::time::Duration;

use serenity::async_trait;
use serenity::collector::{EventCollectorBuilder, MessageCollectorBuilder};
use serenity::collector::MessageCollector;
use serenity::framework::standard::macros::{command, group, help};
use serenity::framework::standard::{
help_commands,
Expand Down Expand Up @@ -115,11 +115,7 @@ async fn challenge(ctx: &Context, msg: &Message, _: Args) -> CommandResult {
.author_id(msg.author.id);

if let Some(reaction) = collector.collect_single().await {
// By default, the collector will collect only added reactions.
// We could also pattern-match the reaction in case we want
// to handle added or removed reactions.
// In this case we will just get the inner reaction.
let _ = if reaction.as_inner_ref().emoji.as_data() == "1️⃣" {
let _ = if reaction.emoji.as_data() == "1️⃣" {
score += 1;
msg.reply(ctx, "That's correct!").await
} else {
Expand All @@ -132,14 +128,14 @@ async fn challenge(ctx: &Context, msg: &Message, _: Args) -> CommandResult {
let _ = msg.reply(ctx, "Write 5 messages in 10 seconds").await;

// We can create a collector from scratch too using this builder future.
let collector = MessageCollectorBuilder::new(&ctx.shard)
let collector = MessageCollector::new(&ctx.shard)
// Only collect messages by this user.
.author_id(msg.author.id)
.channel_id(msg.channel_id)
.collect_limit(5u32)
.timeout(Duration::from_secs(10))
// Build the collector.
.build();
// Build the collector.
.collect_stream()
.take(5);

// Let's acquire borrow HTTP to send a message inside the `async move`.
let http = &ctx.http;
Expand All @@ -164,26 +160,22 @@ async fn challenge(ctx: &Context, msg: &Message, _: Args) -> CommandResult {
score += 1;
}

// We can also collect arbitrary events using the generic EventCollector. For example, here we
// We can also collect arbitrary events using the collect() function. For example, here we
// collect updates to the messages that the user sent above and check for them updating all 5 of
// them.
let builder = EventCollectorBuilder::new(&ctx.shard)
.add_event_type(EventType::MessageUpdate)
.timeout(Duration::from_secs(20));

// Only collect MessageUpdate events for the 5 MessageIds we're interested in.
let mut collector =
collected.iter().try_fold(builder, |b, msg| b.add_message_id(msg.id))?.build();
let mut collector = serenity::collector::collect(&ctx.shard, move |event| match event {
// Only collect MessageUpdate events for the 5 MessageIds we're interested in.
Event::MessageUpdate(event) if collected.iter().any(|msg| event.id == msg.id) => {
Some(event.id)
},
_ => None,
})
.take_until(Box::pin(tokio::time::sleep(Duration::from_secs(20))));

let _ = msg.reply(ctx, "Edit each of those 5 messages in 20 seconds").await;
let mut edited = HashSet::new();
while let Some(event) = collector.next().await {
match event.as_ref() {
Event::MessageUpdate(e) => {
edited.insert(e.id);
},
e => panic!("Unexpected event type received: {:?}", e.event_type()),
}
while let Some(edited_message_id) = collector.next().await {
edited.insert(edited_message_id);
if edited.len() >= 5 {
break;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/e17_message_components/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl EventHandler for Handler {
let mut interaction_stream = m
.component_interaction_collector(&ctx.shard)
.timeout(Duration::from_secs(60 * 3))
.build();
.collect_stream();

while let Some(interaction) = interaction_stream.next().await {
let sound = &interaction.data.custom_id;
Expand Down
22 changes: 22 additions & 0 deletions examples/testing/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,28 @@ async fn message(ctx: &Context, msg: Message) -> Result<(), serenity::Error> {
})),
)
.await?;
} else if msg.content == "manybuttons" {
let mut custom_id = msg.id.to_string();
loop {
let msg = channel_id
.send_message(
ctx,
CreateMessage::new()
.button(CreateButton::new(custom_id.clone()).label(custom_id)),
)
.await?;
let button_press = msg
.component_interaction_collector(&ctx.shard)
.timeout(std::time::Duration::from_secs(10))
.collect_single()
.await;
match button_press {
Some(x) => x.defer(ctx).await?,
None => break,
}

custom_id = msg.id.to_string();
}
} else if msg.content == "reactionremoveemoji" {
// Test new ReactionRemoveEmoji gateway event: https://github.com/serenity-rs/serenity/issues/2248
msg.react(ctx, '👍').await?;
Expand Down
2 changes: 1 addition & 1 deletion src/builder/bot_auth_parameters.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use url::Url;

#[cfg(feature = "http")]
use crate::http::Http;
use crate::http::client::Http;
#[cfg(feature = "http")]
use crate::internal::prelude::*;
use crate::model::application::oauth::Scope;
Expand Down
2 changes: 1 addition & 1 deletion src/builder/create_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use crate::utils::check_overflow;
///
/// [`ChannelId::say`]: crate::model::id::ChannelId::say
/// [`ChannelId::send_message`]: crate::model::id::ChannelId::send_message
/// [`Http::send_message`]: crate::http::Http::send_message
/// [`Http::send_message`]: crate::http::client::Http::send_message
#[derive(Clone, Debug, Default, Serialize)]
#[must_use]
pub struct CreateMessage {
Expand Down
6 changes: 3 additions & 3 deletions src/builder/quick_modal.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use super::{CreateActionRow, CreateInputText, CreateInteractionResponse, CreateModal};
use crate::client::Context;
use crate::collector::ModalInteractionCollectorBuilder;
use crate::collector::ModalInteractionCollector;
use crate::model::id::InteractionId;
use crate::model::prelude::component::{ActionRowComponent, InputTextStyle};
use crate::model::prelude::ModalInteraction;

#[cfg(feature = "collector")]
pub struct QuickModalResponse {
pub interaction: std::sync::Arc<ModalInteraction>,
pub interaction: ModalInteraction,
pub inputs: Vec<String>,
}

Expand Down Expand Up @@ -99,7 +99,7 @@ impl CreateQuickModal {
);
builder.execute(ctx, interaction_id, token).await?;

let modal_interaction = ModalInteractionCollectorBuilder::new(&ctx.shard)
let modal_interaction = ModalInteractionCollector::new(&ctx.shard)
.custom_ids(vec![modal_custom_id])
.collect_single()
.await;
Expand Down
80 changes: 54 additions & 26 deletions src/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@
//! [`http`]: crate::http

use std::collections::{HashMap, HashSet, VecDeque};
use std::hash::Hash;
use std::hash::{BuildHasher, Hash};
use std::str::FromStr;
#[cfg(feature = "temp_cache")]
use std::sync::Arc;
#[cfg(feature = "temp_cache")]
use std::time::Duration;

use dashmap::mapref::entry::Entry;
use dashmap::mapref::one::Ref;
use dashmap::mapref::multiple::RefMulti;
use dashmap::mapref::one::{Ref, RefMut};
use dashmap::DashMap;
use fxhash::FxBuildHasher;
#[cfg(feature = "temp_cache")]
use moka::dash::Cache as DashCache;
use parking_lot::RwLock;
Expand All @@ -53,18 +55,15 @@ use crate::model::prelude::*;
mod cache_update;
mod event;
mod settings;
mod wrappers;

use wrappers::{BuildHasher, MaybeMap, ReadOnlyMapRef};

type MessageCache = DashMap<ChannelId, HashMap<MessageId, Message>, BuildHasher>;
type MessageCache = DashMap<ChannelId, HashMap<MessageId, Message>, FxBuildHasher>;

struct NotSend;

enum CacheRefInner<'a, K, V> {
#[cfg(feature = "temp_cache")]
Arc(Arc<V>),
DashRef(Ref<'a, K, V, BuildHasher>),
DashRef(Ref<'a, K, V, FxBuildHasher>),
ReadGuard(parking_lot::RwLockReadGuard<'a, V>),
}

Expand All @@ -86,7 +85,7 @@ impl<'a, K, V> CacheRef<'a, K, V> {
Self::new(CacheRefInner::Arc(inner))
}

fn from_ref(inner: Ref<'a, K, V, BuildHasher>) -> Self {
fn from_ref(inner: Ref<'a, K, V, FxBuildHasher>) -> Self {
Self::new(CacheRefInner::DashRef(inner))
}

Expand Down Expand Up @@ -160,6 +159,34 @@ pub(crate) struct CachedShardData {
pub has_sent_shards_ready: bool,
}

#[derive(Debug)]
pub(crate) struct MaybeMap<K: Eq + Hash, V>(Option<DashMap<K, V, FxBuildHasher>>);
impl<K: Eq + Hash, V> MaybeMap<K, V> {
pub fn iter(&self) -> impl Iterator<Item = RefMulti<'_, K, V, FxBuildHasher>> {
Option::iter(&self.0).flat_map(DashMap::iter)
}

pub fn get(&self, k: &K) -> Option<Ref<'_, K, V, FxBuildHasher>> {
self.0.as_ref()?.get(k)
}

pub fn get_mut(&self, k: &K) -> Option<RefMut<'_, K, V, FxBuildHasher>> {
self.0.as_ref()?.get_mut(k)
}

pub fn insert(&self, k: K, v: V) -> Option<V> {
self.0.as_ref()?.insert(k, v)
}

pub fn remove(&self, k: &K) -> Option<(K, V)> {
self.0.as_ref()?.remove(k)
}

pub fn len(&self) -> usize {
self.0.as_ref().map_or(0, |map| map.len())
}
}

/// A cache containing data received from [`Shard`]s.
///
/// Using the cache allows to avoid REST API requests via the [`http`] module
Expand Down Expand Up @@ -187,12 +214,12 @@ pub struct Cache {
///
/// The TTL for each value is configured in CacheSettings.
#[cfg(feature = "temp_cache")]
pub(crate) temp_channels: DashCache<ChannelId, GuildChannel, BuildHasher>,
pub(crate) temp_channels: DashCache<ChannelId, GuildChannel, FxBuildHasher>,
/// Cache of users who have been fetched from `to_user`.
///
/// The TTL for each value is configured in CacheSettings.
#[cfg(feature = "temp_cache")]
pub(crate) temp_users: DashCache<UserId, Arc<User>, BuildHasher>,
pub(crate) temp_users: DashCache<UserId, Arc<User>, FxBuildHasher>,

// Channels cache:
// ---
Expand Down Expand Up @@ -248,7 +275,7 @@ pub struct Cache {
/// This is simply a vecdeque so we can keep track of the order of messages
/// inserted into the cache. When a maximum number of messages are in a
/// channel's cache, we can pop the front and remove that ID from the cache.
pub(crate) message_queue: DashMap<ChannelId, VecDeque<MessageId>, BuildHasher>,
pub(crate) message_queue: DashMap<ChannelId, VecDeque<MessageId>, FxBuildHasher>,

// Miscellanous fixed-size data
// ---
Expand Down Expand Up @@ -289,12 +316,12 @@ impl Cache {
#[instrument]
pub fn new_with_settings(settings: Settings) -> Self {
#[cfg(feature = "temp_cache")]
fn temp_cache<K, V>(ttl: Duration) -> DashCache<K, V, BuildHasher>
fn temp_cache<K, V>(ttl: Duration) -> DashCache<K, V, FxBuildHasher>
where
K: Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
DashCache::builder().time_to_live(ttl).build_with_hasher(BuildHasher::default())
DashCache::builder().time_to_live(ttl).build_with_hasher(FxBuildHasher::default())
}

Self {
Expand Down Expand Up @@ -395,8 +422,10 @@ impl Cache {
///
/// println!("There are {} private channels", amount);
/// ```
pub fn private_channels(&self) -> ReadOnlyMapRef<'_, ChannelId, PrivateChannel> {
self.private_channels.as_read_only()
pub fn private_channels(
&self,
) -> DashMap<ChannelId, PrivateChannel, impl BuildHasher + Send + Sync + Clone> {
self.private_channels.0.clone().unwrap_or_default()
}

/// Fetches a vector of all [`Guild`]s' Ids that are stored in the cache.
Expand Down Expand Up @@ -428,11 +457,8 @@ impl Cache {
/// [`Context`]: crate::client::Context
/// [`Shard`]: crate::gateway::Shard
pub fn guilds(&self) -> Vec<GuildId> {
let unavailable_guilds = self.unavailable_guilds();

let unavailable_guild_ids = unavailable_guilds.iter().map(|i| *i.key());

self.guilds.iter().map(|i| *i.key()).chain(unavailable_guild_ids).collect()
let chain = self.unavailable_guilds().into_iter().map(|(k, _)| k);
self.guilds.iter().map(|i| *i.key()).chain(chain).collect()
}

/// Retrieves a [`Channel`] from the cache based on the given Id.
Expand Down Expand Up @@ -697,23 +723,25 @@ impl Cache {

/// This method clones and returns all unavailable guilds.
#[inline]
pub fn unavailable_guilds(&self) -> ReadOnlyMapRef<'_, GuildId, ()> {
self.unavailable_guilds.as_read_only()
pub fn unavailable_guilds(
&self,
) -> DashMap<GuildId, (), impl BuildHasher + Send + Sync + Clone> {
self.unavailable_guilds.0.clone().unwrap_or_default()
}

/// This method returns all channels from a guild of with the given `guild_id`.
#[inline]
pub fn guild_channels(
&self,
guild_id: impl Into<GuildId>,
) -> Option<DashMap<ChannelId, GuildChannel, BuildHasher>> {
) -> Option<DashMap<ChannelId, GuildChannel, FxBuildHasher>> {
self._guild_channels(guild_id.into())
}

fn _guild_channels(
&self,
guild_id: GuildId,
) -> Option<DashMap<ChannelId, GuildChannel, BuildHasher>> {
) -> Option<DashMap<ChannelId, GuildChannel, FxBuildHasher>> {
self.guilds.get(&guild_id).map(|g| g.channels.clone().into_iter().collect())
}

Expand Down Expand Up @@ -899,8 +927,8 @@ impl Cache {

/// Clones all users and returns them.
#[inline]
pub fn users(&self) -> ReadOnlyMapRef<'_, UserId, User> {
self.users.as_read_only()
pub fn users(&self) -> DashMap<UserId, User, impl BuildHasher + Send + Sync + Clone> {
self.users.0.clone().unwrap_or_default()
}

/// Returns the amount of cached users.
Expand Down
Loading

0 comments on commit 3838436

Please sign in to comment.