Skip to content

Commit

Permalink
fix everything
Browse files Browse the repository at this point in the history
  • Loading branch information
satyarohith committed Nov 12, 2021
1 parent c4fdc72 commit 1010f9f
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 76 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ features = [
]

[features]
default = ["__rustls"]
default = ["__rustls", "rustls-tls-webpki-roots"]

# Note: this doesn't enable the 'native-tls' feature, which adds specific
# functionality for it.
Expand Down
108 changes: 63 additions & 45 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use hyper::client::ResponseFuture;
#[cfg(feature = "native-tls-crate")]
use native_tls_crate::TlsConnector;
use pin_project_lite::pin_project;
use rustls::OwnedTrustAnchor;
#[cfg(feature = "rustls-tls-native-roots")]
use rustls::RootCertStore;
use std::future::Future;
Expand Down Expand Up @@ -322,40 +323,81 @@ impl ClientBuilder {
TlsBackend::Rustls => {
use crate::tls::NoVerifier;

let mut tls = rustls::ClientConfig::new();
match config.http_version_pref {
HttpVersionPref::Http1 => {
tls.set_protocols(&["http/1.1".into()]);
}
HttpVersionPref::Http2 => {
tls.set_protocols(&["h2".into()]);
}
HttpVersionPref::All => {
tls.set_protocols(&["h2".into(), "http/1.1".into()]);
}
// Set root certificates.
let mut root_store = rustls::RootCertStore::empty();
for cert in config.root_certs {
cert.add_to_rustls(&mut root_store)?;
}
#[cfg(feature = "rustls-tls-webpki-roots")]
if config.tls_built_in_root_certs {
tls.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
let mut trust_anchors = Vec::with_capacity(webpki_roots::TLS_SERVER_ROOTS.0.len());
for cert in webpki_roots::TLS_SERVER_ROOTS.0 {
trust_anchors.push(
OwnedTrustAnchor::from_subject_spki_name_constraints(
cert.subject,
cert.spki,
cert.name_constraints
)
);
}
root_store.add_server_trust_anchors(trust_anchors.into_iter());
}
#[cfg(feature = "rustls-tls-native-roots")]
if config.tls_built_in_root_certs {
let roots_slice = NATIVE_ROOTS.as_ref().unwrap().roots.as_slice();
tls.root_store.roots.extend_from_slice(roots_slice);
for cert in rustls_native_certs::load_native_certs().unwrap() {
root_store.add(&rustls::Certificate(cert.0))
.map_err(|e| crate::error::builder(e))?
}
}

// Set supported TLS versions.
let mut versions = rustls::ALL_VERSIONS.to_vec();
if let Some(min_tls_version) = config.min_tls_version {
versions
.retain(|&supported_version| match tls::Version::from_rustls(supported_version.version) {
Some(version) => version >= min_tls_version,
// Assume it's so new we don't know about it, allow it
// (as of writing this is unreachable)
None => true,
});
}
if let Some(max_tls_version) = config.max_tls_version {
versions
.retain(|&supported_version| match tls::Version::from_rustls(supported_version.version) {
Some(version) => version <= max_tls_version,
None => false,
});
}

let config_builder = rustls::ClientConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&versions)
.unwrap()
.with_root_certificates(root_store);

let mut tls = if let Some(id) = config.identity {
let (key, certs) = id.get_pem()?;
config_builder.with_single_cert(certs, key).unwrap()
} else {
config_builder.with_no_client_auth()
};

if !config.certs_verification {
tls.dangerous()
.set_certificate_verifier(Arc::new(NoVerifier));
}

for cert in config.root_certs {
cert.add_to_rustls(&mut tls)?;
}

if let Some(id) = config.identity {
id.add_to_rustls(&mut tls)?;
match config.http_version_pref {
HttpVersionPref::Http1 => {
tls.alpn_protocols = vec!["http/1.1".into()];
}
HttpVersionPref::Http2 => {
tls.alpn_protocols = vec!["h2".into()];
}
HttpVersionPref::All => {
tls.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
}
}

// rustls does not support TLS versions <1.2 and this is unlikely to change.
Expand All @@ -368,24 +410,6 @@ impl ClientBuilder {
// For now we assume the default tls.versions matches the future ALL_VERSIONS and
// act based on that.

if let Some(min_tls_version) = config.min_tls_version {
tls.versions
.retain(|&version| match tls::Version::from_rustls(version) {
Some(version) => version >= min_tls_version,
// Assume it's so new we don't know about it, allow it
// (as of writing this is unreachable)
None => true,
});
}

if let Some(max_tls_version) = config.max_tls_version {
tls.versions
.retain(|&version| match tls::Version::from_rustls(version) {
Some(version) => version <= max_tls_version,
None => false,
});
}

Connector::new_rustls_tls(
http,
tls,
Expand Down Expand Up @@ -1848,12 +1872,6 @@ fn add_cookie_header(headers: &mut HeaderMap, cookie_store: &dyn cookie::CookieS
}
}

#[cfg(feature = "rustls-tls-native-roots")]
lazy_static! {
static ref NATIVE_ROOTS: std::io::Result<RootCertStore> =
rustls_native_certs::load_native_certs().map_err(|e| e.1);
}

#[cfg(test)]
mod tests {
#[tokio::test]
Expand Down
2 changes: 1 addition & 1 deletion src/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ impl Connector {
let conn = http.call(proxy_dst).await?;
log::trace!("tunneling HTTPS over proxy");
let maybe_server_name = rustls::ServerName::try_from(host)
.map(|serve_name| serve_name.to_owned())
.map(|serve_name| serve_name)
.map_err(|_| "Invalid DNS Name");
let tunneled = tunnel(conn, host.to_string(), port, self.user_agent.clone(), auth).await?;
let serve_name = maybe_server_name?;
Expand Down
58 changes: 30 additions & 28 deletions src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
//! `ClientBuilder`.

#[cfg(feature = "__rustls")]
use rustls::client::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::{internal::msgs::handshake::DigitallySignedStruct, RootCertStore};
use rustls::client::HandshakeSignatureValid;
use rustls::{client::ServerCertVerifier, internal::msgs::handshake::DigitallySignedStruct};
use std::fmt;
#[cfg(feature = "__rustls")]
use tokio_rustls::webpki::DnsNameRef;

/// Represents a server X509 certificate.
#[derive(Clone)]
Expand Down Expand Up @@ -109,26 +107,24 @@ impl Certificate {
}

#[cfg(feature = "__rustls")]
pub(crate) fn add_to_rustls(self, tls: &mut rustls::ClientConfig) -> crate::Result<()> {

pub(crate) fn add_to_rustls(self, root_store: &mut rustls::RootCertStore) -> crate::Result<()> {
use std::io::Cursor;

match self.original {
Cert::Der(buf) => tls
.root_store
.add(&::rustls::Certificate(buf))
.map_err(|e| crate::error::builder(TLSError::WebPKIError(e)))?,
Cert::Pem(buf) => {
Cert::Der(buf) => {
root_store
.add(&rustls::Certificate(buf))
.map_err(|e| crate::error::builder(e))?
},
Cert::Pem(buf) => {
let mut pem = Cursor::new(buf);
let certs = rustls_pemfile::certs(&mut pem).map_err(|_| {
crate::error::builder(TLSError::General(String::from(
crate::error::builder(rustls::Error::General(String::from(
"No valid certificate was found",
)))
})?;
for c in certs {
tls.root_store
.add(&c)
.map_err(|e| crate::error::builder(TLSError::WebPKIError(e)))?;
root_store.add(&rustls::Certificate(c)).map_err(|e| crate::error::builder(e.to_string()))?;
}
}
}
Expand Down Expand Up @@ -209,14 +205,18 @@ impl Identity {

let (key, certs) = {
let mut pem = Cursor::new(buf);
let certs = rustls_pemfile::certs(&mut pem)
.map_err(|_| TLSError::General(String::from("No valid certificate was found")))
.map_err(crate::error::builder)?;
let mut certs = Vec::new();
for cert in rustls_pemfile::certs(&mut pem)
.map_err(|_| rustls::Error::General(String::from("No valid certificate was found")))
.map_err(crate::error::builder)? {
certs.push(rustls::Certificate(cert));
}
pem.set_position(0);
let mut sk = rustls_pemfile::pkcs8_private_keys(&mut pem)
let mut sks = Vec::new();
for sk in rustls_pemfile::pkcs8_private_keys(&mut pem)
.and_then(|pkcs8_keys| {
if pkcs8_keys.is_empty() {
Err(())
Err(std::io::Error::new(std::io::ErrorKind::NotFound, "No valid private key was found"))
} else {
Ok(pkcs8_keys)
}
Expand All @@ -225,12 +225,15 @@ impl Identity {
pem.set_position(0);
rustls_pemfile::rsa_private_keys(&mut pem)
})
.map_err(|_| TLSError::General(String::from("No valid private key was found")))
.map_err(crate::error::builder)?;
if let (Some(sk), false) = (sk.pop(), certs.is_empty()) {
.map_err(|_| rustls::Error::General(String::from("No valid private key was found")))
.map_err(crate::error::builder)? {
sks.push(rustls::PrivateKey(sk));
}

if let (Some(sk), false) = (sks.pop(), certs.is_empty()) {
(sk, certs)
} else {
return Err(crate::error::builder(TLSError::General(String::from(
return Err(crate::error::builder(rustls::Error::General(String::from(
"private key or certificate not found",
))));
}
Expand All @@ -241,6 +244,7 @@ impl Identity {
})
}


#[cfg(feature = "native-tls")]
pub(crate) fn add_to_native_tls(
self,
Expand All @@ -257,12 +261,10 @@ impl Identity {
}

#[cfg(feature = "__rustls")]
pub(crate) fn add_to_rustls(self, tls: &mut rustls::ClientConfig) -> crate::Result<()> {
pub(crate) fn get_pem(self) -> crate::Result<(rustls::PrivateKey, Vec<rustls::Certificate>)> {
match self.inner {
ClientCert::Pem { key, certs } => {
tls.set_single_client_cert(certs, key)
.map_err(|e| crate::error::builder(e))?;
Ok(())
Ok((key, certs))
}
#[cfg(feature = "native-tls")]
ClientCert::Pkcs12(..) => Err(crate::error::builder("incompatible TLS identity type")),
Expand Down
6 changes: 5 additions & 1 deletion tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,11 @@ fn use_preconfigured_native_tls_default() {
fn use_preconfigured_rustls_default() {
extern crate rustls;

let tls = rustls::ClientConfig::new();
let root_store = rustls::RootCertStore::empty();
let tls = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();

reqwest::Client::builder()
.use_preconfigured_tls(tls)
Expand Down

0 comments on commit 1010f9f

Please sign in to comment.