Skip to content

Commit

Permalink
Benchmark bulk performance via unbuffered API
Browse files Browse the repository at this point in the history
  • Loading branch information
ctz committed Jun 21, 2024
1 parent 7424e40 commit 49fed2b
Showing 1 changed file with 280 additions and 4 deletions.
284 changes: 280 additions & 4 deletions rustls/examples/internal/bench_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
// Note: we don't use any of the standard 'cargo bench', 'test::Bencher',
// etc. because it's unstable at the time of writing.

use std::fs;
use std::io::{self, Read, Write};
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::{fs, mem};

use clap::{value_parser, Arg, Command};
use pki_types::{CertificateDer, PrivateKeyDer};
use rustls::client::Resumption;
use rustls::client::{Resumption, UnbufferedClientConnection};
#[cfg(all(not(feature = "ring"), feature = "aws_lc_rs"))]
use rustls::crypto::aws_lc_rs as provider;
#[cfg(all(not(feature = "ring"), feature = "aws_lc_rs"))]
Expand All @@ -21,7 +21,11 @@ use rustls::crypto::ring as provider;
#[cfg(feature = "ring")]
use rustls::crypto::ring::{cipher_suite, Ticketer};
use rustls::crypto::CryptoProvider;
use rustls::server::{NoServerSessionStorage, ServerSessionMemoryCache, WebPkiClientVerifier};
use rustls::server::{
NoServerSessionStorage, ServerSessionMemoryCache, UnbufferedServerConnection,
WebPkiClientVerifier,
};
use rustls::unbuffered::{ConnectionState, EncryptError, InsufficientSizeError, UnbufferedStatus};
use rustls::{
ClientConfig, ClientConnection, ConnectionCommon, RootCertStore, ServerConfig,
ServerConnection, SideData,
Expand Down Expand Up @@ -253,7 +257,21 @@ fn bench_bulk(

report_bulk_result(
"bulk",
bench_bulk_buffered(client_config, server_config, plaintext_size, rounds),
bench_bulk_buffered(
client_config.clone(),
server_config.clone(),
plaintext_size,
rounds,
),
plaintext_size,
rounds,
max_fragment_size,
params,
);

report_bulk_result(
"bulk-unbuffered",
bench_bulk_unbuffered(client_config, server_config, plaintext_size, rounds),
plaintext_size,
rounds,
max_fragment_size,
Expand Down Expand Up @@ -290,6 +308,40 @@ fn bench_bulk_buffered(
(time_send, time_recv)
}

fn bench_bulk_unbuffered(
client_config: Arc<ClientConfig>,
server_config: Arc<ServerConfig>,
plaintext_size: u64,
rounds: u64,
) -> (f64, f64) {
let server_name = "localhost".try_into().unwrap();
let mut client = Unbuffered::new_client(
UnbufferedClientConnection::new(client_config, server_name).unwrap(),
);
let mut server =
Unbuffered::new_server(UnbufferedServerConnection::new(server_config).unwrap());

client.handshake(&mut server);

let mut time_send = 0f64;
let mut time_recv = 0f64;

let buf = vec![0; plaintext_size as usize];
for _ in 0..rounds {
time_send += time(|| {
server.write(&buf);
});

server.swap_buffers(&mut client);

time_recv += time(|| {
client.read_and_discard(buf.len());
});
}

(time_send, time_recv)
}

fn report_bulk_result(
variant: &str,
(time_send, time_recv): (f64, f64),
Expand Down Expand Up @@ -588,6 +640,230 @@ impl KeyType {
}
}

struct Unbuffered {
conn: UnbufferedConnection,
input: Vec<u8>,
input_used: usize,
output: Vec<u8>,
output_used: usize,
}

impl Unbuffered {
fn new_client(client: UnbufferedClientConnection) -> Self {
Self {
conn: UnbufferedConnection::Client(client),
input: vec![0u8; 16384],
input_used: 0,
output: vec![0u8; 16384],
output_used: 0,
}
}

fn new_server(server: UnbufferedServerConnection) -> Self {
Self {
conn: UnbufferedConnection::Server(server),
input: vec![0u8; 16384],
input_used: 0,
output: vec![0u8; 16384],
output_used: 0,
}
}

fn swap_buffers(&mut self, peer: &mut Unbuffered) {
// our output becomes peer's input, and peer's input
// becomes our output.
mem::swap(&mut self.input, &mut peer.output);
mem::swap(&mut self.input_used, &mut peer.output_used);
mem::swap(&mut self.output, &mut peer.input);
mem::swap(&mut self.output_used, &mut peer.input_used);
}

fn handshake(&mut self, peer: &mut Unbuffered) {
loop {
let mut progress = false;

if self.communicate() {
self.swap_buffers(peer);
progress = true;
}

if peer.communicate() {
peer.swap_buffers(self);
progress = true;
}

if !progress {
return;
}
}
}

fn communicate(&mut self) -> bool {
let (input_used, output_added) = self.conn.communicate(
&mut self.input[..self.input_used],
&mut self.output[self.output_used..],
);
assert_eq!(input_used, self.input_used);
self.input_used = 0;
self.output_used += output_added;
self.output_used > 0
}

fn write(&mut self, data: &[u8]) {
assert_eq!(self.input_used, 0);
let output_added = match self
.conn
.write(data, &mut self.output[self.output_used..])
{
Ok(output_added) => output_added,
Err(EncryptError::InsufficientSize(InsufficientSizeError { required_size })) => {
self.output
.resize(self.output_used + required_size, 0);
self.conn
.write(data, &mut self.output[self.output_used..])
.unwrap()
}
Err(other) => panic!("unexpected write error {other:?}"),
};
self.output_used += output_added;
}

fn read_and_discard(&mut self, len: usize) {
assert!(self.input_used > 0);
let input_used = self
.conn
.read_and_discard(len, &mut self.input[..self.input_used]);
assert_eq!(input_used, self.input_used);
self.input_used = 0;
}
}

enum UnbufferedConnection {
Client(UnbufferedClientConnection),
Server(UnbufferedServerConnection),
}

impl UnbufferedConnection {
fn communicate(&mut self, input: &mut [u8], output: &mut [u8]) -> (usize, usize) {
let mut input_used = 0;
let mut output_added = 0;

loop {
match self {
Self::Client(client) => {
match client.process_tls_records(&mut input[input_used..]) {
UnbufferedStatus {
state: Ok(ConnectionState::EncodeTlsData(mut etd)),
discard,
} => {
input_used += discard;
output_added += etd
.encode(&mut output[output_added..])
.unwrap();
}
UnbufferedStatus {
state: Ok(ConnectionState::TransmitTlsData(ttd)),
discard,
} => {
input_used += discard;
ttd.done();
return (input_used, output_added);
}
UnbufferedStatus {
state: Ok(ConnectionState::WriteTraffic(_)),
discard,
} => {
input_used += discard;
return (input_used, output_added);
}
st => {
println!("unexpected client {st:?}");
return (input_used, output_added);
}
}
}
Self::Server(server) => {
match server.process_tls_records(&mut input[input_used..]) {
UnbufferedStatus {
state: Ok(ConnectionState::EncodeTlsData(mut etd)),
discard,
} => {
input_used += discard;
output_added += etd
.encode(&mut output[output_added..])
.unwrap();
}
UnbufferedStatus {
state: Ok(ConnectionState::TransmitTlsData(ttd)),
discard,
} => {
input_used += discard;
ttd.done();
return (input_used, output_added);
}
UnbufferedStatus {
state: Ok(ConnectionState::WriteTraffic(_)),
discard,
} => {
input_used += discard;
return (input_used, output_added);
}
st => {
println!("unexpected server {st:?}");
return (input_used, output_added);
}
}
}
}
}
}

fn write(&mut self, data: &[u8], output: &mut [u8]) -> Result<usize, EncryptError> {
match self {
Self::Client(client) => match client.process_tls_records(&mut []) {
UnbufferedStatus {
state: Ok(ConnectionState::WriteTraffic(mut wt)),
..
} => wt.encrypt(data, output),
st => panic!("unexpected write state: {st:?}"),
},
Self::Server(server) => match server.process_tls_records(&mut []) {
UnbufferedStatus {
state: Ok(ConnectionState::WriteTraffic(mut wt)),
..
} => wt.encrypt(data, output),
st => panic!("unexpected write state: {st:?}"),
},
}
}

fn read_and_discard(&mut self, mut expected: usize, input: &mut [u8]) -> usize {
let mut input_used = 0;

while expected > 0 {
match self {
Self::Client(client) => {
match client.process_tls_records(&mut input[input_used..]) {
UnbufferedStatus {
state: Ok(ConnectionState::ReadTraffic(mut rt)),
discard,
} => {
input_used += discard;
let record = rt.next_record().unwrap().unwrap();
input_used += record.discard;
expected -= record.payload.len();
}
st => panic!("unexpected read state: {st:?}"),
}
}
Self::Server(_) => panic!("server read"),
}
}

input_used
}
}

fn do_handshake_step(client: &mut ClientConnection, server: &mut ServerConnection) -> bool {
if server.is_handshaking() || client.is_handshaking() {
transfer(client, server, None);
Expand Down

0 comments on commit 49fed2b

Please sign in to comment.