Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mpz-common): multi-threaded executor #136

Merged
merged 9 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ mpz-share-conversion-core = { path = "mpz-share-conversion-core" }
clmul = { path = "clmul" }
matrix-transpose = { path = "matrix-transpose" }

tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "1d27bd7" }
tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "1d27bd7" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" }
tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" }

# rand
rand_chacha = "0.3"
Expand Down Expand Up @@ -83,10 +83,10 @@ prost-build = "0.9"
bytes = "1"
yamux = "0.10"
bytemuck = { version = "1.13", features = ["derive"] }
serio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "1d27bd7" }
serio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" }

# io
uid-mux = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "1d27bd7" }
uid-mux = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" }

# testing
prost = "0.9"
Expand Down
19 changes: 1 addition & 18 deletions crates/mpz-common/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use async_trait::async_trait;
use scoped_futures::ScopedBoxFuture;
use serio::{IoSink, IoStream};

use crate::{queue::Queue, ThreadId};
use crate::ThreadId;

/// An error for types that implement [`Context`].
#[derive(Debug, thiserror::Error)]
Expand Down Expand Up @@ -48,11 +48,6 @@ impl fmt::Display for ErrorKind {
pub trait Context: Send + Sync {
/// I/O channel used by the thread.
type Io: IoSink + IoStream + Send + Unpin + 'static;
/// Queue type.
type Queue<'a, R>: Queue<Self, R> + 'a
where
R: Send + 'static,
Self: Sized + 'a;

/// Returns the thread ID.
fn id(&self) -> &ThreadId;
Expand All @@ -63,18 +58,6 @@ pub trait Context: Send + Sync {
/// Returns a mutable reference to the thread's I/O channel.
fn io_mut(&mut self) -> &mut Self::Io;

/// Returns a new task queue.
///
/// Implementations may not be able to fork, in which case the tasks may be executed
/// sequentially.
///
/// [`max_concurrency`](Context::max_concurrency) can help determine how to best divide work
/// between tasks.
async fn queue<'a, R>(&'a mut self) -> Result<Self::Queue<'a, R>, ContextError>
where
R: Send + 'static,
Self: Sized;

/// Forks the thread and executes the provided closures concurrently.
///
/// Implementations may not be able to fork, in which case the closures are executed
Expand Down
14 changes: 1 addition & 13 deletions crates/mpz-common/src/executor/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use async_trait::async_trait;
use scoped_futures::ScopedBoxFuture;
use serio::{Sink, Stream};

use crate::{context::Context, queue::SimpleQueue, ContextError, ThreadId};
use crate::{context::Context, ContextError, ThreadId};

/// A dummy executor.
#[derive(Debug, Default)]
Expand Down Expand Up @@ -61,10 +61,6 @@ impl Stream for DummyIo {
#[async_trait]
impl Context for DummyExecutor {
type Io = DummyIo;
type Queue<'a, R> = SimpleQueue<'a, Self, R>
where
R: Send + 'static,
Self: Sized + 'a;

fn id(&self) -> &ThreadId {
&self.id
Expand All @@ -78,14 +74,6 @@ impl Context for DummyExecutor {
&mut self.io
}

async fn queue<R>(&mut self) -> Result<Self::Queue<'_, R>, ContextError>
where
R: Send + 'static,
Self: Sized,
{
Ok(SimpleQueue::new(self))
}

async fn join<'a, A, B, RA, RB>(&'a mut self, a: A, b: B) -> Result<(RA, RB), ContextError>
where
A: for<'b> FnOnce(&'b mut Self) -> ScopedBoxFuture<'a, 'b, RA> + Send + 'a,
Expand Down
34 changes: 10 additions & 24 deletions crates/mpz-common/src/executor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,8 @@ pub use st::STExecutor;

#[cfg(any(test, feature = "test-utils"))]
mod test_utils {
use std::future::IntoFuture;

use futures::{Future, TryFutureExt};
use serio::{
channel::{duplex, MemoryDuplex},
codec::Bincode,
};
use uid_mux::{
test_utils::test_yamux_pair_framed,
yamux::{ConnectionError, YamuxCtrl},
FramedMux,
};
use serio::channel::{duplex, MemoryDuplex};
use uid_mux::test_utils::{test_framed_mux, TestFramedMux};

use super::*;

Expand All @@ -35,24 +25,20 @@ mod test_utils {
}

/// Test multi-threaded executor.
pub type TestMTExecutor = MTExecutor<FramedMux<YamuxCtrl, Bincode>>;
pub type TestMTExecutor = MTExecutor<TestFramedMux>;

/// Creates a pair of multi-threaded executors with yamux I/O channels.
pub fn test_mt_executor(
io_buffer: usize,
) -> (
(TestMTExecutor, TestMTExecutor),
impl Future<Output = Result<(), ConnectionError>>,
) {
let ((mux_0, fut_0), (mux_1, fut_1)) = test_yamux_pair_framed(io_buffer, Bincode);

let fut_io =
futures::future::try_join(fut_0.into_future(), fut_1.into_future()).map_ok(|_| ());
///
/// # Arguments
///
/// * `io_buffer` - The size of the I/O buffer (channel capacity).
pub fn test_mt_executor(io_buffer: usize) -> (TestMTExecutor, TestMTExecutor) {
let (mux_0, mux_1) = test_framed_mux(io_buffer);

let exec_0 = MTExecutor::new(mux_0, 8);
let exec_1 = MTExecutor::new(mux_1, 8);

((exec_0, exec_1), fut_io)
(exec_0, exec_1)
}
}

Expand Down
103 changes: 20 additions & 83 deletions crates/mpz-common/src/executor/mt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use uid_mux::FramedUidMux;

use crate::{
context::{ContextError, ErrorKind},
queue::RRQueue,
Context, ThreadId,
};

Expand Down Expand Up @@ -91,10 +90,6 @@ where
Io: IoDuplex + Send + Sync + Unpin + 'static,
{
type Io = Io;
type Queue<'a, R> = RRQueue<'a, Self, R>
where
R: Send + 'static,
Self: Sized + 'a;

fn id(&self) -> &ThreadId {
&self.id
Expand All @@ -111,23 +106,6 @@ where
&mut self.io
}

async fn queue<R>(&mut self) -> Result<Self::Queue<'_, R>, ContextError>
where
R: Send + 'static,
Self: Sized,
{
let children = self
.children
.as_mut()
.expect("children were not left uninitialized");

children
.alloc(&self.mux, children.max_concurrency())
.await?;

Ok(RRQueue::new(children.as_slice_mut()))
}

async fn join<'a, A, B, RA, RB>(&'a mut self, a: A, b: B) -> Result<(RA, RB), ContextError>
where
A: for<'b> FnOnce(&'b mut Self) -> ScopedBoxFuture<'a, 'b, RA> + Send + 'a,
Expand All @@ -141,11 +119,16 @@ where
.take()
.expect("children were not left uninitialized");

if children.len() < 1 {
children.alloc(&self.mux, 1).await?;
if children.len() < 2 {
th4s marked this conversation as resolved.
Show resolved Hide resolved
if let Err(e) = children.alloc(&self.mux, 2).await {
self.children = Some(children);
return Err(e);
}
}

let output = futures::join!(a(self), b(children.first_mut()));
let [child_a, child_b] = children.first_n_mut();

let output = futures::join!(a(child_a), b(child_b));

self.children = Some(children);

Expand All @@ -170,11 +153,16 @@ where
.take()
.expect("children were not left uninitialized");

if children.len() < 1 {
children.alloc(&self.mux, 1).await?;
if children.len() < 2 {
if let Err(e) = children.alloc(&self.mux, 2).await {
self.children = Some(children);
return Err(e);
}
}

let output = futures::try_join!(a(self), b(children.first_mut()));
let [child_a, child_b] = children.first_n_mut();

let output = futures::try_join!(a(child_a), b(child_b));

self.children = Some(children);

Expand Down Expand Up @@ -249,24 +237,16 @@ where
Ok(())
}

fn first_mut(&mut self) -> &mut MTContext<M, Io> {
fn first_n_mut<const N: usize>(&mut self) -> &mut [MTContext<M, Io>; N] {
self.slots
.first_mut()
.first_chunk_mut()
.expect("number of threads were checked")
}

fn as_slice_mut(&mut self) -> &mut [MTContext<M, Io>] {
&mut self.slots
}
}

#[cfg(test)]
mod tests {
use std::future::IntoFuture;

use crate::{join, queue::Queue};
use serio::{codec::Bincode, stream::IoStreamExt, SinkExt};
use uid_mux::test_utils::test_yamux_pair_framed;
use crate::{executor::test_mt_executor, join};

use super::*;

Expand Down Expand Up @@ -302,14 +282,7 @@ mod tests {

#[tokio::test]
async fn test_mt_executor_join() {
let ((mux_a, fut_a), (mux_b, fut_b)) = test_yamux_pair_framed(1024, Bincode);

tokio::spawn(async move {
futures::try_join!(fut_a.into_future(), fut_b.into_future()).unwrap();
});

let mut exec_a = MTExecutor::new(mux_a, 8);
let mut exec_b = MTExecutor::new(mux_b, 8);
let (mut exec_a, mut exec_b) = test_mt_executor(8);

let (mut ctx_a, mut ctx_b) =
futures::try_join!(exec_a.new_thread(), exec_b.new_thread()).unwrap();
Expand All @@ -319,40 +292,4 @@ mod tests {

futures::join!(test_a.foo(&mut ctx_a), test_b.foo(&mut ctx_b));
}

#[tokio::test]
async fn test_mt_executor_queue() {
let ((mux_a, fut_a), (mux_b, fut_b)) = test_yamux_pair_framed(1024, Bincode);

tokio::spawn(async move {
futures::try_join!(fut_a.into_future(), fut_b.into_future()).unwrap();
});

let mut exec_a = MTExecutor::new(mux_a, 8);
let mut exec_b = MTExecutor::new(mux_b, 8);

let (mut ctx_a, mut ctx_b) =
futures::try_join!(exec_a.new_thread(), exec_b.new_thread()).unwrap();

let mut queue_a = ctx_a.queue().await.unwrap();
let mut queue_b = ctx_b.queue().await.unwrap();

queue_a.push(|ctx| {
Box::pin(async {
ctx.io_mut().send(0u8).await.unwrap();
})
});
queue_b.push(|ctx| Box::pin(async { ctx.io_mut().expect_next::<u8>().await.unwrap() }));

queue_a.push(|ctx| {
Box::pin(async {
ctx.io_mut().send(1u8).await.unwrap();
})
});
queue_b.push(|ctx| Box::pin(async { ctx.io_mut().expect_next::<u8>().await.unwrap() }));

let (_, results_b) = futures::try_join!(queue_a.wait(), queue_b.wait()).unwrap();

assert_eq!(results_b, vec![0, 1]);
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test to make sure things run concurrently

    use std::time::{Duration, Instant};
    // Need to enable tokio time feature

    // Tests that the tasks run concurrently.
    #[tokio::test]
    async fn test_mt_executor_concurrent() {
        let ((mux_a, fut_a), (mux_b, fut_b)) = test_yamux_pair_framed(1024, Bincode);

        tokio::spawn(async move {
            futures::try_join!(fut_a.into_future(), fut_b.into_future()).unwrap();
        });

        let mut exec_a = MTExecutor::new(mux_a, 8);
        let mut exec_b = MTExecutor::new(mux_b, 8);

        let (mut ctx_a, mut ctx_b) =
            futures::try_join!(exec_a.new_thread(), exec_b.new_thread()).unwrap();

        let mut queue_a = ctx_a.queue().await.unwrap();
        let mut queue_b = ctx_b.queue().await.unwrap();

        let start = Instant::now();
        let timeout = Duration::from_millis(100);

        for i in 0..8 {
            queue_a.push(move |ctx| {
                let i = i;
                let timeout = timeout;
                Box::pin(async move {
                    tokio::time::sleep(timeout).await;
                    ctx.io_mut().send(i as u8).await.unwrap();
                })
            });
            queue_b.push(|ctx| Box::pin(async { ctx.io_mut().expect_next::<u8>().await.unwrap() }));
        }

        let (_, results_b) = futures::try_join!(queue_a.wait(), queue_b.wait()).unwrap();

        let elapsed = Instant::now().duration_since(start);

        // The overall latency should be approximately that of a single task.
        assert!(elapsed < timeout + Duration::from_millis(50));

        assert_eq!(results_b, vec![0, 1, 2, 3, 4, 5, 6, 7]);
    }

10 changes: 0 additions & 10 deletions crates/mpz-common/src/executor/st.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use serio::{IoSink, IoStream};

use crate::{
context::{Context, ContextError},
queue::SimpleQueue,
ThreadId,
};

Expand Down Expand Up @@ -37,7 +36,6 @@ where
Io: IoSink + IoStream + Send + Sync + Unpin + 'static,
{
type Io = Io;
type Queue<'a, R> = SimpleQueue<'a, Self, R> where R: Send + 'static, Self: Sized + 'a;

fn id(&self) -> &ThreadId {
&self.id
Expand All @@ -51,14 +49,6 @@ where
&mut self.io
}

async fn queue<R>(&mut self) -> Result<Self::Queue<'_, R>, ContextError>
where
R: Send + 'static,
Self: Sized,
{
Ok(SimpleQueue::new(self))
}

async fn join<'a, A, B, RA, RB>(&'a mut self, a: A, b: B) -> Result<(RA, RB), ContextError>
where
A: for<'b> FnOnce(&'b mut Self) -> ScopedBoxFuture<'a, 'b, RA> + Send + 'a,
Expand Down
1 change: 0 additions & 1 deletion crates/mpz-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ pub mod executor;
mod id;
#[cfg(any(test, feature = "ideal"))]
pub mod ideal;
pub mod queue;
#[cfg(feature = "sync")]
pub mod sync;

Expand Down
32 changes: 0 additions & 32 deletions crates/mpz-common/src/queue.rs

This file was deleted.

Loading