diff --git a/doc/xml/release.xml b/doc/xml/release.xml index ec4ba6fb8a..2f7473995f 100644 --- a/doc/xml/release.xml +++ b/doc/xml/release.xml @@ -50,6 +50,14 @@

Split session functionality of TlsClient out into TlsSession.

+ + + + + +

Use SocketSession/TlsSession for test servers.

+
+ diff --git a/src/common/io/socket/client.c b/src/common/io/socket/client.c index 29cbb7c14e..7e664d5687 100644 --- a/src/common/io/socket/client.c +++ b/src/common/io/socket/client.c @@ -141,7 +141,7 @@ sckClientOpen(SocketClient *this) // Create the session MEM_CONTEXT_PRIOR_BEGIN() { - result = sckSessionNew(fd, this->host, this->port, this->timeout); + result = sckSessionNew(sckSessionTypeClient, fd, this->host, this->port, this->timeout); } MEM_CONTEXT_PRIOR_END(); diff --git a/src/common/io/socket/session.c b/src/common/io/socket/session.c index ed05706cd0..86053b8ae6 100644 --- a/src/common/io/socket/session.c +++ b/src/common/io/socket/session.c @@ -23,6 +23,7 @@ Object type struct SocketSession { MemContext *memContext; // Mem context + SocketSessionType type; // Type (server or client) int fd; // File descriptor String *host; // Hostname or IP address unsigned int port; // Port to connect to host on @@ -32,6 +33,7 @@ struct SocketSession OBJECT_DEFINE_MOVE(SOCKET_SESSION); OBJECT_DEFINE_GET(Fd, , SOCKET_SESSION, int, fd); +OBJECT_DEFINE_GET(Type, const, SOCKET_SESSION, SocketSessionType, type); OBJECT_DEFINE_FREE(SOCKET_SESSION); @@ -46,9 +48,10 @@ OBJECT_DEFINE_FREE_RESOURCE_END(LOG); /**********************************************************************************************************************************/ SocketSession * -sckSessionNew(int fd, const String *host, unsigned int port, TimeMSec timeout) +sckSessionNew(SocketSessionType type, int fd, const String *host, unsigned int port, TimeMSec timeout) { FUNCTION_LOG_BEGIN(logLevelDebug) + FUNCTION_LOG_PARAM(ENUM, type); FUNCTION_LOG_PARAM(INT, fd); FUNCTION_LOG_PARAM(STRING, host); FUNCTION_LOG_PARAM(UINT, port); @@ -67,6 +70,7 @@ sckSessionNew(int fd, const String *host, unsigned int port, TimeMSec timeout) *this = (SocketSession) { .memContext = MEM_CONTEXT_NEW(), + .type = type, .fd = fd, .host = strDup(host), .port = port, @@ -122,5 +126,7 @@ sckSessionReadWait(SocketSession *this) String * sckSessionToLog(const SocketSession *this) { - return strNewFmt("{fd: %d, host: %s, port: %u, timeout: %" PRIu64 "}", this->fd, strPtr(this->host), this->port, this->timeout); + return strNewFmt( + "{type: %s, fd %d, host: %s, port: %u, timeout: %" PRIu64 "}", this->type == sckSessionTypeClient ? "client" : "server", + this->fd, strPtr(this->host), this->port, this->timeout); } diff --git a/src/common/io/socket/session.h b/src/common/io/socket/session.h index e4ef0fe096..7a30e96cb5 100644 --- a/src/common/io/socket/session.h +++ b/src/common/io/socket/session.h @@ -8,6 +8,15 @@ Currently this is not a full-featured session and is only intended to isolate so #ifndef COMMON_IO_SOCKET_SESSION_H #define COMMON_IO_SOCKET_SESSION_H +/*********************************************************************************************************************************** +Test result operations +***********************************************************************************************************************************/ +typedef enum +{ + sckSessionTypeClient, + sckSessionTypeServer, +} SocketSessionType; + /*********************************************************************************************************************************** Object type ***********************************************************************************************************************************/ @@ -22,7 +31,7 @@ typedef struct SocketSession SocketSession; /*********************************************************************************************************************************** Constructors ***********************************************************************************************************************************/ -SocketSession *sckSessionNew(int fd, const String *host, unsigned int port, TimeMSec timeout); +SocketSession *sckSessionNew(SocketSessionType type, int fd, const String *host, unsigned int port, TimeMSec timeout); /*********************************************************************************************************************************** Functions @@ -39,6 +48,9 @@ Getters/Setters // Socket file descriptor int sckSessionFd(SocketSession *this); +// Socket type +SocketSessionType sckSessionType(const SocketSession *this); + /*********************************************************************************************************************************** Destructor ***********************************************************************************************************************************/ diff --git a/src/common/io/tls/session.c b/src/common/io/tls/session.c index b1ee08488a..a89b456f40 100644 --- a/src/common/io/tls/session.c +++ b/src/common/io/tls/session.c @@ -43,14 +43,13 @@ OBJECT_DEFINE_FREE_RESOURCE_BEGIN(TLS_SESSION, LOG, logLevelTrace) } OBJECT_DEFINE_FREE_RESOURCE_END(LOG); -/*********************************************************************************************************************************** -Close the connection -***********************************************************************************************************************************/ -static void -tlsSessionClose(TlsSession *this) +/**********************************************************************************************************************************/ +void +tlsSessionClose(TlsSession *this, bool shutdown) { FUNCTION_LOG_BEGIN(logLevelTrace); FUNCTION_LOG_PARAM(TLS_SESSION, this); + FUNCTION_LOG_PARAM(BOOL, shutdown); FUNCTION_LOG_END(); ASSERT(this != NULL); @@ -58,6 +57,10 @@ tlsSessionClose(TlsSession *this) // If not already closed if (this->session != NULL) { + // Shutdown on request + if (shutdown) + SSL_shutdown(this->session); + // Free the socket session sckSessionFree(this->socketSession); this->socketSession = NULL; @@ -89,7 +92,7 @@ tlsSessionError(TlsSession *this, int code) // The connection was closed case SSL_ERROR_ZERO_RETURN: { - tlsSessionClose(this); + tlsSessionClose(this, false); break; } @@ -106,7 +109,7 @@ tlsSessionError(TlsSession *this, int code) { // Get the error before closing so it is not cleared int errNo = errno; - tlsSessionClose(this); + tlsSessionClose(this, false); // Throw the sys error THROW_SYS_ERROR_CODE(errNo, KernelError, "tls failed syscall"); @@ -282,12 +285,17 @@ tlsSessionNew(SSL *session, SocketSession *socketSession, TimeMSec timeout) .timeout = timeout, }; - // Initiate TLS connection + // Ensure session is freed + memContextCallbackSet(this->memContext, tlsSessionFreeResource, this); + + // Negotiate TLS session cryptoError( SSL_set_fd(this->session, sckSessionFd(this->socketSession)) != 1, "unable to add socket to TLS session"); - cryptoError(SSL_connect(this->session) != 1, "unable to negotiate TLS connection"); - memContextCallbackSet(this->memContext, tlsSessionFreeResource, this); + if (sckSessionType(this->socketSession) == sckSessionTypeClient) + cryptoError(SSL_connect(this->session) != 1, "unable to negotiate client TLS session"); + else + cryptoError(SSL_accept(this->session) != 1, "unable to negotiate server TLS session"); // Create read and write interfaces this->write = ioWriteNewP(this, .write = tlsSessionWrite); diff --git a/src/common/io/tls/session.h b/src/common/io/tls/session.h index 32af6ea08d..acb335176b 100644 --- a/src/common/io/tls/session.h +++ b/src/common/io/tls/session.h @@ -25,6 +25,10 @@ typedef struct TlsSession TlsSession; /*********************************************************************************************************************************** Functions ***********************************************************************************************************************************/ +// Close the session. Shutdown should not be attempted after an error, which means the client never has the oppottunity to do a +// shutdown since the connection is held open until it is disconnected by the server. +void tlsSessionClose(TlsSession *this, bool shutdown); + // Move to a new parent mem context TlsSession *tlsSessionMove(TlsSession *this, MemContext *parentNew); diff --git a/test/src/common/harnessTls.c b/test/src/common/harnessTls.c index e974911d15..e02c83dbdb 100644 --- a/test/src/common/harnessTls.c +++ b/test/src/common/harnessTls.c @@ -11,6 +11,8 @@ Tls Test Harness #include "common/crypto/common.h" #include "common/error.h" +#include "common/io/socket/session.h" +#include "common/io/tls/session.intern.h" #include "common/type/buffer.h" #include "common/wait.h" @@ -24,8 +26,7 @@ Test defaults static int testServerSocket = 0; static SSL_CTX *testServerContext = NULL; -static int testClientSocket = 0; -static SSL *testClientSSL = NULL; +static TlsSession *testServerSession = NULL; /*********************************************************************************************************************************** Initialize TLS and listen on the specified port for TLS connections @@ -104,23 +105,13 @@ harnessTlsServerInitDefault(void) /*********************************************************************************************************************************** Expect an exact string from the client - -This is a very unforgiving function and short input will leave the server hanging. Definitely room for improvement here. ***********************************************************************************************************************************/ void harnessTlsServerExpect(const char *expected) { Buffer *buffer = bufNew(strlen(expected)); - int readBytes = 0; - // Read expected bytes - do - { - int lastBytes = SSL_read(testClientSSL, bufRemainsPtr(buffer), (int)bufRemains(buffer)); - readBytes += lastBytes; - bufUsedSet(buffer, (size_t)readBytes); - } - while (bufRemains(buffer)); + ioRead(tlsSessionIoRead(testServerSession), buffer); // Treat and ? characters as wildcards so variable elements (e.g. auth hashes) can be ignored String *actual = strNewBuf(buffer); @@ -142,7 +133,8 @@ Send a reply to the client void harnessTlsServerReply(const char *reply) { - SSL_write(testClientSSL, reply, (int)strlen(reply)); + ioWrite(tlsSessionIoWrite(testServerSession), BUF((unsigned char *)reply, strlen(reply))); + ioWriteFlush(tlsSessionIoWrite(testServerSession)); } /*********************************************************************************************************************************** @@ -154,15 +146,15 @@ harnessTlsServerAccept(void) struct sockaddr_in addr; unsigned int len = sizeof(addr); - testClientSocket = accept(testServerSocket, (struct sockaddr *)&addr, &len); + int testClientSocket = accept(testServerSocket, (struct sockaddr *)&addr, &len); if (testClientSocket < 0) THROW_SYS_ERROR(AssertError, "unable to accept socket"); - testClientSSL = SSL_new(testServerContext); - SSL_set_fd(testClientSSL, testClientSocket); + SSL *testClientSSL = SSL_new(testServerContext); - cryptoError(SSL_accept(testClientSSL) <= 0, "unable to accept TLS connection"); + testServerSession = tlsSessionNew( + testClientSSL, sckSessionNew(sckSessionTypeServer, testClientSocket, STRDEF("client"), 0, 5000), 5000); } /*********************************************************************************************************************************** @@ -171,17 +163,18 @@ Close the connection void harnessTlsServerClose(void) { - SSL_shutdown(testClientSSL); - SSL_free(testClientSSL); - close(testClientSocket); + tlsSessionClose(testServerSession, true); + tlsSessionFree(testServerSession); + testServerSession = NULL; } /**********************************************************************************************************************************/ void harnessTlsServerAbort(void) { - SSL_free(testClientSSL); - close(testClientSocket); + tlsSessionClose(testServerSession, false); + tlsSessionFree(testServerSession); + testServerSession = NULL; } /**********************************************************************************************************************************/ diff --git a/test/src/module/common/ioTlsTest.c b/test/src/module/common/ioTlsTest.c index 2eb8375c2c..5cf2f5a772 100644 --- a/test/src/module/common/ioTlsTest.c +++ b/test/src/module/common/ioTlsTest.c @@ -398,7 +398,7 @@ testRun(void) TEST_RESULT_UINT(ioRead(tlsSessionIoRead(session), output), 0, "read no output after eof"); TEST_RESULT_BOOL(ioReadEof(tlsSessionIoRead(session)), true, " check eof = true"); - TEST_RESULT_VOID(tlsSessionClose(session), "close again"); + TEST_RESULT_VOID(tlsSessionClose(session, false), "close again"); TEST_ERROR(tlsSessionError(session, SSL_ERROR_WANT_X509_LOOKUP), ServiceError, "tls error [4]"); // -----------------------------------------------------------------------------------------------------------------