diff --git a/c_src/fast_tls_drv.c b/c_src/fast_tls_drv.c index a746049..176411a 100644 --- a/c_src/fast_tls_drv.c +++ b/c_src/fast_tls_drv.c @@ -109,6 +109,7 @@ struct bucket { char *key; time_t key_mtime; time_t dh_mtime; + time_t ca_mtime; SSL_CTX *ssl_ctx; struct bucket *next; }; @@ -135,7 +136,7 @@ static void init_hash_table() } static void hash_table_insert(char *key, time_t key_mtime, time_t dh_mtime, - SSL_CTX *ssl_ctx) + time_t ca_mtime, SSL_CTX *ssl_ctx) { int level, split; uint32_t hash = str_hash(key); @@ -156,6 +157,7 @@ static void hash_table_insert(char *key, time_t key_mtime, time_t dh_mtime, if (el->hash == hash && strcmp(el->key, key) == 0) { el->key_mtime = key_mtime; el->dh_mtime = dh_mtime; + el->ca_mtime = ca_mtime; if (el->ssl_ctx != NULL) SSL_CTX_free(el->ssl_ctx); el->ssl_ctx = ssl_ctx; @@ -174,6 +176,7 @@ static void hash_table_insert(char *key, time_t key_mtime, time_t dh_mtime, strcpy(new_bucket_el->key, key); new_bucket_el->key_mtime = key_mtime; new_bucket_el->dh_mtime = dh_mtime; + new_bucket_el->ca_mtime = ca_mtime; new_bucket_el->ssl_ctx = ssl_ctx; new_bucket_el->next = ht.buckets[bucket]; ht.buckets[bucket] = new_bucket_el; @@ -211,7 +214,7 @@ static void hash_table_insert(char *key, time_t key_mtime, time_t dh_mtime, } static SSL_CTX *hash_table_lookup(char *key, time_t *key_mtime, - time_t *dh_mtime) + time_t *dh_mtime, time_t *ca_mtime) { int level, split; uint32_t hash = str_hash(key); @@ -230,6 +233,7 @@ static SSL_CTX *hash_table_lookup(char *key, time_t *key_mtime, if (el->hash == hash && strcmp(el->key, key) == 0) { *key_mtime = el->key_mtime; *dh_mtime = el->dh_mtime; + *ca_mtime = el->ca_mtime; return el->ssl_ctx; } el = el->next; @@ -526,6 +530,7 @@ static ErlDrvSSizeT tls_drv_control(ErlDrvData handle, case SET_CERTIFICATE_FILE_CONNECT: { time_t key_mtime = 0; time_t dh_mtime = 0; + time_t ca_mtime = 0; char *key_file = buf; size_t key_file_len = strlen(key_file); char *ciphers = key_file + key_file_len + 1; @@ -534,10 +539,13 @@ static ErlDrvSSizeT tls_drv_control(ErlDrvData handle, size_t protocol_options_len = strlen(protocol_options); char *dh_file = protocol_options + protocol_options_len + 1; size_t dh_file_len = strlen(dh_file); + char *ca_file = dh_file + dh_file_len + 1; + size_t ca_file_len = strlen(ca_file); char *hash_key = (char *)driver_alloc(key_file_len + ciphers_len + protocol_options_len + - dh_file_len + 1); + dh_file_len + + ca_file_len + 1); long options = 0L; if (protocol_options_len != 0) { @@ -553,20 +561,24 @@ static ErlDrvSSizeT tls_drv_control(ErlDrvData handle, free(popts); } - sprintf(hash_key, "%s%s%s%s", key_file, ciphers, protocol_options, - dh_file); - SSL_CTX *ssl_ctx = hash_table_lookup(hash_key, &key_mtime, &dh_mtime); + sprintf(hash_key, "%s%s%s%s%s", key_file, ciphers, protocol_options, + dh_file, ca_file); + SSL_CTX *ssl_ctx = hash_table_lookup(hash_key, &key_mtime, &dh_mtime, &ca_mtime); if (dh_file_len == 0) dh_file = NULL; + if (ca_file_len == 0) + ca_file = NULL; + if (is_modified(key_file, &key_mtime) || is_modified(dh_file, &dh_mtime) || + is_modified(ca_file, &ca_mtime) || ssl_ctx == NULL) { SSL_CTX *ctx; - hash_table_insert(hash_key, key_mtime, dh_mtime, NULL); + hash_table_insert(hash_key, key_mtime, dh_mtime, ca_mtime, NULL); ctx = SSL_CTX_new(SSLv23_method()); die_unless(ctx, "SSL_CTX_new failed"); @@ -593,7 +605,10 @@ static ErlDrvSSizeT tls_drv_control(ErlDrvData handle, #endif SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF); - SSL_CTX_set_default_verify_paths(ctx); + if (ca_file) + SSL_CTX_load_verify_locations(ctx, ca_file, NULL); + else + SSL_CTX_set_default_verify_paths(ctx); #ifdef SSL_MODE_RELEASE_BUFFERS SSL_CTX_set_mode(ctx, SSL_MODE_RELEASE_BUFFERS); #endif @@ -611,7 +626,7 @@ static ErlDrvSSizeT tls_drv_control(ErlDrvData handle, SSL_CTX_set_info_callback(ctx, &ssl_info_callback); ssl_ctx = ctx; - hash_table_insert(hash_key, key_mtime, dh_mtime, ssl_ctx); + hash_table_insert(hash_key, key_mtime, dh_mtime, ca_mtime, ssl_ctx); } driver_free(hash_key); diff --git a/src/fast_tls.erl b/src/fast_tls.erl index 08a29ae..f24f55e 100644 --- a/src/fast_tls.erl +++ b/src/fast_tls.erl @@ -142,11 +142,17 @@ tcp_to_tls(TCPSocket, Options) -> false -> <<>> end, + CAFile = case lists:keysearch(cafile, 1, Options) of + {value, {cafile, CF}} -> + iolist_to_binary(CF); + false -> + <<>> + end, CertFile1 = iolist_to_binary(CertFile), case catch port_control(Port, Command bor Flags, <>) + 0, CAFile/binary, 0>>) of {'EXIT', {badarg, _}} -> {error, einval}; <<0>> ->