diff --git a/include/net/net_context.h b/include/net/net_context.h index 00a2b23e514ec..08ed3033c0d40 100644 --- a/include/net/net_context.h +++ b/include/net/net_context.h @@ -184,6 +184,8 @@ struct net_tcp; struct net_conn_handle; +struct tls_context; + /** * Note that we do not store the actual source IP address in the context * because the address is already be set in the network interface struct. @@ -275,6 +277,11 @@ struct net_context { struct k_fifo recv_q; struct k_fifo accept_q; }; + +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + /** TLS context information */ + struct tls_context *tls; +#endif /* CONFIG_NET_SOCKETS_SOCKOPT_TLS */ #endif /* CONFIG_NET_SOCKETS */ }; diff --git a/include/net/net_ip.h b/include/net/net_ip.h index cc972b3f5251b..98bf46314477b 100644 --- a/include/net/net_ip.h +++ b/include/net/net_ip.h @@ -54,6 +54,13 @@ enum net_ip_protocol { IPPROTO_ICMPV6 = 58, }; +/* Protocol numbers for TLS protocols */ +enum net_ip_protocol_secure { + IPPROTO_TLS_1_0 = 256, + IPPROTO_TLS_1_1 = 257, + IPPROTO_TLS_1_2 = 258, +}; + /** Socket type */ enum net_sock_type { SOCK_STREAM = 1, diff --git a/include/net/socket.h b/include/net/socket.h index f7890ab1cc0e6..55d4c88184a6a 100644 --- a/include/net/socket.h +++ b/include/net/socket.h @@ -76,67 +76,134 @@ int zsock_getaddrinfo(const char *host, const char *service, const struct zsock_addrinfo *hints, struct zsock_addrinfo **res); +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + +int ztls_socket(int family, int type, int proto); +int ztls_close(int sock); +int ztls_bind(int sock, const struct sockaddr *addr, socklen_t addrlen); +int ztls_connect(int sock, const struct sockaddr *addr, socklen_t addrlen); +int ztls_listen(int sock, int backlog); +int ztls_accept(int sock, struct sockaddr *addr, socklen_t *addrlen); +ssize_t ztls_send(int sock, const void *buf, size_t len, int flags); +ssize_t ztls_recv(int sock, void *buf, size_t max_len, int flags); +ssize_t ztls_sendto(int sock, const void *buf, size_t len, int flags, + const struct sockaddr *dest_addr, socklen_t addrlen); +ssize_t ztls_recvfrom(int sock, void *buf, size_t max_len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen); +int ztls_fcntl(int sock, int cmd, int flags); +int ztls_poll(struct zsock_pollfd *fds, int nfds, int timeout); + +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ + #if defined(CONFIG_NET_SOCKETS_POSIX_NAMES) static inline int socket(int family, int type, int proto) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + return ztls_socket(family, type, proto); +#else return zsock_socket(family, type, proto); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int close(int sock) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + return ztls_close(sock); +#else return zsock_close(sock); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int bind(int sock, const struct sockaddr *addr, socklen_t addrlen) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + return ztls_bind(sock, addr, addrlen); +#else return zsock_bind(sock, addr, addrlen); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int connect(int sock, const struct sockaddr *addr, socklen_t addrlen) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + return ztls_connect(sock, addr, addrlen); +#else return zsock_connect(sock, addr, addrlen); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int listen(int sock, int backlog) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + return ztls_listen(sock, backlog); +#else return zsock_listen(sock, backlog); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int accept(int sock, struct sockaddr *addr, socklen_t *addrlen) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + return ztls_accept(sock, addr, addrlen); +#else return zsock_accept(sock, addr, addrlen); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline ssize_t send(int sock, const void *buf, size_t len, int flags) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + return ztls_send(sock, buf, len, flags); +#else return zsock_send(sock, buf, len, flags); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline ssize_t recv(int sock, void *buf, size_t max_len, int flags) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + return ztls_recv(sock, buf, max_len, flags); +#else return zsock_recv(sock, buf, max_len, flags); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } /* This conflicts with fcntl.h, so code must include fcntl.h before socket.h: */ +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) +#define fcntl ztls_fcntl +#else #define fcntl zsock_fcntl +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ static inline ssize_t sendto(int sock, const void *buf, size_t len, int flags, const struct sockaddr *dest_addr, socklen_t addrlen) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + return ztls_sendto(sock, buf, len, flags, dest_addr, addrlen); +#else return zsock_sendto(sock, buf, len, flags, dest_addr, addrlen); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline ssize_t recvfrom(int sock, void *buf, size_t max_len, int flags, struct sockaddr *src_addr, socklen_t *addrlen) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + return ztls_recvfrom(sock, buf, max_len, flags, src_addr, addrlen); +#else return zsock_recvfrom(sock, buf, max_len, flags, src_addr, addrlen); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } static inline int poll(struct zsock_pollfd *fds, int nfds, int timeout) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + return ztls_poll(fds, nfds, timeout); +#else return zsock_poll(fds, nfds, timeout); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ } #define pollfd zsock_pollfd diff --git a/samples/net/sockets/big_http_download/prj_tls.conf b/samples/net/sockets/big_http_download/prj_tls.conf new file mode 100644 index 0000000000000..a6b0afb9da471 --- /dev/null +++ b/samples/net/sockets/big_http_download/prj_tls.conf @@ -0,0 +1,46 @@ +# General config +CONFIG_NEWLIB_LIBC=y + +# Networking config +CONFIG_NETWORKING=y +CONFIG_NET_IPV4=y +CONFIG_NET_IPV6=y +CONFIG_NET_TCP=y +CONFIG_NET_SOCKETS=y +CONFIG_NET_SOCKETS_POSIX_NAMES=y + +CONFIG_NET_PKT_TX_COUNT=10 + +CONFIG_DNS_RESOLVER=y +CONFIG_DNS_SERVER_IP_ADDRESSES=y +CONFIG_DNS_SERVER1="192.0.2.2" + +# Network driver config +CONFIG_TEST_RANDOM_GENERATOR=y + +# Network address config +CONFIG_NET_APP_SETTINGS=y +CONFIG_NET_APP_NEED_IPV4=y +CONFIG_NET_APP_MY_IPV4_ADDR="192.0.2.1" +CONFIG_NET_APP_PEER_IPV4_ADDR="192.0.2.2" +CONFIG_NET_APP_MY_IPV4_GW="192.0.2.2" +# DHCP configuration. Until DHCP address is assigned, +# static configuration above is used instead. +CONFIG_NET_DHCPV4=y + +# Network debug config +CONFIG_NET_LOG=y +CONFIG_NET_LOG_GLOBAL=y +CONFIG_SYS_LOG_NET_LEVEL=2 +#CONFIG_NET_DEBUG_SOCKETS=y +#CONFIG_NET_SHELL=y + +# TLS configuration +CONFIG_MBEDTLS=y +CONFIG_MBEDTLS_BUILTIN=y +CONFIG_MBEDTLS_ENABLE_HEAP=y +CONFIG_MBEDTLS_HEAP_SIZE=60000 +CONFIG_MBEDTLS_SSL_MAX_CONTENT_LEN=16384 + +CONFIG_MAIN_STACK_SIZE=4096 +CONFIG_NET_SOCKETS_SOCKOPT_TLS=y diff --git a/samples/net/sockets/big_http_download/src/big_http_download.c b/samples/net/sockets/big_http_download/src/big_http_download.c index ccd4a9da4d567..2ba3ff3985ab3 100644 --- a/samples/net/sockets/big_http_download/src/big_http_download.c +++ b/samples/net/sockets/big_http_download/src/big_http_download.c @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include #include #include #include @@ -23,24 +24,34 @@ #include #include #include + #define sleep(x) k_sleep(x * 1000) #endif /* This URL is parsed in-place, so buffer must be non-const. */ static char download_url[] = +#if !defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) "http://archive.ubuntu.com/ubuntu/dists/xenial/main/installer-amd64/current/images/hd-media/vmlinuz"; +#else + "https://www.7-zip.org/a/7z1805.exe"; +#endif /* !defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ /* Quick testing. */ /* "http://google.com/foo";*/ /* print("".join(["\\x%02x" % x for x in list(binascii.unhexlify("hash"))])) */ -static uint8_t download_hash[32] = "\x33\x7c\x37\xd7\xec\x00\x34\x84\x14\x22\x4b\xaa\x6b\xdb\x2d\x43\xf2\xa3\x4e\xf5\x67\x6b\xaf\xcd\xca\xd9\x16\xf1\x48\xb5\xb3\x17"; +static uint8_t download_hash[32] = +#if !defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) +"\x33\x7c\x37\xd7\xec\x00\x34\x84\x14\x22\x4b\xaa\x6b\xdb\x2d\x43\xf2\xa3\x4e\xf5\x67\x6b\xaf\xcd\xca\xd9\x16\xf1\x48\xb5\xb3\x17"; +#else +"\x64\x7a\x9a\x62\x11\x62\xcd\x7a\x50\x08\x93\x4a\x08\xe2\x3f\xf7\xc1\x13\x5d\x6f\x12\x61\x68\x9f\xd9\x54\xaa\x17\xd5\x0f\x97\x29"; +#endif /* !defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ #define SSTRLEN(s) (sizeof(s) - 1) #define CHECK(r) { if (r == -1) { printf("Error: " #r "\n"); exit(1); } } const char *host; -const char *port = "80"; +const char *port; const char *uri_path = ""; static char response[1024]; static char response_hash[32]; @@ -113,15 +124,26 @@ void print_hex(const unsigned char *p, int len) } } -void download(struct addrinfo *ai) +void download(struct addrinfo *ai, bool is_tls) { int sock; cur_bytes = 0; - sock = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); + if (is_tls) { +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + sock = socket(ai->ai_family, ai->ai_socktype, IPPROTO_TLS_1_2); +# else + printf("TLS not supported\n"); + return; +#endif + } else { + sock = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); + } + CHECK(sock); printf("sock = %d\n", sock); + CHECK(connect(sock, ai->ai_addr, ai->ai_addrlen)); sendall(sock, "GET /", SSTRLEN("GET /")); sendall(sock, uri_path, strlen(uri_path)); @@ -183,15 +205,28 @@ int main(void) char *p; unsigned int total_bytes = 0; int resolve_attempts = 10; + bool is_tls = false; setbuf(stdout, NULL); - if (strncmp(download_url, "http://", SSTRLEN("http://")) != 0) { - fatal("Only http: URLs are supported"); + if (strncmp(download_url, "http://", SSTRLEN("http://")) == 0) { + port = "80"; + p = download_url + SSTRLEN("http://"); +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + } else if (strncmp(download_url, "https://", + SSTRLEN("https://")) == 0) { + is_tls = true; + port = "443"; + p = download_url + SSTRLEN("https://"); +#endif /* defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) */ + } else { + fatal("Only http: " +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + "and https: " +#endif + "URLs are supported"); } - p = download_url + SSTRLEN("http://"); - /* Parse host part */ host = p; while (*p && *p != ':' && *p != '/') { @@ -214,8 +249,8 @@ int main(void) uri_path = p; } - printf("Preparing HTTP GET request for http://%s:%s/%s\n", - host, port, uri_path); + printf("Preparing HTTP GET request for http%s://%s:%s/%s\n", + (is_tls ? "s" : ""), host, port, uri_path); hints.ai_family = AF_INET; hints.ai_socktype = SOCK_STREAM; @@ -248,7 +283,7 @@ int main(void) } while (1) { - download(res); + download(res, is_tls); total_bytes += cur_bytes; printf("Total downloaded so far: %uMB\n", total_bytes / (1024 * 1024)); diff --git a/samples/net/sockets/http_get/prj_tls.conf b/samples/net/sockets/http_get/prj_tls.conf new file mode 100644 index 0000000000000..eccd8b0f33e00 --- /dev/null +++ b/samples/net/sockets/http_get/prj_tls.conf @@ -0,0 +1,40 @@ +# General config +CONFIG_NEWLIB_LIBC=y + +# Networking config +CONFIG_NETWORKING=y +CONFIG_NET_IPV4=y +CONFIG_NET_IPV6=y +CONFIG_NET_TCP=y +CONFIG_NET_SOCKETS=y +CONFIG_NET_SOCKETS_POSIX_NAMES=y + +CONFIG_DNS_RESOLVER=y +CONFIG_DNS_SERVER_IP_ADDRESSES=y +CONFIG_DNS_SERVER1="192.0.2.2" + +# Network driver config +CONFIG_TEST_RANDOM_GENERATOR=y + +# Network address config +CONFIG_NET_APP_SETTINGS=y +CONFIG_NET_APP_NEED_IPV4=y +CONFIG_NET_APP_MY_IPV4_ADDR="192.0.2.1" +CONFIG_NET_APP_PEER_IPV4_ADDR="192.0.2.2" +CONFIG_NET_APP_MY_IPV4_GW="192.0.2.2" + +# Network debug config +CONFIG_NET_LOG=y +CONFIG_NET_LOG_GLOBAL=y +CONFIG_SYS_LOG_NET_LEVEL=2 +#CONFIG_NET_DEBUG_SOCKETS=y + +# TLS configuration +CONFIG_MBEDTLS=y +CONFIG_MBEDTLS_BUILTIN=y +CONFIG_MBEDTLS_ENABLE_HEAP=y +CONFIG_MBEDTLS_HEAP_SIZE=30000 +CONFIG_MBEDTLS_SSL_MAX_CONTENT_LEN=4096 + +CONFIG_MAIN_STACK_SIZE=4096 +CONFIG_NET_SOCKETS_SOCKOPT_TLS=y diff --git a/samples/net/sockets/http_get/src/http_get.c b/samples/net/sockets/http_get/src/http_get.c index 8785eb8a554a8..22a2d6e1e5e8f 100644 --- a/samples/net/sockets/http_get/src/http_get.c +++ b/samples/net/sockets/http_get/src/http_get.c @@ -26,7 +26,11 @@ /* HTTP server to connect to */ #define HTTP_HOST "google.com" /* Port to connect to, as string */ +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) +#define HTTP_PORT "443" +#else #define HTTP_PORT "80" +#endif /* HTTP path to request */ #define HTTP_PATH "/" @@ -74,7 +78,11 @@ int main(void) dump_addrinfo(res); +#if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) + sock = socket(res->ai_family, res->ai_socktype, IPPROTO_TLS_1_2); +#else sock = socket(res->ai_family, res->ai_socktype, res->ai_protocol); +#endif CHECK(sock); printf("sock = %d\n", sock); CHECK(connect(sock, res->ai_addr, res->ai_addrlen)); @@ -95,9 +103,11 @@ int main(void) } response[len] = 0; - printf("%s\n", response); + printf("%s", response); } + printf("\n"); + (void)close(sock); return 0; diff --git a/subsys/net/lib/sockets/CMakeLists.txt b/subsys/net/lib/sockets/CMakeLists.txt index 94ed549b22131..a2f1b942d9eab 100644 --- a/subsys/net/lib/sockets/CMakeLists.txt +++ b/subsys/net/lib/sockets/CMakeLists.txt @@ -3,3 +3,5 @@ zephyr_sources( getaddrinfo.c sockets.c ) + +zephyr_sources_ifdef(CONFIG_NET_SOCKETS_SOCKOPT_TLS sockets_tls.c) diff --git a/subsys/net/lib/sockets/Kconfig b/subsys/net/lib/sockets/Kconfig index 561af3de53376..3391f8ca567ba 100644 --- a/subsys/net/lib/sockets/Kconfig +++ b/subsys/net/lib/sockets/Kconfig @@ -32,6 +32,21 @@ config NET_SOCKETS_POLL_MAX help Maximum number of entries supported for poll() call. +config NET_SOCKETS_SOCKOPT_TLS + bool "Enable TCP TLS socket option support [EXPERIMENTAL]" + default n + help + Enable TLS socket option support which automatically establishes + a TLS connection to the remote host. + +config NET_SOCKETS_TLS_MAX_CONTEXTS + int "Maximum number of TLS/DTLS contexts" + default 1 + depends on NET_SOCKETS_SOCKOPT_TLS + help + "This variable specifies maximum number of TLS/DTLS contexts that can + be allocated at the same time." + config NET_DEBUG_SOCKETS bool "Debug BSD Sockets compatible API calls" default n diff --git a/subsys/net/lib/sockets/sockets_tls.c b/subsys/net/lib/sockets/sockets_tls.c new file mode 100644 index 0000000000000..9fe9713b36a22 --- /dev/null +++ b/subsys/net/lib/sockets/sockets_tls.c @@ -0,0 +1,670 @@ +/* + * Copyright (c) 2018 Intel Corporation + * Copyright (c) 2018 Nordic Semiconductor ASA + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#if defined(CONFIG_NET_DEBUG_SOCKETS) +#define SYS_LOG_DOMAIN "net/tls" +#define NET_LOG_ENABLED 1 +#endif + +#include + +#include +#include +#include + +#if defined(CONFIG_MBEDTLS) +#if !defined(CONFIG_MBEDTLS_CFG_FILE) +#include "mbedtls/config.h" +#else +#include CONFIG_MBEDTLS_CFG_FILE +#endif /* CONFIG_MBEDTLS_CFG_FILE */ + +#include +#include +#include +#include +#include +#include +#include +#endif /* CONFIG_MBEDTLS */ + +/** TLS context information. */ +struct tls_context { + /** Information whether TLS context is used. */ + bool is_used; + + /** Secure protocol version running on TLS context. */ + enum net_ip_protocol_secure tls_version; + + /** Socket flags passed to a socket call. */ + int flags; + +#if defined(CONFIG_MBEDTLS) + /** mbedTLS context. */ + mbedtls_ssl_context ssl; + + /** mbedTLS configuration. */ + mbedtls_ssl_config config; +#endif /* CONFIG_MBEDTLS */ +}; + +static mbedtls_ctr_drbg_context tls_ctr_drbg; + +/* A global pool of TLS contexts. */ +static struct tls_context tls_contexts[CONFIG_NET_SOCKETS_TLS_MAX_CONTEXTS]; + +/* A mutex for protecting TLS context allocation. */ +static struct k_mutex context_lock; + +#if defined(MBEDTLS_DEBUG_C) && defined(CONFIG_NET_DEBUG_SOCKETS) +static void tls_debug(void *ctx, int level, const char *file, + int line, const char *str) +{ + const char *p, *basename; + + ARG_UNUSED(ctx); + + if (!file || !str) { + return; + } + + /* Extract basename from file */ + for (p = basename = file; *p != '\0'; p++) { + if (*p == '/' || *p == '\\') { + basename = p + 1; + } + } + + NET_DBG("%s:%04d: |%d| %s", basename, line, level, str); +} +#endif /* defined(MBEDTLS_DEBUG_C) && defined(CONFIG_NET_TLS_DEBUG) */ + +#if defined(CONFIG_ENTROPY_HAS_DRIVER) +static int tls_entropy_func(void *ctx, unsigned char *buf, size_t len) +{ + return entropy_get_entropy(ctx, buf, len); +} +#else +static int tls_entropy_func(void *ctx, unsigned char *buf, size_t len) +{ + ARG_UNUSED(ctx); + + size_t i = len / 4; + u32_t val; + + while (i--) { + val = sys_rand32_get(); + UNALIGNED_PUT(val, (u32_t *)buf); + buf += 4; + } + + i = len & 0x3; + val = sys_rand32_get(); + while (i--) { + *buf++ = val; + val >>= 8; + } + + return 0; +} +#endif /* defined(CONFIG_ENTROPY_HAS_DRIVER) */ + +/* Initialize TLS internals. */ +static int tls_init(struct device *unused) +{ + ARG_UNUSED(unused); + + int ret; + static const unsigned char drbg_seed[] = "zephyr"; + struct device *dev = NULL; + +#if defined(CONFIG_ENTROPY_HAS_DRIVER) + dev = device_get_binding(CONFIG_ENTROPY_NAME); + + if (!dev) { + NET_ERR("Failed to obtain entropy device"); + return -ENODEV; + } +#else + NET_WARN("No entropy device on the system, " + "TLS communication may be insecure!"); +#endif /* defined(CONFIG_ENTROPY_HAS_DRIVER) */ + + memset(tls_contexts, 0, sizeof(tls_contexts)); + + k_mutex_init(&context_lock); + + mbedtls_ctr_drbg_init(&tls_ctr_drbg); + + ret = mbedtls_ctr_drbg_seed(&tls_ctr_drbg, tls_entropy_func, dev, + drbg_seed, sizeof(drbg_seed)); + if (ret != 0) { + mbedtls_ctr_drbg_free(&tls_ctr_drbg); + NET_ERR("TLS entropy source initialization failed"); + return -EFAULT; + } + +#if defined(MBEDTLS_DEBUG_C) && defined(CONFIG_NET_DEBUG_SOCKETS) + mbedtls_debug_set_threshold(CONFIG_MBEDTLS_DEBUG_LEVEL); +#endif + + return 0; +} + +SYS_INIT(tls_init, APPLICATION, CONFIG_KERNEL_INIT_PRIORITY_DEFAULT); + +/* Allocate TLS context. */ +static struct tls_context *tls_alloc(void) +{ + int i; + struct tls_context *tls = NULL; + + k_mutex_lock(&context_lock, K_FOREVER); + + for (i = 0; i < ARRAY_SIZE(tls_contexts); i++) { + if (!tls_contexts[i].is_used) { + tls = &tls_contexts[i]; + memset(tls, 0, sizeof(*tls)); + tls->is_used = true; + + NET_DBG("Allocated TLS context, %p", tls); + break; + } + } + + k_mutex_unlock(&context_lock); + + if (tls) { + mbedtls_ssl_init(&tls->ssl); + mbedtls_ssl_config_init(&tls->config); + +#if defined(MBEDTLS_DEBUG_C) && defined(CONFIG_NET_DEBUG_SOCKETS) + mbedtls_ssl_conf_dbg(&tls->config, tls_debug, NULL); +#endif + } else { + NET_WARN("Failed to allocate TLS context"); + } + + return tls; +} + +/* Allocate new TLS context and copy the content from the source context. */ +static struct tls_context *tls_clone(struct tls_context *source_tls) +{ + struct tls_context *target_tls; + + target_tls = tls_alloc(); + if (!target_tls) { + return NULL; + } + + target_tls->tls_version = source_tls->tls_version; + + return target_tls; +} + +/* Release TLS context. */ +static int tls_release(struct tls_context *tls) +{ + if (!PART_OF_ARRAY(tls_contexts, tls)) { + NET_ERR("Invalid TLS context"); + return -EBADF; + } + + if (!tls->is_used) { + NET_ERR("Deallocating unused TLS context"); + return -EBADF; + } + + mbedtls_ssl_config_free(&tls->config); + mbedtls_ssl_free(&tls->ssl); + + tls->is_used = false; + + return 0; +} + +static int tls_tx(void *ctx, const unsigned char *buf, size_t len) +{ + int sock = POINTER_TO_INT(ctx); + ssize_t sent; + + sent = zsock_sendto(sock, buf, len, + ((struct net_context *)ctx)->tls->flags, + NULL, 0); + if (sent < 0) { + if (errno == EAGAIN) { + return MBEDTLS_ERR_SSL_WANT_WRITE; + } + + return MBEDTLS_ERR_NET_SEND_FAILED; + } + + return sent; +} + +static int tls_rx(void *ctx, unsigned char *buf, size_t len) +{ + int sock = POINTER_TO_INT(ctx); + ssize_t received; + + received = zsock_recvfrom(sock, buf, len, + ((struct net_context *)ctx)->tls->flags, + NULL, 0); + if (received < 0) { + if (errno == EAGAIN) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + + return MBEDTLS_ERR_NET_RECV_FAILED; + } + + return received; +} + +static int tls_mbedtls_set_credentials(struct tls_context *tls) +{ + /* TODO Temporary solution to verify communication */ + mbedtls_ssl_conf_authmode(&tls->config, MBEDTLS_SSL_VERIFY_NONE); + + return 0; +} + +static int tls_mbedtls_handshake(struct net_context *context) +{ + int ret; + + /* We do not want to use any socket flags during the handshake. */ + context->tls->flags = 0; + + /* TODO For simplicity, TLS handshake blocks the socket even for + * non-blocking socket. Non-blocking behavior for handshake can + * be implemented later. + */ + while ((ret = mbedtls_ssl_handshake(&context->tls->ssl)) != 0) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ || + ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + continue; + } + + NET_ERR("TLS handshake error: -%x", -ret); + ret = -ECONNABORTED; + break; + } + + return ret; +} + +static int tls_mbedtls_init(struct net_context *context, bool is_server) +{ + int role, type, ret; + + role = is_server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT; + + type = (net_context_get_type(context) == SOCK_STREAM) ? + MBEDTLS_SSL_TRANSPORT_STREAM : + MBEDTLS_SSL_TRANSPORT_DATAGRAM; + + mbedtls_ssl_set_bio(&context->tls->ssl, context, tls_tx, tls_rx, NULL); + + ret = mbedtls_ssl_config_defaults(&context->tls->config, role, type, + MBEDTLS_SSL_PRESET_DEFAULT); + if (ret != 0) { + /* According to mbedTLS API documentation, + * mbedtls_ssl_config_defaults can fail due to memory + * allocation failure + */ + return -ENOMEM; + } + + mbedtls_ssl_conf_rng(&context->tls->config, + mbedtls_ctr_drbg_random, + &tls_ctr_drbg); + + ret = tls_mbedtls_set_credentials(context->tls); + if (ret != 0) { + return ret; + } + + ret = mbedtls_ssl_setup(&context->tls->ssl, + &context->tls->config); + if (ret != 0) { + /* According to mbedTLS API documentation, + * mbedtls_ssl_setup can fail due to memory allocation failure + */ + return -ENOMEM; + } + + return 0; +} + +int ztls_socket(int family, int type, int proto) +{ + enum net_ip_protocol_secure tls_proto = 0; + int sock, ret, err; + + if (proto >= IPPROTO_TLS_1_0 && proto <= IPPROTO_TLS_1_2) { + /* Currently DTLS is not supported, + * so do not allow to create datagram socket + */ + if (type == SOCK_DGRAM) { + errno = ENOTSUP; + return -1; + } + + tls_proto = proto; + proto = (type == SOCK_STREAM) ? IPPROTO_TCP : IPPROTO_UDP; + } + + sock = zsock_socket(family, type, proto); + if (sock < 0) { + /* errno will be propagated */ + return -1; + } + + if (tls_proto != 0) { + /* If TLS protocol is used, allocate TLS context */ + struct net_context *context = INT_TO_POINTER(sock); + + context->tls = tls_alloc(); + + if (!context->tls) { + ret = -ENOMEM; + goto error; + } + + context->tls->tls_version = tls_proto; + } + + return sock; + +error: + err = zsock_close(sock); + __ASSERT(err == 0, "Socket close failed"); + + errno = -ret; + return -1; +} + +int ztls_close(int sock) +{ + struct net_context *context = INT_TO_POINTER(sock); + int ret, err = 0; + + if (context->tls) { + /* Try to send close notification. */ + context->tls->flags = 0; + (void)mbedtls_ssl_close_notify(&context->tls->ssl); + + err = tls_release(context->tls); + } + + ret = zsock_close(sock); + + /* In case zsock_close fails, we propagate errno value set by + * zsock_close. + * In case zsock_close succeeds, but tls_release fails, set errno + * according to tls_release return value. + */ + if (ret == 0 && err < 0) { + errno = -err; + ret = -1; + } + + return ret; +} + +int ztls_bind(int sock, const struct sockaddr *addr, socklen_t addrlen) +{ + /* No extra action needed here. */ + return zsock_bind(sock, addr, addrlen); +} + +int ztls_connect(int sock, const struct sockaddr *addr, socklen_t addrlen) +{ + int ret; + struct net_context *context = INT_TO_POINTER(sock); + + ret = zsock_connect(sock, addr, addrlen); + if (ret < 0) { + /* errno will be propagated */ + return -1; + } + + if (context->tls) { + ret = tls_mbedtls_init(context, false); + if (ret < 0) { + goto error; + } + + ret = tls_mbedtls_handshake(context); + if (ret < 0) { + goto error; + } + } + + return 0; + +error: + errno = -ret; + return -1; +} + +int ztls_listen(int sock, int backlog) +{ + /* No extra action needed here. */ + return zsock_listen(sock, backlog); +} + +int ztls_accept(int sock, struct sockaddr *addr, socklen_t *addrlen) +{ + int child_sock, ret, err; + struct net_context *parent_context = INT_TO_POINTER(sock); + struct net_context *child_context = NULL; + + child_sock = zsock_accept(sock, addr, addrlen); + if (child_sock < 0) { + /* errno will be propagated */ + return -1; + } + + if (parent_context->tls) { + child_context = INT_TO_POINTER(child_sock); + + child_context->tls = tls_clone(parent_context->tls); + if (!child_context->tls) { + ret = -ENOMEM; + goto error; + } + + ret = tls_mbedtls_init(child_context, true); + if (ret < 0) { + goto error; + } + + ret = tls_mbedtls_handshake(child_context); + if (ret < 0) { + goto error; + } + } + + return child_sock; + +error: + if (child_context && child_context->tls) { + err = tls_release(child_context->tls); + __ASSERT(err == 0, "TLS context release failed"); + } + + err = zsock_close(child_sock); + __ASSERT(err == 0, "Child socket close failed"); + + errno = -ret; + return -1; +} + +ssize_t ztls_send(int sock, const void *buf, size_t len, int flags) +{ + return ztls_sendto(sock, buf, len, flags, NULL, 0); +} + +ssize_t ztls_recv(int sock, void *buf, size_t max_len, int flags) +{ + return ztls_recvfrom(sock, buf, max_len, flags, NULL, 0); +} + +ssize_t ztls_sendto(int sock, const void *buf, size_t len, int flags, + const struct sockaddr *dest_addr, socklen_t addrlen) +{ + struct net_context *context = INT_TO_POINTER(sock); + int ret; + + if (!context->tls) { + return zsock_sendto(sock, buf, len, flags, dest_addr, addrlen); + } + + context->tls->flags = flags; + + ret = mbedtls_ssl_write(&context->tls->ssl, buf, len); + if (ret >= 0) { + return ret; + } + + if (ret == MBEDTLS_ERR_SSL_WANT_READ || + ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + errno = EAGAIN; + return -1; + } + + errno = EIO; + return -1; +} + +ssize_t ztls_recvfrom(int sock, void *buf, size_t max_len, int flags, + struct sockaddr *src_addr, socklen_t *addrlen) +{ + struct net_context *context = INT_TO_POINTER(sock); + int ret; + + if (!context->tls) { + return zsock_recvfrom(sock, buf, max_len, flags, + src_addr, addrlen); + } + + if (flags & ZSOCK_MSG_PEEK) { + /* TODO mbedTLS does not support 'peeking' This could be + * bypassed by having intermediate buffer for peeking + */ + errno = ENOTSUP; + return -1; + } + + context->tls->flags = flags; + + ret = mbedtls_ssl_read(&context->tls->ssl, buf, max_len); + if (ret >= 0) { + return ret; + } + + if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + /* Peer notified that it's closing the connection. */ + return 0; + } + + if (ret == MBEDTLS_ERR_SSL_CLIENT_RECONNECT) { + /* Client reconnect on the same socket is not + * supported. See mbedtls_ssl_read API documentation. + */ + return 0; + } + + if (ret == MBEDTLS_ERR_SSL_WANT_READ || + ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + errno = EAGAIN; + return -1; + } + + errno = EIO; + return -1; +} + +int ztls_fcntl(int sock, int cmd, int flags) +{ + /* No extra action needed here. */ + return zsock_fcntl(sock, cmd, flags); +} + +int ztls_poll(struct zsock_pollfd *fds, int nfds, int timeout) +{ + bool has_mbedtls_data = false; + struct zsock_pollfd *pfd; + struct net_context *context; + int i, ret; + + /* There might be some decrypted but unread data pending on mbedTLS, + * check for that. + */ + for (pfd = fds, i = nfds; i--; pfd++) { + /* Per POSIX, negative fd's are just ignored */ + if (pfd->fd < 0) { + continue; + } + + if (pfd->events & ZSOCK_POLLIN) { + context = INT_TO_POINTER(pfd->fd); + if (!context->tls) { + continue; + } + + if (mbedtls_ssl_get_bytes_avail( + &context->tls->ssl) > 0) { + has_mbedtls_data = true; + break; + } + } + } + + /* If there is no data waiting on any of mbedTLS contexts, + * just do regular poll. + */ + if (!has_mbedtls_data) { + return zsock_poll(fds, nfds, timeout); + } + + /* Otherwise, poll with no timeout, and update respective revents. */ + ret = zsock_poll(fds, nfds, K_NO_WAIT); + if (ret < 0) { + /* errno will be propagated */ + return -1; + } + + /* Another pass, this time updating revents. */ + for (pfd = fds, i = nfds; i--; pfd++) { + /* Per POSIX, negative fd's are just ignored */ + if (pfd->fd < 0) { + continue; + } + + if (pfd->events & ZSOCK_POLLIN) { + context = INT_TO_POINTER(pfd->fd); + if (!context->tls) { + continue; + } + + if (mbedtls_ssl_get_bytes_avail( + &context->tls->ssl) > 0) { + if (pfd->revents == 0) { + ret++; + } + + pfd->revents |= ZSOCK_POLLIN; + } + } + } + + return ret; +}