Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 13 additions & 6 deletions crates/factor-key-value/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ pub trait StoreManager: Sync + Send {

#[async_trait]
pub trait Store: Sync + Send {
async fn after_open(&self) -> Result<(), Error> {
Ok(())
}
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error>;
async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error>;
async fn delete(&self, key: &str) -> Result<(), Error>;
Expand Down Expand Up @@ -109,11 +112,13 @@ impl key_value::HostStore for KeyValueDispatch {
async fn open(&mut self, name: String) -> Result<Result<Resource<key_value::Store>, Error>> {
Ok(async {
if self.allowed_stores.contains(&name) {
let store = self
let store = self.manager.get(&name).await?;
store.after_open().await?;
let store_idx = self
.stores
.push(self.manager.get(&name).await?)
.push(store)
.map_err(|()| Error::StoreTableFull)?;
Ok(Resource::new_own(store))
Ok(Resource::new_own(store_idx))
} else {
Err(Error::AccessDenied)
}
Expand Down Expand Up @@ -193,11 +198,13 @@ impl wasi_keyvalue::store::Host for KeyValueDispatch {
identifier: String,
) -> Result<Resource<wasi_keyvalue::store::Bucket>, wasi_keyvalue::store::Error> {
if self.allowed_stores.contains(&identifier) {
let store = self
let store = self.manager.get(&identifier).await.map_err(to_wasi_err)?;
store.after_open().await.map_err(to_wasi_err)?;
let store_idx = self
.stores
.push(self.manager.get(&identifier).await.map_err(to_wasi_err)?)
.push(store)
.map_err(|()| wasi_keyvalue::store::Error::Other("store table full".to_string()))?;
Ok(Resource::new_own(store))
Ok(Resource::new_own(store_idx))
} else {
Err(wasi_keyvalue::store::Error::AccessDenied)
}
Expand Down
4 changes: 4 additions & 0 deletions crates/factor-key-value/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ struct CachingStore {

#[async_trait]
impl Store for CachingStore {
async fn after_open(&self) -> Result<(), Error> {
self.inner.after_open().await
}

async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
// Retrieve the specified value from the cache, lazily populating the cache as necessary.

Expand Down
2 changes: 1 addition & 1 deletion crates/key-value-redis/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = { workspace = true }

[dependencies]
anyhow = { workspace = true }
redis = { version = "0.27", features = ["tokio-comp", "tokio-native-tls-comp"] }
redis = { version = "0.28", features = ["tokio-comp", "tokio-native-tls-comp", "connection-manager"] }
serde = { workspace = true }
spin-core = { path = "../core" }
spin-factor-key-value = { path = "../factor-key-value" }
Expand Down
83 changes: 28 additions & 55 deletions crates/key-value-redis/src/store.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use anyhow::{Context, Result};
use redis::{aio::MultiplexedConnection, parse_redis_url, AsyncCommands, Client, RedisError};
use redis::{aio::ConnectionManager, parse_redis_url, AsyncCommands, Client, RedisError};
use spin_core::async_trait;
use spin_factor_key_value::{log_error, Cas, Error, Store, StoreManager, SwapError};
use std::ops::DerefMut;
use std::sync::Arc;
use tokio::sync::{Mutex, OnceCell};
use tokio::sync::OnceCell;
use url::Url;

pub struct KeyValueRedis {
database_url: Url,
connection: OnceCell<Arc<Mutex<MultiplexedConnection>>>,
connection: OnceCell<ConnectionManager>,
}

impl KeyValueRedis {
Expand All @@ -30,10 +29,8 @@ impl StoreManager for KeyValueRedis {
.connection
.get_or_try_init(|| async {
Client::open(self.database_url.clone())?
.get_multiplexed_async_connection()
.get_connection_manager()
.await
.map(Mutex::new)
.map(Arc::new)
})
.await
.map_err(log_error)?;
Expand All @@ -55,90 +52,69 @@ impl StoreManager for KeyValueRedis {
}

struct RedisStore {
connection: Arc<Mutex<MultiplexedConnection>>,
connection: ConnectionManager,
database_url: Url,
}

struct CompareAndSwap {
key: String,
connection: Arc<Mutex<MultiplexedConnection>>,
connection: ConnectionManager,
bucket_rep: u32,
}

#[async_trait]
impl Store for RedisStore {
async fn after_open(&self) -> Result<(), Error> {
if let Err(_error) = self.connection.clone().ping::<()>().await {
// If an IO error happens, ConnectionManager will start reconnection in the background
// so we do not take any action and just pray re-connection will be successful.
}
Ok(())
}

async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
let mut conn = self.connection.lock().await;
conn.get(key).await.map_err(log_error)
self.connection.clone().get(key).await.map_err(log_error)
}

async fn set(&self, key: &str, value: &[u8]) -> Result<(), Error> {
self.connection
.lock()
.await
.clone()
.set(key, value)
.await
.map_err(log_error)
}

async fn delete(&self, key: &str) -> Result<(), Error> {
self.connection
.lock()
.await
.del(key)
.await
.map_err(log_error)
self.connection.clone().del(key).await.map_err(log_error)
}

async fn exists(&self, key: &str) -> Result<bool, Error> {
self.connection
.lock()
.await
.exists(key)
.await
.map_err(log_error)
self.connection.clone().exists(key).await.map_err(log_error)
}

async fn get_keys(&self) -> Result<Vec<String>, Error> {
self.connection
.lock()
.await
.keys("*")
.await
.map_err(log_error)
self.connection.clone().keys("*").await.map_err(log_error)
}

async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
self.connection
.lock()
.await
.keys(keys)
.await
.map_err(log_error)
self.connection.clone().keys(keys).await.map_err(log_error)
}

async fn set_many(&self, key_values: Vec<(String, Vec<u8>)>) -> Result<(), Error> {
self.connection
.lock()
.await
.clone()
.mset(&key_values)
.await
.map_err(log_error)
}

async fn delete_many(&self, keys: Vec<String>) -> Result<(), Error> {
self.connection
.lock()
.await
.del(keys)
.await
.map_err(log_error)
self.connection.clone().del(keys).await.map_err(log_error)
}

async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
self.connection
.lock()
.await
.clone()
.incr(key, delta)
.await
.map_err(log_error)
Expand All @@ -154,10 +130,8 @@ impl Store for RedisStore {
) -> Result<Arc<dyn Cas>, Error> {
let cx = Client::open(self.database_url.clone())
.map_err(log_error)?
.get_multiplexed_async_connection()
.get_connection_manager()
.await
.map(Mutex::new)
.map(Arc::new)
.map_err(log_error)?;

Ok(Arc::new(CompareAndSwap {
Expand All @@ -175,12 +149,11 @@ impl Cas for CompareAndSwap {
async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
redis::cmd("WATCH")
.arg(&self.key)
.exec_async(self.connection.lock().await.deref_mut())
.exec_async(&mut self.connection.clone())
.await
.map_err(log_error)?;
self.connection
.lock()
.await
.clone()
.get(&self.key)
.await
.map_err(log_error)
Expand All @@ -194,12 +167,12 @@ impl Cas for CompareAndSwap {
let res: Result<(), RedisError> = transaction
.atomic()
.set(&self.key, value)
.query_async(self.connection.lock().await.deref_mut())
.query_async(&mut self.connection.clone())
.await;

redis::cmd("UNWATCH")
.arg(&self.key)
.exec_async(self.connection.lock().await.deref_mut())
.exec_async(&mut self.connection.clone())
.await
.map_err(|err| SwapError::CasFailed(format!("{err:?}")))?;

Expand Down
Loading