diff --git a/src/test/test_wolfssl_forwarding.c b/src/test/test_wolfssl_forwarding.c index 1ac3b4d..1c9c23c 100644 --- a/src/test/test_wolfssl_forwarding.c +++ b/src/test/test_wolfssl_forwarding.c @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -111,10 +112,11 @@ static void wolfip_reset_io(WOLFSSL *ssl) struct mem_link { pthread_mutex_t lock; - pthread_cond_t cond[2]; - uint8_t buf[2][LINK_MTU]; - uint32_t len[2]; - int ready[2]; + uint8_t buf[2][8][LINK_MTU]; + uint32_t len[2][8]; + uint8_t head[2]; + uint8_t tail[2]; + uint8_t count[2]; const char *name[2]; }; @@ -130,10 +132,9 @@ static size_t mem_ep_count; static void mem_link_init(struct mem_link *link) { pthread_mutex_init(&link->lock, NULL); - pthread_cond_init(&link->cond[0], NULL); - pthread_cond_init(&link->cond[1], NULL); - link->ready[0] = link->ready[1] = 0; - link->len[0] = link->len[1] = 0; + link->head[0] = link->head[1] = 0; + link->tail[0] = link->tail[1] = 0; + link->count[0] = link->count[1] = 0; link->name[0] = link->name[1] = ""; } @@ -159,13 +160,14 @@ static int mem_ll_poll(struct wolfIP_ll_dev *ll, void *buf, uint32_t len) idx = ep->idx; pthread_mutex_lock(&link->lock); - if (link->ready[idx]) { - uint32_t copy = link->len[idx]; + if (link->count[idx] > 0) { + uint8_t head = link->head[idx]; + uint32_t copy = link->len[idx][head]; if (copy > len) copy = len; - memcpy(buf, link->buf[idx], copy); - link->ready[idx] = 0; - pthread_cond_signal(&link->cond[idx]); + memcpy(buf, link->buf[idx][head], copy); + link->head[idx] = (uint8_t)((head + 1U) % 8U); + link->count[idx]--; ret = (int)copy; } pthread_mutex_unlock(&link->lock); @@ -184,18 +186,31 @@ static int mem_ll_send(struct wolfIP_ll_dev *ll, void *buf, uint32_t len) dst = 1 - ep->idx; pthread_mutex_lock(&link->lock); - while (link->ready[dst]) - pthread_cond_wait(&link->cond[dst], &link->lock); + if (link->count[dst] >= 8U) { + pthread_mutex_unlock(&link->lock); + return -1; + } if (len > LINK_MTU) len = LINK_MTU; - memcpy(link->buf[dst], buf, len); - link->len[dst] = len; - link->ready[dst] = 1; - pthread_cond_signal(&link->cond[dst]); + memcpy(link->buf[dst][link->tail[dst]], buf, len); + link->len[dst][link->tail[dst]] = len; + link->tail[dst] = (uint8_t)((link->tail[dst] + 1U) % 8U); + link->count[dst]++; pthread_mutex_unlock(&link->lock); return (int)len; } +static int set_nonblocking(int fd) +{ + int flags = fcntl(fd, F_GETFL, 0); + + if (flags < 0) + return -1; + if ((flags & O_NONBLOCK) != 0) + return 0; + return fcntl(fd, F_SETFL, flags | O_NONBLOCK); +} + static void mem_link_attach(struct wolfIP_ll_dev *ll, struct mem_link *link, int idx, const char *ifname, const uint8_t mac[6]) { @@ -367,6 +382,7 @@ static int run_host_tls_client(ip4 server_ip) size_t received = 0; int connected = 0; char remote_str[16]; + uint64_t handshake_deadline; sleep(1); ip4_to_str(server_ip, remote_str, sizeof(remote_str)); @@ -390,6 +406,11 @@ static int run_host_tls_client(ip4 server_ip) perror("socket"); goto out; } + if (set_nonblocking(fd) < 0) { + perror("fcntl"); + goto out; + } + wolfSSL_set_using_nonblock(ssl, 1); wolfSSL_set_fd(ssl, fd); for (int attempt = 0; attempt < 50; attempt++) { @@ -400,9 +421,7 @@ static int run_host_tls_client(ip4 server_ip) printf("TLS client: TCP connected\n"); break; } - if (errno == EINPROGRESS) { - if (wolfSSL_get_using_nonblock(ssl) == 0) - wolfSSL_set_using_nonblock(ssl, 1); + if (errno == EINPROGRESS || errno == EALREADY || errno == EWOULDBLOCK) { if (poll(&(struct pollfd){ .fd = fd, .events = POLLOUT }, 1, 100) > 0) { socklen_t errlen = sizeof(err); if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &errlen) == 0 && err == 0) { @@ -426,6 +445,7 @@ static int run_host_tls_client(ip4 server_ip) goto out; } + handshake_deadline = monotonic_ms() + 30000U; while (1) { int hret = wolfSSL_connect(ssl); if (hret == SSL_SUCCESS) { @@ -434,6 +454,10 @@ static int run_host_tls_client(ip4 server_ip) } err = wolfSSL_get_error(ssl, hret); if (err == WOLFSSL_ERROR_WANT_READ || err == WOLFSSL_ERROR_WANT_WRITE) { + if (monotonic_ms() >= handshake_deadline) { + fprintf(stderr, "host client: handshake timed out\n"); + goto out; + } usleep(1000); continue; }