Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Windows registered IO #918

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bench/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ rustls = { git = "https://github.com/ctz/rustls", rev = "fee894f7e030" }
tokio = { version = "0.2.13", features = ["rt-core"] }
tracing = "0.1.10"
tracing-subscriber = { version = "0.2.5", default-features = false, features = ["env-filter", "fmt", "ansi", "chrono"]}
winapi = { version = "0.3.9", features = ["impl-default", "ioapiset", "mswsock", "winnt", "synchapi", "mswsockdef", "processthreadsapi"] }

[[bin]]
name = "bulk"
Expand Down
82 changes: 51 additions & 31 deletions bench/src/bulk.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::time::Instant;

use anyhow::{anyhow, Context, Result};
use futures::StreamExt;
use tokio::runtime::{Builder, Runtime};
use tracing::trace;
use winapi::um::winsock2;

const NR_ITERATIONS: usize = 100;
const NR_CHUNKS: usize = 2 * 1024;
const DATA_LEN: usize = 1 * 1024 * 1024;

fn main() {
let mut winsock_data = winsock2::WSADATA::default();
if unsafe { winsock2::WSAStartup(0x202, &mut winsock_data) } != 0 {
panic!("Error starting winsock");
}

tracing::subscriber::set_global_default(
tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
Expand All @@ -27,7 +37,7 @@ fn main() {
let mut runtime = rt();
let (endpoint, incoming) = runtime.enter(|| {
endpoint
.bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0))
.bind(&SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
.unwrap()
});
let server_addr = endpoint.local_addr().unwrap();
Expand All @@ -51,29 +61,31 @@ async fn server(mut incoming: quinn::Incoming) -> Result<()> {
let quinn::NewConnection {
mut uni_streams, ..
} = handshake.await.context("handshake failed")?;
let mut stream = uni_streams
.next()
.await
.ok_or_else(|| anyhow!("accepting stream failed"))??;
trace!("stream established");
let start = Instant::now();
let mut n = 0;
while let Some((data, offset)) = stream.read_unordered().await? {
n = n.max(offset + data.len() as u64);
for _ in 0..NR_ITERATIONS {
let mut stream = uni_streams
.next()
.await
.ok_or_else(|| anyhow!("accepting stream failed"))??;
trace!("stream established");
let start = Instant::now();
let mut n = 0;
while let Some((data, offset)) = stream.read_unordered().await? {
n = n.max(offset + data.len() as u64);
}
let dt = start.elapsed();
println!(
"recvd {} bytes in {:?} ({} MiB/s)",
n,
dt,
n as f32 / (dt.as_secs_f32() * 1024.0 * 1024.0)
);
}
let dt = start.elapsed();
println!(
"recvd {} bytes in {:?} ({} MiB/s)",
n,
dt,
n as f32 / (dt.as_secs_f32() * 1024.0 * 1024.0)
);
Ok(())
}

async fn client(server_addr: SocketAddr, server_cert: quinn::Certificate) -> Result<()> {
let (endpoint, _) = quinn::EndpointBuilder::default()
.bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0))
.bind(&SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
.unwrap();

let mut client_config = quinn::ClientConfigBuilder::default();
Expand All @@ -87,20 +99,28 @@ async fn client(server_addr: SocketAddr, server_cert: quinn::Certificate) -> Res
.context("unable to connect")?;
trace!("connected");

let mut stream = connection
.open_uni()
.await
.context("failed to open stream")?;
const DATA: &[u8] = &[0xAB; 1024 * 1024];
let start = Instant::now();
for _ in 0..1024 {
stream
.write_all(DATA)
for _ in 0..NR_ITERATIONS {
let mut stream = connection
.open_uni()
.await
.context("failed sending data")?;
.context("failed to open stream")?;
const DATA: &[u8] = &[0xAB; DATA_LEN];
let start = Instant::now();
for _ in 0..NR_CHUNKS {
stream
.write_all(DATA)
.await
.context("failed sending data")?;
}
stream.finish().await.context("failed finishing stream")?;
let dt = start.elapsed();
println!(
"sent {} bytes in {:?} ({} MiB/s)",
1024 * DATA.len(),
dt,
(NR_CHUNKS * DATA_LEN) as f32 / 1024.0 / 1024.0 / dt.as_secs_f32()
);
}
stream.finish().await.context("failed finishing stream")?;
println!("sent {} bytes in {:?}", 1024 * DATA.len(), start.elapsed());
Ok(())
}

Expand Down
1 change: 1 addition & 0 deletions quinn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ thiserror = "1.0.21"
tracing = "0.1.10"
tokio = { version = "0.2.6", features = ["rt-core", "io-driver", "time"] }
webpki = { version = "0.21", optional = true }
winapi = { version = "0.3.9", features = ["impl-default", "ioapiset", "mswsock", "winnt", "synchapi", "mswsockdef", "processthreadsapi"] }

[dev-dependencies]
anyhow = "1.0.22"
Expand Down
18 changes: 13 additions & 5 deletions quinn/benches/bench.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use std::{
net::{IpAddr, Ipv6Addr, SocketAddr, UdpSocket},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
sync::Arc,
thread,
};
use winapi::um::winsock2;

use bencher::{benchmark_group, benchmark_main, Bencher};
use futures::StreamExt;
use tokio::runtime::{Builder, Runtime};
use tracing::error_span;
use tracing_futures::Instrument as _;

use quinn::{ClientConfigBuilder, Endpoint, ServerConfigBuilder};
use quinn::{create_wsa_socket, ClientConfigBuilder, Endpoint, ServerConfigBuilder};

benchmark_group!(benches, large_streams, small_streams);
benchmark_main!(benches);
Expand All @@ -21,7 +22,7 @@ fn large_streams(bench: &mut Bencher) {
let ctx = Context::new();
let (addr, thread) = ctx.spawn_server();
let (endpoint, client, mut runtime) = ctx.make_client(addr);
const DATA: &[u8] = &[0xAB; 128 * 1024];
const DATA: &[u8] = &[0xAB; 1024 * 1024];
bench.bytes = DATA.len() as u64;
bench.iter(|| {
runtime.block_on(async {
Expand Down Expand Up @@ -61,6 +62,11 @@ struct Context {

impl Context {
fn new() -> Self {
let mut winsock_data = winsock2::WSADATA::default();
if unsafe { winsock2::WSAStartup(0x202, &mut winsock_data) } != 0 {
panic!("Error starting winsock");
}

let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let key = quinn::PrivateKey::from_der(&cert.serialize_private_key_der()).unwrap();
let cert = quinn::Certificate::from_der(&cert.serialize_der().unwrap()).unwrap();
Expand All @@ -83,8 +89,10 @@ impl Context {
}

pub fn spawn_server(&self) -> (SocketAddr, thread::JoinHandle<()>) {
let sock = UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0)).unwrap();
let sock = create_wsa_socket(&SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)).unwrap();
let addr = sock.local_addr().unwrap();
println!("Server address: {:?}", addr);

let config = self.server_config.clone();
let handle = thread::spawn(move || {
let mut endpoint = Endpoint::builder();
Expand Down Expand Up @@ -119,7 +127,7 @@ impl Context {
let mut runtime = rt();
let (endpoint, _) = runtime.enter(|| {
Endpoint::builder()
.bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0))
.bind(&SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
.unwrap()
});
let quinn::NewConnection { connection, .. } = runtime
Expand Down
72 changes: 63 additions & 9 deletions quinn/src/builders.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,62 @@
use std::{io, net::SocketAddr, sync::Arc};

use proto::generic::{ClientConfig, EndpointConfig, ServerConfig};
use std::os::windows::io::FromRawSocket;
use thiserror::Error;
use tracing::error;
use winapi::{
shared::{ws2def, ws2ipdef},
um::winsock2,
};

use crate::{
endpoint::{Endpoint, EndpointDriver, EndpointRef, Incoming},
udp::UdpSocket,
use crate::endpoint::{
into_c_addr, wsa_last_error, Endpoint, EndpointDriver, EndpointRef, Incoming,
};
#[cfg(feature = "rustls")]
use crate::{Certificate, CertificateChain, PrivateKey};

/// Creates a windows socket
pub fn create_wsa_socket(addr: &SocketAddr) -> Result<std::net::UdpSocket, std::io::Error> {
let raw_socket = unsafe {
winsock2::WSASocketA(
if addr.is_ipv4() {
ws2def::AF_INET
} else {
ws2def::AF_INET6
},
ws2def::SOCK_DGRAM,
ws2def::IPPROTO_UDP as i32,
std::ptr::null_mut(),
0,
winsock2::WSA_FLAG_REGISTERED_IO,
)
};
if raw_socket == winsock2::INVALID_SOCKET {
println!("Invalid socket");
return Err(wsa_last_error());
}

let sock_addr = into_c_addr(*addr);

if winsock2::SOCKET_ERROR
== unsafe {
winsock2::bind(
raw_socket,
&sock_addr as *const ws2ipdef::SOCKADDR_INET as *const _,
std::mem::size_of_val(&sock_addr) as i32,
)
}
{
eprintln!("Can not bind");
let error = Err(wsa_last_error());
let _ = unsafe { winsock2::closesocket(raw_socket) };
return error;
}

let rust_sock: std::net::UdpSocket = unsafe { FromRawSocket::from_raw_socket(raw_socket as _) };
Ok(rust_sock)
}

/// A helper for constructing an [`Endpoint`].
///
/// See [`ClientConfigBuilder`] for details on trust defaults.
Expand Down Expand Up @@ -51,7 +97,8 @@ where
/// addresses. Portable applications should bind an address that matches the family they wish to
/// communicate within.
pub fn bind(self, addr: &SocketAddr) -> Result<(Endpoint<S>, Incoming<S>), EndpointError> {
let socket = std::net::UdpSocket::bind(addr).map_err(EndpointError::Socket)?;
let socket = create_wsa_socket(addr).map_err(|e| EndpointError::Socket(e))?;

self.with_socket(socket)
}

Expand All @@ -64,18 +111,25 @@ where
socket: std::net::UdpSocket,
) -> Result<(Endpoint<S>, Incoming<S>), EndpointError> {
let addr = socket.local_addr().map_err(EndpointError::Socket)?;
let socket = UdpSocket::from_std(socket).map_err(EndpointError::Socket)?;
eprintln!("Local address is {:?}", addr);

// let socket = UdpSocket::from_std(socket).map_err(EndpointError::Socket)?;
let rc = EndpointRef::new(
tokio::runtime::Handle::current(),
socket,
proto::generic::Endpoint::new(Arc::new(self.config), self.server_config.map(Arc::new)),
addr.is_ipv6(),
);
let driver = EndpointDriver(rc.clone());
tokio::spawn(async {
if let Err(e) = driver.await {
error!("I/O error: {}", e);

let ev = rc.wakeup_event();

let mut driver = EndpointDriver(rc.clone());
std::thread::spawn(move || {
if let Err(e) = driver.run(ev) {
eprintln!("Endpoint error: {:?}", e);
}
});

Ok((
Endpoint {
inner: rc.clone(),
Expand Down
Loading