Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RFC 5077 TLS session ticket reuse #166

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions Cargo.toml
Expand Up @@ -16,17 +16,19 @@ vendored = ["openssl/vendored"]
alpn = ["security-framework/alpn"]

[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
security-framework = "2.0.0"
security-framework = { version = "2.0.0", features = ["session-tickets"] }
security-framework-sys = "2.0.0"
lazy_static = "1.4.0"
libc = "0.2"
tempfile = "3.1.0"

[target.'cfg(target_os = "windows")'.dependencies]
schannel = "0.1.16"
schannel = "0.1.18"

[target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios")))'.dependencies]
linked_hash_set = "0.1"
log = "0.4.5"
once_cell = "1.0"
openssl = "0.10.29"
openssl-sys = "0.9.55"
openssl-probe = "0.1"
Expand Down
3 changes: 3 additions & 0 deletions build.rs
Expand Up @@ -7,6 +7,9 @@ fn main() {
if version >= 0x1_01_00_00_0 {
println!("cargo:rustc-cfg=have_min_max_version");
}
if version >= 0x1_01_01_00_0 {
println!("cargo:rustc-cfg=ossl111");
}
}

if let Ok(version) = env::var("DEP_OPENSSL_LIBRESSL_VERSION_NUMBER") {
Expand Down
207 changes: 204 additions & 3 deletions src/imp/openssl.rs
@@ -1,20 +1,28 @@
extern crate linked_hash_set;
extern crate once_cell;
extern crate openssl;
extern crate openssl_probe;

use self::linked_hash_set::LinkedHashSet;
use self::once_cell::sync::OnceCell;
use self::openssl::error::ErrorStack;
use self::openssl::ex_data::Index;
use self::openssl::hash::MessageDigest;
use self::openssl::nid::Nid;
use self::openssl::pkcs12::Pkcs12;
use self::openssl::pkey::PKey;
use self::openssl::ssl::{
self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
SslVerifyMode,
self, MidHandshakeSslStream, Ssl, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
SslSession, SslSessionCacheMode, SslSessionRef, SslVerifyMode,
};
use self::openssl::x509::{store::X509StoreBuilder, X509VerifyResult, X509};
use std::borrow::Borrow;
use std::collections::hash_map::{Entry, HashMap};
use std::error;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::io;
use std::sync::Once;
use std::sync::{Arc, Mutex, Once};

use self::openssl::pkey::Private;
use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
Expand Down Expand Up @@ -248,6 +256,8 @@ pub struct TlsConnector {
use_sni: bool,
accept_invalid_hostnames: bool,
accept_invalid_certs: bool,
session_tickets_enabled: bool,
session_cache: Arc<Mutex<SessionCache>>,
}

impl TlsConnector {
Expand Down Expand Up @@ -297,11 +307,37 @@ impl TlsConnector {
#[cfg(target_os = "android")]
load_android_root_certs(&mut connector)?;

let session_cache = Arc::new(Mutex::new(SessionCache::new()));
if builder.session_tickets_enabled {
connector.set_session_cache_mode(SslSessionCacheMode::CLIENT);

connector.set_new_session_callback({
let session_cache = session_cache.clone();
move |ssl, session| {
if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) {
if let Ok(mut session_cache) = session_cache.lock() {
session_cache.insert(key.clone(), session);
}
}
}
});
connector.set_remove_session_callback({
let session_cache = session_cache.clone();
move |_, session| {
if let Ok(mut session_cache) = session_cache.lock() {
session_cache.remove(session);
}
}
});
}

Ok(TlsConnector {
connector: connector.build(),
use_sni: builder.use_sni,
accept_invalid_hostnames: builder.accept_invalid_hostnames,
accept_invalid_certs: builder.accept_invalid_certs,
session_tickets_enabled: builder.session_tickets_enabled,
session_cache,
})
}

Expand All @@ -317,6 +353,23 @@ impl TlsConnector {
if self.accept_invalid_certs {
ssl.set_verify(SslVerifyMode::NONE);
}
if self.session_tickets_enabled {
let key = SessionKey {
host: domain.to_string(),
};

if let Ok(mut session_cache) = self.session_cache.lock() {
if let Some(session) = session_cache.get(&key) {
// Note: the `unsafe`-ty here is because the `session` is required to come from the
// same SSL_CTX that the ssl object (`ssl`) is from, since it maintains internal
// pointers and refcounts. Here, we only have one SSL_CTX, so this is safe.
unsafe { ssl.set_session(&session)? };
}
}

let idx = key_index()?;
ssl.set_ex_data(idx, key);
}

let s = ssl.connect(domain, stream)?;
Ok(TlsStream(s))
Expand Down Expand Up @@ -452,3 +505,151 @@ impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
self.0.flush()
}
}

fn key_index() -> Result<Index<Ssl, SessionKey>, ErrorStack> {
static IDX: OnceCell<Index<Ssl, SessionKey>> = OnceCell::new();
IDX.get_or_try_init(|| Ssl::new_ex_index()).map(|v| *v)
}

#[derive(Hash, PartialEq, Eq, Clone)]
pub struct SessionKey {
pub host: String,
}

#[derive(Clone)]
struct HashSession(SslSession);

impl PartialEq for HashSession {
fn eq(&self, other: &HashSession) -> bool {
self.0.id() == other.0.id()
}
}

impl Eq for HashSession {}

impl Hash for HashSession {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
self.0.id().hash(state);
}
}

impl Borrow<[u8]> for HashSession {
fn borrow(&self) -> &[u8] {
self.0.id()
}
}

pub struct SessionCache {
sessions: HashMap<SessionKey, LinkedHashSet<HashSession>>,
reverse: HashMap<HashSession, SessionKey>,
}

impl SessionCache {
pub fn new() -> SessionCache {
SessionCache {
sessions: HashMap::new(),
reverse: HashMap::new(),
}
}

pub fn insert(&mut self, key: SessionKey, session: SslSession) {
let session = HashSession(session);

self.sessions
.entry(key.clone())
.or_insert_with(LinkedHashSet::new)
.insert(session.clone());
self.reverse.insert(session.clone(), key);
}

pub fn get(&mut self, key: &SessionKey) -> Option<SslSession> {
let session = {
let sessions = self.sessions.get_mut(key)?;
sessions.front().cloned()?.0
};

#[cfg(ossl111)]
{
use self::openssl::ssl::SslVersion;

// https://tools.ietf.org/html/rfc8446#appendix-C.4
// OpenSSL will remove the session from its cache after the handshake completes anyway, but this ensures
// that concurrent handshakes don't end up with the same session.
if session.protocol_version() == SslVersion::TLS1_3 {
self.remove(&session);
}
}

Some(session)
}

pub fn remove(&mut self, session: &SslSessionRef) {
let key = match self.reverse.remove(session.id()) {
Some(key) => key,
None => return,
};

if let Entry::Occupied(mut sessions) = self.sessions.entry(key) {
sessions.get_mut().remove(session.id());
if sessions.get().is_empty() {
sessions.remove();
}
}
}
}

#[cfg(test)]
mod tests {
use std::io::{Read, Write};
use std::net::TcpStream;

use crate::TlsConnector;

fn connect_and_assert(tls: &TlsConnector, domain: &str, port: u16, should_resume: bool) {
let s = TcpStream::connect((domain, port)).unwrap();
let mut stream = tls.connect(domain, s).unwrap();

// Must write to the stream, as OpenSSL doesn't appear to call the
// session callback until we do.
stream.write_all(b"GET / HTTP/1.0\r\n\r\n").unwrap();
let mut result = vec![];
stream.read_to_end(&mut result).unwrap();

assert_eq!((stream.0).0.ssl().session_reused(), should_resume);

// Must shut down properly, or OpenSSL will invalidate the session.
stream.shutdown().unwrap();
}

#[test]
fn connect_no_session_ticket_resumption() {
let tls = TlsConnector::new().unwrap();
connect_and_assert(&tls, "google.com", 443, false);
connect_and_assert(&tls, "google.com", 443, false);
}

#[test]
fn connect_session_ticket_resumption() {
let mut builder = TlsConnector::builder();
builder.session_tickets_enabled(true);
let tls = builder.build().unwrap();

connect_and_assert(&tls, "google.com", 443, false);
connect_and_assert(&tls, "google.com", 443, true);
}

#[test]
fn connect_session_ticket_resumption_two_sites() {
let mut builder = TlsConnector::builder();
builder.session_tickets_enabled(true);
let tls = builder.build().unwrap();

connect_and_assert(&tls, "google.com", 443, false);
connect_and_assert(&tls, "mozilla.org", 443, false);
connect_and_assert(&tls, "google.com", 443, true);
connect_and_assert(&tls, "mozilla.org", 443, true);
}
}