Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ libc = "0.2"
rustls-native-certs = "0.5.0"
sct = "0.6.0"
rustls-pemfile = "0.2.0"
log = "0.4.14"

[dev_dependencies]
cbindgen = "*"
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ target/crustls-demo: target/main.o target/$(PROFILE)/libcrustls.a
$(CC) -o $@ $^ $(LDFLAGS)

target/$(PROFILE)/libcrustls.a: src/*.rs Cargo.toml
cargo build $(CARGOFLAGS)
RUSTFLAGS="-C metadata=rustls-ffi" cargo build $(CARGOFLAGS)

target/main.o: src/main.c src/crustls.h | target
$(CC) -o $@ -c $< $(CFLAGS)
Expand Down
19 changes: 18 additions & 1 deletion src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use rustls::{Certificate, ClientSession, ServerSession, Session, SupportedCipher

use crate::io::{CallbackReader, CallbackWriter, ReadCallback, WriteCallback};
use crate::is_close_notify;
use crate::log::{ensure_log_registered, rustls_log_callback};
use crate::{
cipher::{rustls_certificate, rustls_supported_ciphersuite},
error::{map_error, rustls_io_result, rustls_result},
Expand All @@ -19,6 +20,7 @@ use rustls_result::NullParameter;
pub(crate) struct Connection {
conn: Inner,
userdata: *mut c_void,
log_callback: rustls_log_callback,
peer_certs: Option<Vec<Certificate>>,
}

Expand All @@ -32,6 +34,7 @@ impl Connection {
Connection {
conn: Inner::Client(s),
userdata: null_mut(),
log_callback: None,
peer_certs: None,
}
}
Expand All @@ -40,6 +43,7 @@ impl Connection {
Connection {
conn: Inner::Server(s),
userdata: null_mut(),
log_callback: None,
peer_certs: None,
}
}
Expand Down Expand Up @@ -114,6 +118,19 @@ pub extern "C" fn rustls_connection_set_userdata(
conn.userdata = userdata;
}

/// Set the logging callback for this connection. The log callback will be invoked
/// with the userdata parameter previously set by rustls_connection_set_userdata, or
/// NULL if no userdata was set.
#[no_mangle]
pub extern "C" fn rustls_connection_set_log_callback(
conn: *mut rustls_connection,
cb: rustls_log_callback,
) {
let conn: &mut Connection = try_mut_from_ptr!(conn);
ensure_log_registered();
conn.log_callback = cb;
}

/// Read some TLS bytes from the network into internal buffers. The actual network
/// I/O is performed by `callback`, which you provide. Rustls will invoke your
/// callback with a suitable buffer to store the read bytes into. You don't have
Expand Down Expand Up @@ -186,7 +203,7 @@ pub extern "C" fn rustls_connection_process_new_packets(
) -> rustls_result {
ffi_panic_boundary! {
let conn: &mut Connection = try_mut_from_ptr!(conn);
let guard = match userdata_push(conn.userdata) {
let guard = match userdata_push(conn.userdata, conn.log_callback) {
Ok(g) => g,
Err(_) => return rustls_result::Panic,
};
Expand Down
21 changes: 21 additions & 0 deletions src/crustls.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ typedef enum rustls_result (*rustls_session_store_get_callback)(rustls_session_s
*/
typedef enum rustls_result (*rustls_session_store_put_callback)(rustls_session_store_userdata userdata, const struct rustls_slice_bytes *key, const struct rustls_slice_bytes *val);

typedef size_t rustls_log_level;

typedef struct rustls_log_params {
rustls_log_level level;
struct rustls_str message;
} rustls_log_params;

typedef void (*rustls_log_callback)(void *userdata, const struct rustls_log_params *params);
Copy link
Collaborator

@tgeoghegan tgeoghegan Jun 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapping rustls_str in another struct seems like overkill to me. Is the idea here that it would be easier to add members to rustls_log_params than it would be to add arguments to rustls_log_callback? Also I think this merits documentation to explain which userdata is provided to the callback (since rustls_client_connection_set_log_callback doesn't take opaque userdata) (though discussion of doccomments is probably premature for a draft PR!)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the idea is to add other arguments, like log level. This follows a pattern for our other callbacks where they take a pointer to a struct rather than a list of params. One could argue is that that pattern is useful for large lists of params, but with small lists (like level and message), it's better to just pass the params directly. The flip side is that it's more consistent for our callbacks to always take a struct.


/**
* A return value for a function that may return either success (0) or a
* non-zero value representing an error.
Expand Down Expand Up @@ -724,6 +733,13 @@ enum rustls_result rustls_client_config_builder_set_persistence(struct rustls_cl
*/
void rustls_connection_set_userdata(struct rustls_connection *conn, void *userdata);

/**
* Set the logging callback for this connection. The log callback will be invoked
* with the userdata parameter previously set by rustls_connection_set_userdata, or
* NULL if no userdata was set.
*/
void rustls_connection_set_log_callback(struct rustls_connection *conn, rustls_log_callback cb);

/**
* Read some TLS bytes from the network into internal buffers. The actual network
* I/O is performed by `callback`, which you provide. Rustls will invoke your
Expand Down Expand Up @@ -865,6 +881,11 @@ void rustls_error(enum rustls_result result, char *buf, size_t len, size_t *out_

bool rustls_result_is_cert_error(enum rustls_result result);

/**
* Return a rustls_str containing the stringified version of a log level.
*/
struct rustls_str rustls_log_level_str(rustls_log_level level);

/**
* Return the length of the outer slice. If the input pointer is NULL,
* returns 0.
Expand Down
72 changes: 54 additions & 18 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ mod connection;
mod enums;
mod error;
mod io;
mod log;
mod panic;
mod rslice;
mod server;
mod session;

use crate::log::rustls_log_callback;
use crate::panic::PanicOrDefault;

// For C callbacks, we need to offer a `void *userdata` parameter, so the
Expand All @@ -32,7 +34,12 @@ use crate::panic::PanicOrDefault;
// Rust code, we model these thread locals as a stack, so we can always
// restore the previous version.
thread_local! {
pub static USERDATA: RefCell<Vec<*mut c_void>> = RefCell::new(Vec::new());
pub static USERDATA: RefCell<Vec<Userdata>> = RefCell::new(Vec::new());
}

pub struct Userdata {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One might consider simply storing the rustls_connection here. It seems to contain the application context which is exactly what we want to push/pop with this.

If my Rust were better, I'd also knew if a Arc<dyn AppContextTrait> could be used here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this, but this would require us to violate the borrowing rules for Connection in ways we don't already. Specifically in the case where we are in a call that takes a *mut rustls_connection and turns it into a &mut Connection, we would then have to put a *mut Connection in Userdata and use unsafe to dereference it. I think this winds up being sound due to the threading properties of rustls (everything happens on the thread you called on), but it still seems like unnecessary risk.

Of course we are taking the same risk (having multiple mutable pointers) with userdata: *mut c_void, but that is unavoidable risk, if we want a general-purpose callback.

userdata: *mut c_void,
log_callback: rustls_log_callback,
}

/// UserdataGuard pops an entry off the USERDATA stack, restoring the
Expand All @@ -46,12 +53,17 @@ pub struct UserdataGuard {
// Keep a copy of the data we expect to be popping off the stack. This allows
// us to check for consistency, and also serves to make this type !Send:
// https://doc.rust-lang.org/nightly/std/primitive.pointer.html#impl-Send-1
data: Option<*mut c_void>,
data: Option<Userdata>,
}

impl UserdataGuard {
fn new(u: *mut c_void) -> Self {
UserdataGuard { data: Some(u) }
UserdataGuard {
data: Some(Userdata {
userdata: u,
log_callback: None,
}),
}
}

/// Even though we have a Drop impl on this guard, when possible it's
Expand All @@ -63,15 +75,19 @@ impl UserdataGuard {
}

fn try_pop(&mut self) -> Result<(), UserdataError> {
let expected_data = self.data.ok_or(UserdataError::AlreadyPopped)?;
let expected_data = self
.data
.as_ref()
.ok_or(UserdataError::AlreadyPopped)?
.userdata;
USERDATA
.try_with(|userdata| {
userdata.try_borrow_mut().map_or_else(
|_| Err(UserdataError::AlreadyBorrowed),
|mut v| {
let u = v.pop().ok_or(UserdataError::EmptyStack)?;
self.data = None;
if u == expected_data {
if u.userdata == expected_data {
Ok(())
} else {
Err(UserdataError::WrongData)
Expand Down Expand Up @@ -105,13 +121,19 @@ pub enum UserdataError {
}

#[must_use = "If you drop the guard, userdata will be immediately cleared"]
pub fn userdata_push(u: *mut c_void) -> Result<UserdataGuard, UserdataError> {
pub fn userdata_push(
u: *mut c_void,
cb: rustls_log_callback,
) -> Result<UserdataGuard, UserdataError> {
USERDATA
.try_with(|userdata| {
userdata.try_borrow_mut().map_or_else(
|_| Err(UserdataError::AlreadyBorrowed),
|mut v| {
v.push(u);
v.push(Userdata {
userdata: u,
log_callback: cb,
});
Ok(())
},
)
Expand All @@ -126,7 +148,21 @@ pub fn userdata_get() -> Result<*mut c_void, UserdataError> {
userdata.try_borrow_mut().map_or_else(
|_| Err(UserdataError::AlreadyBorrowed),
|v| match v.last() {
Some(u) => Ok(*u),
Some(u) => Ok(u.userdata),
None => Err(UserdataError::EmptyStack),
},
)
})
.unwrap_or(Err(UserdataError::AccessError))
}

pub fn log_callback_get() -> Result<(rustls_log_callback, *mut c_void), UserdataError> {
USERDATA
.try_with(|userdata| {
userdata.try_borrow_mut().map_or_else(
|_| Err(UserdataError::AlreadyBorrowed),
|v| match v.last() {
Some(u) => Ok((u.log_callback, u.userdata)),
None => Err(UserdataError::EmptyStack),
},
)
Expand All @@ -143,7 +179,7 @@ mod tests {
fn guard_try_pop() {
let data = "hello";
let data_ptr: *mut c_void = data as *const _ as _;
let mut guard = userdata_push(data_ptr).unwrap();
let mut guard = userdata_push(data_ptr, None).unwrap();
assert_eq!(userdata_get().unwrap(), data_ptr);
guard.try_pop().unwrap();
assert!(matches!(guard.try_pop(), Err(_)));
Expand All @@ -153,7 +189,7 @@ mod tests {
fn guard_try_drop() {
let data = "hello";
let data_ptr: *mut c_void = data as *const _ as _;
let guard = userdata_push(data_ptr).unwrap();
let guard = userdata_push(data_ptr, None).unwrap();
assert_eq!(userdata_get().unwrap(), data_ptr);
guard.try_drop().unwrap();
assert!(matches!(userdata_get(), Err(_)));
Expand All @@ -164,7 +200,7 @@ mod tests {
let data = "hello";
let data_ptr: *mut c_void = data as *const _ as _;
{
let _guard = userdata_push(data_ptr).unwrap();
let _guard = userdata_push(data_ptr, None).unwrap();
assert_eq!(userdata_get().unwrap(), data_ptr);
}
assert!(matches!(userdata_get(), Err(_)));
Expand All @@ -175,12 +211,12 @@ mod tests {
let hello = "hello";
let hello_ptr: *mut c_void = hello as *const _ as _;
{
let guard = userdata_push(hello_ptr).unwrap();
let guard = userdata_push(hello_ptr, None).unwrap();
assert_eq!(userdata_get().unwrap(), hello_ptr);
{
let yo = "yo";
let yo_ptr: *mut c_void = yo as *const _ as _;
let guard2 = userdata_push(yo_ptr).unwrap();
let guard2 = userdata_push(yo_ptr, None).unwrap();
assert_eq!(userdata_get().unwrap(), yo_ptr);
guard2.try_drop().unwrap();
}
Expand All @@ -194,12 +230,12 @@ mod tests {
fn out_of_order_drop() {
let hello = "hello";
let hello_ptr: *mut c_void = hello as *const _ as _;
let guard = userdata_push(hello_ptr).unwrap();
let guard = userdata_push(hello_ptr, None).unwrap();
assert_eq!(userdata_get().unwrap(), hello_ptr);

let yo = "yo";
let yo_ptr: *mut c_void = yo as *const _ as _;
let guard2 = userdata_push(yo_ptr).unwrap();
let guard2 = userdata_push(yo_ptr, None).unwrap();
assert_eq!(userdata_get().unwrap(), yo_ptr);

assert!(matches!(guard.try_drop(), Err(UserdataError::WrongData)));
Expand All @@ -210,19 +246,19 @@ mod tests {
fn userdata_multi_threads() {
let hello = "hello";
let hello_ptr: *mut c_void = hello as *const _ as _;
let guard = userdata_push(hello_ptr).unwrap();
let guard = userdata_push(hello_ptr, None).unwrap();
assert_eq!(userdata_get().unwrap(), hello_ptr);

let thread1 = thread::spawn(|| {
let yo = "yo";
let yo_ptr: *mut c_void = yo as *const _ as _;
let guard2 = userdata_push(yo_ptr).unwrap();
let guard2 = userdata_push(yo_ptr, None).unwrap();
assert_eq!(userdata_get().unwrap(), yo_ptr);

let greetz = "greetz";
let greetz_ptr: *mut c_void = greetz as *const _ as _;

let guard3 = userdata_push(greetz_ptr).unwrap();
let guard3 = userdata_push(greetz_ptr, None).unwrap();

assert_eq!(userdata_get().unwrap(), greetz_ptr);
guard3.try_drop().unwrap();
Expand Down
62 changes: 62 additions & 0 deletions src/log.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::convert::TryInto;

use libc::c_void;
use log::Level;

use crate::{log_callback_get, rslice::rustls_str};

struct Logger {}

impl log::Log for Logger {
fn enabled(&self, _metadata: &log::Metadata<'_>) -> bool {
true
}
fn log(&self, record: &log::Record<'_>) {
if let Ok((Some(cb), userdata)) = log_callback_get() {
let message = format!("{} {}", record.target(), record.args());
if let Ok(message) = message.as_str().try_into() {
unsafe {
cb(
userdata,
&rustls_log_params {
level: record.level() as rustls_log_level,
message,
},
);
}
}
}
}
fn flush(&self) {}
}

pub(crate) fn ensure_log_registered() {
log::set_logger(&Logger {}).ok();
log::set_max_level(log::LevelFilter::Debug)
}

type rustls_log_level = usize;

/// Return a rustls_str containing the stringified version of a log level.
#[no_mangle]
pub extern "C" fn rustls_log_level_str(level: rustls_log_level) -> rustls_str<'static> {
let s = match level {
1 => Level::Error.as_str(),
2 => Level::Warn.as_str(),
3 => Level::Info.as_str(),
4 => Level::Debug.as_str(),
5 => Level::Trace.as_str(),
_ => "INVALID",
};
rustls_str::from_str_unchecked(s)
}

#[repr(C)]
pub struct rustls_log_params<'a> {
level: rustls_log_level,
message: rustls_str<'a>,
}

#[allow(non_camel_case_types)]
pub type rustls_log_callback =
Option<unsafe extern "C" fn(userdata: *mut c_void, params: *const rustls_log_params)>;
Loading