/
tls.rs
147 lines (125 loc) · 4.98 KB
/
tls.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use anyhow::{anyhow, Ok, Result};
use openssl::ssl::Ssl;
use openssl::ssl::{SslAcceptor, SslConnector, SslFiletype, SslMethod};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_openssl::SslStream;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TlsAcceptorConfig {
/// Path to the certificate authority in PEM format
pub certificate_authority_path: String,
/// Path to the certificate in PEM format
pub certificate_path: String,
/// Path to the private key in PEM format
pub private_key_path: String,
}
#[derive(Clone)]
pub struct TlsAcceptor {
acceptor: Arc<SslAcceptor>,
}
pub fn check_file_field(field_name: &str, file_path: &str) -> Result<()> {
if Path::new(file_path).exists() {
Ok(())
} else {
Err(anyhow!(
"configured {field_name} does not exist {file_path}"
))
}
}
impl TlsAcceptor {
pub fn new(tls_config: TlsAcceptorConfig) -> Result<TlsAcceptor> {
// openssl's errors are really bad so we do our own checks so we can provide reasonable errors
check_file_field(
"certificate_authority_path",
&tls_config.certificate_authority_path,
)?;
check_file_field("private_key_path", &tls_config.private_key_path)?;
check_file_field("certificate_path", &tls_config.certificate_path)?;
let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
builder.set_ca_file(tls_config.certificate_authority_path)?;
builder.set_private_key_file(tls_config.private_key_path, SslFiletype::PEM)?;
builder.set_certificate_chain_file(tls_config.certificate_path)?;
builder.check_private_key()?;
Ok(TlsAcceptor {
acceptor: Arc::new(builder.build()),
})
}
pub async fn accept(&self, tcp_stream: TcpStream) -> Result<SslStream<TcpStream>> {
let ssl = Ssl::new(self.acceptor.context())?;
let mut ssl_stream = SslStream::new(ssl, tcp_stream)?;
Pin::new(&mut ssl_stream)
.accept()
.await
.map_err(|e| anyhow!(e).context("Failed to accept TLS connection"))?;
Ok(ssl_stream)
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TlsConnectorConfig {
/// Path to the certificate authority in PEM format
pub certificate_authority_path: String,
/// Path to the certificate in PEM format
pub certificate_path: Option<String>,
/// Path to the private key in PEM format
pub private_key_path: Option<String>,
}
#[derive(Clone, Debug)]
pub struct TlsConnector {
connector: Arc<SslConnector>,
}
impl TlsConnector {
pub fn new(tls_config: TlsConnectorConfig) -> Result<TlsConnector> {
check_file_field(
"certificate_authority_path",
&tls_config.certificate_authority_path,
)?;
let mut builder = SslConnector::builder(SslMethod::tls())?;
builder.set_ca_file(tls_config.certificate_authority_path)?;
if let Some(private_key_path) = tls_config.private_key_path {
check_file_field("private_key_path", &private_key_path)?;
builder.set_private_key_file(private_key_path, SslFiletype::PEM)?;
}
if let Some(certificate_path) = tls_config.certificate_path {
check_file_field("certificate_path", &certificate_path)?;
builder.set_certificate_chain_file(certificate_path)?;
}
Ok(TlsConnector {
connector: Arc::new(builder.build()),
})
}
pub async fn connect_unverified_hostname(
&self,
tcp_stream: TcpStream,
) -> Result<SslStream<TcpStream>> {
let ssl = self
.connector
.configure()?
.verify_hostname(false)
.into_ssl("localhost")?;
let mut ssl_stream = SslStream::new(ssl, tcp_stream)?;
Pin::new(&mut ssl_stream)
.connect()
.await
.map_err(|e| anyhow!(e).context("Failed to establish TLS connection to destination"))?;
Ok(ssl_stream)
}
pub async fn connect(&self, tcp_stream: TcpStream) -> Result<SslStream<TcpStream>> {
let ssl = self.connector.configure()?.into_ssl("localhost")?;
let mut ssl_stream = SslStream::new(ssl, tcp_stream)?;
Pin::new(&mut ssl_stream)
.connect()
.await
.map_err(|e| anyhow!(e).context("Failed to establish TLS connection to destination"))?;
Ok(ssl_stream)
}
}
/// A trait object can only consist of one trait + special language traits like Send/Sync etc
/// So we need to use this trait when creating trait objects that need both AsyncRead and AsyncWrite
pub trait AsyncStream: AsyncRead + AsyncWrite {}
/// We need to tell rust that these types implement AsyncStream even though they already implement AsyncRead and AsyncWrite
impl AsyncStream for tokio_openssl::SslStream<TcpStream> {}
impl AsyncStream for TcpStream {}