Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure Redis clusters and multiplexed connections are supported #4

Merged
merged 1 commit into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions omniqueue/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ edition = "2021"
async-trait = "0.1"
aws-config = { version = "0.55", optional = true }
aws-sdk-sqs = { version = "0.25", optional = true }
bb8 = { version = "0.7.1", optional = true }
bb8-redis = { version = "0.10.1", optional = true }
bb8 = { version = "0.8", optional = true }
bb8-redis = { version = "0.13", optional = true }
futures = { version = "0.3", default-features = false, features = ["async-await", "std"] }
lapin = { version = "2", optional = true }
rdkafka = { version = "0.29", features = ["cmake-build", "ssl", "tracing"] }
redis = { version = "0.21.5", features = ["tokio-comp", "tokio-native-tls-comp", "streams"], optional = true }
redis = { version = "0.23", features = ["tokio-comp", "tokio-native-tls-comp", "streams"], optional = true }
serde = { version = "1", features = ["derive", "rc"] }
serde_json = "1"
thiserror = "1"
Expand All @@ -31,8 +31,9 @@ tokio-executor-trait = "2.1"
tokio-reactor-trait = "1.1"

[features]
default = ["memory_queue", "rabbitmq", "redis", "sqs"]
default = ["memory_queue", "rabbitmq", "redis", "redis_cluster", "sqs"]
memory_queue = []
rabbitmq = ["dep:lapin"]
redis = ["dep:bb8", "dep:bb8-redis", "dep:redis"]
redis_cluster = ["redis", "redis/cluster-async"]
sqs = ["dep:aws-config", "dep:aws-sdk-sqs"]
45 changes: 45 additions & 0 deletions omniqueue/src/backends/redis/cluster.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use async_trait::async_trait;

use redis::{
cluster::{ClusterClient, ClusterClientBuilder},
ErrorKind, IntoConnectionInfo, RedisError,
};

/// ConnectionManager that implements `bb8::ManageConnection` and supports asynchronous clustered
/// connections via `redis::cluster::ClusterClient`
#[derive(Clone)]
pub struct RedisClusterConnectionManager {
client: ClusterClient,
}

impl RedisClusterConnectionManager {
pub fn new<T: IntoConnectionInfo>(
info: T,
) -> Result<RedisClusterConnectionManager, RedisError> {
Ok(RedisClusterConnectionManager {
client: ClusterClientBuilder::new(vec![info]).build()?,
})
}
}

#[async_trait]
impl bb8::ManageConnection for RedisClusterConnectionManager {
type Connection = redis::cluster_async::ClusterConnection;
type Error = RedisError;

async fn connect(&self) -> Result<Self::Connection, Self::Error> {
self.client.get_async_connection().await
}

async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
let pong: String = redis::cmd("PING").query_async(&mut *conn).await?;
match pong.as_str() {
"PONG" => Ok(()),
_ => Err((ErrorKind::ResponseError, "ping request").into()),
}
}

fn has_broken(&self, _: &mut Self::Connection) -> bool {
false
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::{any::TypeId, collections::HashMap};
use std::{any::TypeId, collections::HashMap, marker::PhantomData};

use async_trait::async_trait;
use bb8_redis::RedisConnectionManager;
use bb8::ManageConnection;
use bb8_redis::RedisMultiplexedConnectionManager;
use redis::streams::{StreamReadOptions, StreamReadReply};

use crate::{
Expand All @@ -11,6 +12,30 @@ use crate::{
QueueError,
};

mod cluster;
use cluster::RedisClusterConnectionManager;

pub trait RedisConnection
where
Self: ManageConnection + Sized,
Self::Connection: redis::aio::ConnectionLike,
Self::Error: 'static + std::error::Error + Send + Sync,
{
fn from_dsn(dsn: &str) -> Result<Self, QueueError>;
}

impl RedisConnection for RedisMultiplexedConnectionManager {
fn from_dsn(dsn: &str) -> Result<Self, QueueError> {
Self::new(dsn).map_err(QueueError::generic)
}
}

impl RedisConnection for RedisClusterConnectionManager {
fn from_dsn(dsn: &str) -> Result<Self, QueueError> {
Self::new(dsn).map_err(QueueError::generic)
}
}

pub struct RedisConfig {
pub dsn: String,
pub max_connections: u16,
Expand All @@ -21,25 +46,31 @@ pub struct RedisConfig {
pub payload_key: String,
}

pub struct RedisQueueBackend;
pub struct RedisQueueBackend<R = RedisMultiplexedConnectionManager>(PhantomData<R>);
pub type RedisClusterQueueBackend = RedisQueueBackend<RedisClusterConnectionManager>;

#[async_trait]
impl QueueBackend for RedisQueueBackend {
impl<R> QueueBackend for RedisQueueBackend<R>
where
R: RedisConnection,
R::Connection: redis::aio::ConnectionLike + Send + Sync,
R::Error: 'static + std::error::Error + Send + Sync,
{
type Config = RedisConfig;

// FIXME: Is it possible to use the types Redis actually uses?
type PayloadIn = Vec<u8>;
type PayloadOut = Vec<u8>;

type Producer = RedisStreamProducer;
type Consumer = RedisStreamConsumer;
type Producer = RedisStreamProducer<R>;
type Consumer = RedisStreamConsumer<R>;

async fn new_pair(
cfg: RedisConfig,
custom_encoders: EncoderRegistry<Vec<u8>>,
custom_decoders: DecoderRegistry<Vec<u8>>,
) -> Result<(RedisStreamProducer, RedisStreamConsumer), QueueError> {
let redis = RedisConnectionManager::new(cfg.dsn).map_err(QueueError::generic)?;
) -> Result<(RedisStreamProducer<R>, RedisStreamConsumer<R>), QueueError> {
let redis = R::from_dsn(&cfg.dsn)?;
let redis = bb8::Pool::builder()
.max_size(cfg.max_connections.into())
.build(redis)
Expand Down Expand Up @@ -67,8 +98,8 @@ impl QueueBackend for RedisQueueBackend {
async fn producing_half(
cfg: RedisConfig,
custom_encoders: EncoderRegistry<Vec<u8>>,
) -> Result<RedisStreamProducer, QueueError> {
let redis = RedisConnectionManager::new(cfg.dsn).map_err(QueueError::generic)?;
) -> Result<RedisStreamProducer<R>, QueueError> {
let redis = R::from_dsn(&cfg.dsn)?;
let redis = bb8::Pool::builder()
.max_size(cfg.max_connections.into())
.build(redis)
Expand All @@ -86,8 +117,8 @@ impl QueueBackend for RedisQueueBackend {
async fn consuming_half(
cfg: RedisConfig,
custom_decoders: DecoderRegistry<Vec<u8>>,
) -> Result<RedisStreamConsumer, QueueError> {
let redis = RedisConnectionManager::new(cfg.dsn).map_err(QueueError::generic)?;
) -> Result<RedisStreamConsumer<R>, QueueError> {
let redis = R::from_dsn(&cfg.dsn)?;
let redis = bb8::Pool::builder()
.max_size(cfg.max_connections.into())
.build(redis)
Expand All @@ -105,8 +136,8 @@ impl QueueBackend for RedisQueueBackend {
}
}

pub struct RedisStreamAcker {
redis: bb8::Pool<RedisConnectionManager>,
pub struct RedisStreamAcker<M: ManageConnection> {
redis: bb8::Pool<M>,
queue_key: String,
consumer_group: String,
entry_id: String,
Expand All @@ -115,7 +146,12 @@ pub struct RedisStreamAcker {
}

#[async_trait]
impl Acker for RedisStreamAcker {
impl<M> Acker for RedisStreamAcker<M>
where
M: ManageConnection,
M::Connection: redis::aio::ConnectionLike + Send + Sync,
M::Error: 'static + std::error::Error + Send + Sync,
{
async fn ack(&mut self) -> Result<(), QueueError> {
if self.already_acked_or_nacked {
return Err(QueueError::CannotAckOrNackTwice);
Expand Down Expand Up @@ -143,15 +179,20 @@ impl Acker for RedisStreamAcker {
}
}

pub struct RedisStreamProducer {
pub struct RedisStreamProducer<M: ManageConnection> {
registry: EncoderRegistry<Vec<u8>>,
redis: bb8::Pool<RedisConnectionManager>,
redis: bb8::Pool<M>,
queue_key: String,
payload_key: String,
}

#[async_trait]
impl QueueProducer for RedisStreamProducer {
impl<M> QueueProducer for RedisStreamProducer<M>
where
M: ManageConnection,
M::Connection: redis::aio::ConnectionLike + Send + Sync,
M::Error: 'static + std::error::Error + Send + Sync,
{
type Payload = Vec<u8>;

fn get_custom_encoders(&self) -> &HashMap<TypeId, Box<dyn CustomEncoder<Self::Payload>>> {
Expand All @@ -169,17 +210,22 @@ impl QueueProducer for RedisStreamProducer {
}
}

pub struct RedisStreamConsumer {
pub struct RedisStreamConsumer<M: ManageConnection> {
registry: DecoderRegistry<Vec<u8>>,
redis: bb8::Pool<RedisConnectionManager>,
redis: bb8::Pool<M>,
queue_key: String,
consumer_group: String,
consumer_name: String,
payload_key: String,
}

#[async_trait]
impl QueueConsumer for RedisStreamConsumer {
impl<M> QueueConsumer for RedisStreamConsumer<M>
where
M: ManageConnection,
M::Connection: redis::aio::ConnectionLike + Send + Sync,
M::Error: 'static + std::error::Error + Send + Sync,
{
type Payload = Vec<u8>;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
Expand Down
125 changes: 125 additions & 0 deletions omniqueue/tests/redis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use omniqueue::{
backends::redis::{RedisConfig, RedisQueueBackend},
queue::{consumer::QueueConsumer, producer::QueueProducer, QueueBackend, QueueBuilder, Static},
};
use redis::{AsyncCommands, Client, Commands};
use serde::{Deserialize, Serialize};

const ROOT_URL: &str = "redis://localhost";

pub struct RedisStreamDrop(String);
impl Drop for RedisStreamDrop {
fn drop(&mut self) {
let client = Client::open(ROOT_URL).unwrap();
let mut conn = client.get_connection().unwrap();
let _: () = conn.del(&self.0).unwrap();
}
}

/// Returns a [`QueueBuilder`] configured to connect to the Redis instance spawned by the file
/// `testing-docker-compose.yaml` in the root of the repository.
///
/// Additionally this will make a temporary stream on that instance for the duration of the test
/// such as to ensure there is no stealing
///
/// This will also return a [`RedisStreamDrop`] to clean up the stream after the test ends.
async fn make_test_queue() -> (QueueBuilder<RedisQueueBackend, Static>, RedisStreamDrop) {
let stream_name: String = std::iter::repeat_with(fastrand::alphanumeric)
.take(8)
.collect();

let client = Client::open(ROOT_URL).unwrap();
let mut conn = client.get_async_connection().await.unwrap();

let _: () = conn
.xgroup_create_mkstream(&stream_name, "test_cg", 0i8)
.await
.unwrap();

let config = RedisConfig {
dsn: ROOT_URL.to_owned(),
max_connections: 8,
reinsert_on_nack: false,
queue_key: stream_name.clone(),
consumer_group: "test_cg".to_owned(),
consumer_name: "test_cn".to_owned(),
payload_key: "payload".to_owned(),
};

(
RedisQueueBackend::builder(config),
RedisStreamDrop(stream_name),
)
}

#[tokio::test]
async fn test_raw_send_recv() {
let (builder, _drop) = make_test_queue().await;
let payload = b"{\"test\": \"data\"}";
let (p, mut c) = builder.build_pair().await.unwrap();

p.send_raw(&payload.to_vec()).await.unwrap();

let d = c.receive().await.unwrap();
assert_eq!(d.borrow_payload().unwrap(), payload);
}

#[tokio::test]
async fn test_bytes_send_recv() {
let (builder, _drop) = make_test_queue().await;
let payload = b"hello";
let (p, mut c) = builder.build_pair().await.unwrap();

p.send_bytes(payload).await.unwrap();

let d = c.receive().await.unwrap();
assert_eq!(d.borrow_payload().unwrap(), payload);
d.ack().await.unwrap();
}

#[derive(Debug, Deserialize, Serialize, PartialEq)]
pub struct ExType {
a: u8,
}

#[tokio::test]
async fn test_serde_send_recv() {
let (builder, _drop) = make_test_queue().await;
let payload = ExType { a: 2 };
let (p, mut c) = builder.build_pair().await.unwrap();

p.send_serde_json(&payload).await.unwrap();

let d = c.receive().await.unwrap();
assert_eq!(d.payload_serde_json::<ExType>().unwrap().unwrap(), payload);
d.ack().await.unwrap();
}

#[tokio::test]
async fn test_custom_send_recv() {
let (builder, _drop) = make_test_queue().await;
let payload = ExType { a: 3 };

let encoder = |p: &ExType| Ok(vec![p.a]);
let decoder = |p: &Vec<u8>| {
Ok(ExType {
a: *p.first().unwrap_or(&0),
})
};

let (p, mut c) = builder
.with_encoder(encoder)
.with_decoder(decoder)
.build_pair()
.await
.unwrap();

p.send_custom(&payload).await.unwrap();

let d = c.receive().await.unwrap();
assert_eq!(d.payload_custom::<ExType>().unwrap().unwrap(), payload);

// Because it doesn't use JSON, this should fail:
d.payload_serde_json::<ExType>().unwrap_err();
d.ack().await.unwrap();
}
Loading
Loading