diff --git a/Cargo.toml b/Cargo.toml index 8dd004a8..6e689f29 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ libc = "0.2" rustls-native-certs = "0.5.0" sct = "0.6.0" rustls-pemfile = "0.2.0" +log = "0.4.14" [dev_dependencies] cbindgen = "*" diff --git a/Makefile b/Makefile index 5aa33c74..85241b6f 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,7 @@ target/crustls-demo: target/main.o target/$(PROFILE)/libcrustls.a $(CC) -o $@ $^ $(LDFLAGS) target/$(PROFILE)/libcrustls.a: src/*.rs Cargo.toml - cargo build $(CARGOFLAGS) + RUSTFLAGS="-C metadata=rustls-ffi" cargo build $(CARGOFLAGS) target/main.o: src/main.c src/crustls.h | target $(CC) -o $@ -c $< $(CFLAGS) diff --git a/src/connection.rs b/src/connection.rs index 5909cac7..59a4db85 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -6,6 +6,7 @@ use rustls::{Certificate, ClientSession, ServerSession, Session, SupportedCipher use crate::io::{CallbackReader, CallbackWriter, ReadCallback, WriteCallback}; use crate::is_close_notify; +use crate::log::{ensure_log_registered, rustls_log_callback}; use crate::{ cipher::{rustls_certificate, rustls_supported_ciphersuite}, error::{map_error, rustls_io_result, rustls_result}, @@ -19,6 +20,7 @@ use rustls_result::NullParameter; pub(crate) struct Connection { conn: Inner, userdata: *mut c_void, + log_callback: rustls_log_callback, peer_certs: Option>, } @@ -32,6 +34,7 @@ impl Connection { Connection { conn: Inner::Client(s), userdata: null_mut(), + log_callback: None, peer_certs: None, } } @@ -40,6 +43,7 @@ impl Connection { Connection { conn: Inner::Server(s), userdata: null_mut(), + log_callback: None, peer_certs: None, } } @@ -114,6 +118,19 @@ pub extern "C" fn rustls_connection_set_userdata( conn.userdata = userdata; } +/// Set the logging callback for this connection. The log callback will be invoked +/// with the userdata parameter previously set by rustls_connection_set_userdata, or +/// NULL if no userdata was set. +#[no_mangle] +pub extern "C" fn rustls_connection_set_log_callback( + conn: *mut rustls_connection, + cb: rustls_log_callback, +) { + let conn: &mut Connection = try_mut_from_ptr!(conn); + ensure_log_registered(); + conn.log_callback = cb; +} + /// Read some TLS bytes from the network into internal buffers. The actual network /// I/O is performed by `callback`, which you provide. Rustls will invoke your /// callback with a suitable buffer to store the read bytes into. You don't have @@ -186,7 +203,7 @@ pub extern "C" fn rustls_connection_process_new_packets( ) -> rustls_result { ffi_panic_boundary! { let conn: &mut Connection = try_mut_from_ptr!(conn); - let guard = match userdata_push(conn.userdata) { + let guard = match userdata_push(conn.userdata, conn.log_callback) { Ok(g) => g, Err(_) => return rustls_result::Panic, }; diff --git a/src/crustls.h b/src/crustls.h index b5bc8f3a..2beebd17 100644 --- a/src/crustls.h +++ b/src/crustls.h @@ -330,6 +330,15 @@ typedef enum rustls_result (*rustls_session_store_get_callback)(rustls_session_s */ typedef enum rustls_result (*rustls_session_store_put_callback)(rustls_session_store_userdata userdata, const struct rustls_slice_bytes *key, const struct rustls_slice_bytes *val); +typedef size_t rustls_log_level; + +typedef struct rustls_log_params { + rustls_log_level level; + struct rustls_str message; +} rustls_log_params; + +typedef void (*rustls_log_callback)(void *userdata, const struct rustls_log_params *params); + /** * A return value for a function that may return either success (0) or a * non-zero value representing an error. @@ -724,6 +733,13 @@ enum rustls_result rustls_client_config_builder_set_persistence(struct rustls_cl */ void rustls_connection_set_userdata(struct rustls_connection *conn, void *userdata); +/** + * Set the logging callback for this connection. The log callback will be invoked + * with the userdata parameter previously set by rustls_connection_set_userdata, or + * NULL if no userdata was set. + */ +void rustls_connection_set_log_callback(struct rustls_connection *conn, rustls_log_callback cb); + /** * Read some TLS bytes from the network into internal buffers. The actual network * I/O is performed by `callback`, which you provide. Rustls will invoke your @@ -865,6 +881,11 @@ void rustls_error(enum rustls_result result, char *buf, size_t len, size_t *out_ bool rustls_result_is_cert_error(enum rustls_result result); +/** + * Return a rustls_str containing the stringified version of a log level. + */ +struct rustls_str rustls_log_level_str(rustls_log_level level); + /** * Return the length of the outer slice. If the input pointer is NULL, * returns 0. diff --git a/src/lib.rs b/src/lib.rs index 1055e83e..cfb1a933 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,11 +14,13 @@ mod connection; mod enums; mod error; mod io; +mod log; mod panic; mod rslice; mod server; mod session; +use crate::log::rustls_log_callback; use crate::panic::PanicOrDefault; // For C callbacks, we need to offer a `void *userdata` parameter, so the @@ -32,7 +34,12 @@ use crate::panic::PanicOrDefault; // Rust code, we model these thread locals as a stack, so we can always // restore the previous version. thread_local! { - pub static USERDATA: RefCell> = RefCell::new(Vec::new()); + pub static USERDATA: RefCell> = RefCell::new(Vec::new()); +} + +pub struct Userdata { + userdata: *mut c_void, + log_callback: rustls_log_callback, } /// UserdataGuard pops an entry off the USERDATA stack, restoring the @@ -46,12 +53,17 @@ pub struct UserdataGuard { // Keep a copy of the data we expect to be popping off the stack. This allows // us to check for consistency, and also serves to make this type !Send: // https://doc.rust-lang.org/nightly/std/primitive.pointer.html#impl-Send-1 - data: Option<*mut c_void>, + data: Option, } impl UserdataGuard { fn new(u: *mut c_void) -> Self { - UserdataGuard { data: Some(u) } + UserdataGuard { + data: Some(Userdata { + userdata: u, + log_callback: None, + }), + } } /// Even though we have a Drop impl on this guard, when possible it's @@ -63,7 +75,11 @@ impl UserdataGuard { } fn try_pop(&mut self) -> Result<(), UserdataError> { - let expected_data = self.data.ok_or(UserdataError::AlreadyPopped)?; + let expected_data = self + .data + .as_ref() + .ok_or(UserdataError::AlreadyPopped)? + .userdata; USERDATA .try_with(|userdata| { userdata.try_borrow_mut().map_or_else( @@ -71,7 +87,7 @@ impl UserdataGuard { |mut v| { let u = v.pop().ok_or(UserdataError::EmptyStack)?; self.data = None; - if u == expected_data { + if u.userdata == expected_data { Ok(()) } else { Err(UserdataError::WrongData) @@ -105,13 +121,19 @@ pub enum UserdataError { } #[must_use = "If you drop the guard, userdata will be immediately cleared"] -pub fn userdata_push(u: *mut c_void) -> Result { +pub fn userdata_push( + u: *mut c_void, + cb: rustls_log_callback, +) -> Result { USERDATA .try_with(|userdata| { userdata.try_borrow_mut().map_or_else( |_| Err(UserdataError::AlreadyBorrowed), |mut v| { - v.push(u); + v.push(Userdata { + userdata: u, + log_callback: cb, + }); Ok(()) }, ) @@ -126,7 +148,21 @@ pub fn userdata_get() -> Result<*mut c_void, UserdataError> { userdata.try_borrow_mut().map_or_else( |_| Err(UserdataError::AlreadyBorrowed), |v| match v.last() { - Some(u) => Ok(*u), + Some(u) => Ok(u.userdata), + None => Err(UserdataError::EmptyStack), + }, + ) + }) + .unwrap_or(Err(UserdataError::AccessError)) +} + +pub fn log_callback_get() -> Result<(rustls_log_callback, *mut c_void), UserdataError> { + USERDATA + .try_with(|userdata| { + userdata.try_borrow_mut().map_or_else( + |_| Err(UserdataError::AlreadyBorrowed), + |v| match v.last() { + Some(u) => Ok((u.log_callback, u.userdata)), None => Err(UserdataError::EmptyStack), }, ) @@ -143,7 +179,7 @@ mod tests { fn guard_try_pop() { let data = "hello"; let data_ptr: *mut c_void = data as *const _ as _; - let mut guard = userdata_push(data_ptr).unwrap(); + let mut guard = userdata_push(data_ptr, None).unwrap(); assert_eq!(userdata_get().unwrap(), data_ptr); guard.try_pop().unwrap(); assert!(matches!(guard.try_pop(), Err(_))); @@ -153,7 +189,7 @@ mod tests { fn guard_try_drop() { let data = "hello"; let data_ptr: *mut c_void = data as *const _ as _; - let guard = userdata_push(data_ptr).unwrap(); + let guard = userdata_push(data_ptr, None).unwrap(); assert_eq!(userdata_get().unwrap(), data_ptr); guard.try_drop().unwrap(); assert!(matches!(userdata_get(), Err(_))); @@ -164,7 +200,7 @@ mod tests { let data = "hello"; let data_ptr: *mut c_void = data as *const _ as _; { - let _guard = userdata_push(data_ptr).unwrap(); + let _guard = userdata_push(data_ptr, None).unwrap(); assert_eq!(userdata_get().unwrap(), data_ptr); } assert!(matches!(userdata_get(), Err(_))); @@ -175,12 +211,12 @@ mod tests { let hello = "hello"; let hello_ptr: *mut c_void = hello as *const _ as _; { - let guard = userdata_push(hello_ptr).unwrap(); + let guard = userdata_push(hello_ptr, None).unwrap(); assert_eq!(userdata_get().unwrap(), hello_ptr); { let yo = "yo"; let yo_ptr: *mut c_void = yo as *const _ as _; - let guard2 = userdata_push(yo_ptr).unwrap(); + let guard2 = userdata_push(yo_ptr, None).unwrap(); assert_eq!(userdata_get().unwrap(), yo_ptr); guard2.try_drop().unwrap(); } @@ -194,12 +230,12 @@ mod tests { fn out_of_order_drop() { let hello = "hello"; let hello_ptr: *mut c_void = hello as *const _ as _; - let guard = userdata_push(hello_ptr).unwrap(); + let guard = userdata_push(hello_ptr, None).unwrap(); assert_eq!(userdata_get().unwrap(), hello_ptr); let yo = "yo"; let yo_ptr: *mut c_void = yo as *const _ as _; - let guard2 = userdata_push(yo_ptr).unwrap(); + let guard2 = userdata_push(yo_ptr, None).unwrap(); assert_eq!(userdata_get().unwrap(), yo_ptr); assert!(matches!(guard.try_drop(), Err(UserdataError::WrongData))); @@ -210,19 +246,19 @@ mod tests { fn userdata_multi_threads() { let hello = "hello"; let hello_ptr: *mut c_void = hello as *const _ as _; - let guard = userdata_push(hello_ptr).unwrap(); + let guard = userdata_push(hello_ptr, None).unwrap(); assert_eq!(userdata_get().unwrap(), hello_ptr); let thread1 = thread::spawn(|| { let yo = "yo"; let yo_ptr: *mut c_void = yo as *const _ as _; - let guard2 = userdata_push(yo_ptr).unwrap(); + let guard2 = userdata_push(yo_ptr, None).unwrap(); assert_eq!(userdata_get().unwrap(), yo_ptr); let greetz = "greetz"; let greetz_ptr: *mut c_void = greetz as *const _ as _; - let guard3 = userdata_push(greetz_ptr).unwrap(); + let guard3 = userdata_push(greetz_ptr, None).unwrap(); assert_eq!(userdata_get().unwrap(), greetz_ptr); guard3.try_drop().unwrap(); diff --git a/src/log.rs b/src/log.rs new file mode 100644 index 00000000..2f91a93d --- /dev/null +++ b/src/log.rs @@ -0,0 +1,62 @@ +use std::convert::TryInto; + +use libc::c_void; +use log::Level; + +use crate::{log_callback_get, rslice::rustls_str}; + +struct Logger {} + +impl log::Log for Logger { + fn enabled(&self, _metadata: &log::Metadata<'_>) -> bool { + true + } + fn log(&self, record: &log::Record<'_>) { + if let Ok((Some(cb), userdata)) = log_callback_get() { + let message = format!("{} {}", record.target(), record.args()); + if let Ok(message) = message.as_str().try_into() { + unsafe { + cb( + userdata, + &rustls_log_params { + level: record.level() as rustls_log_level, + message, + }, + ); + } + } + } + } + fn flush(&self) {} +} + +pub(crate) fn ensure_log_registered() { + log::set_logger(&Logger {}).ok(); + log::set_max_level(log::LevelFilter::Debug) +} + +type rustls_log_level = usize; + +/// Return a rustls_str containing the stringified version of a log level. +#[no_mangle] +pub extern "C" fn rustls_log_level_str(level: rustls_log_level) -> rustls_str<'static> { + let s = match level { + 1 => Level::Error.as_str(), + 2 => Level::Warn.as_str(), + 3 => Level::Info.as_str(), + 4 => Level::Debug.as_str(), + 5 => Level::Trace.as_str(), + _ => "INVALID", + }; + rustls_str::from_str_unchecked(s) +} + +#[repr(C)] +pub struct rustls_log_params<'a> { + level: rustls_log_level, + message: rustls_str<'a>, +} + +#[allow(non_camel_case_types)] +pub type rustls_log_callback = + Option; diff --git a/src/main.c b/src/main.c index ec1cc55b..645d18ed 100644 --- a/src/main.c +++ b/src/main.c @@ -414,6 +414,15 @@ send_request_and_read_response(struct demo_conn *conn, return ret; } +void +log_cb(void *userdata, const struct rustls_log_params *params) +{ + struct demo_conn *conn = (struct demo_conn*)userdata; + struct rustls_str level_str = rustls_log_level_str(params->level); + fprintf(stderr, "rustls[fd %d][%.*s]: %.*s\n", conn->fd, + (int)level_str.len, level_str.data, (int)params->message.len, params->message.data); +} + int do_request(const struct rustls_client_config *client_config, const char *hostname, const char *path) @@ -442,6 +451,7 @@ do_request(const struct rustls_client_config *client_config, } rustls_connection_set_userdata(client_conn, conn); + rustls_connection_set_log_callback(client_conn, log_cb); ret = send_request_and_read_response(conn, client_conn, hostname, path); if(ret != RUSTLS_RESULT_OK) { diff --git a/src/rslice.rs b/src/rslice.rs index fc8f9ce6..04589e76 100644 --- a/src/rslice.rs +++ b/src/rslice.rs @@ -163,6 +163,16 @@ impl<'a> TryFrom<&'a str> for rustls_str<'a> { } } +impl<'a> rustls_str<'a> { + pub fn from_str_unchecked(s: &'static str) -> rustls_str<'static> { + rustls_str { + data: s.as_ptr() as *const _, + len: s.len(), + phantom: PhantomData, + } + } +} + #[test] fn test_rustls_str() { let s = "abcd";