diff --git a/doc/xml/release.xml b/doc/xml/release.xml
index ec4ba6fb8a..2f7473995f 100644
--- a/doc/xml/release.xml
+++ b/doc/xml/release.xml
@@ -50,6 +50,14 @@
Split session functionality of TlsClient
out into TlsSession
.
+
+
+
+
+
+ Use SocketSession
/TlsSession
for test servers.
+
+
diff --git a/src/common/io/socket/client.c b/src/common/io/socket/client.c
index 29cbb7c14e..7e664d5687 100644
--- a/src/common/io/socket/client.c
+++ b/src/common/io/socket/client.c
@@ -141,7 +141,7 @@ sckClientOpen(SocketClient *this)
// Create the session
MEM_CONTEXT_PRIOR_BEGIN()
{
- result = sckSessionNew(fd, this->host, this->port, this->timeout);
+ result = sckSessionNew(sckSessionTypeClient, fd, this->host, this->port, this->timeout);
}
MEM_CONTEXT_PRIOR_END();
diff --git a/src/common/io/socket/session.c b/src/common/io/socket/session.c
index ed05706cd0..86053b8ae6 100644
--- a/src/common/io/socket/session.c
+++ b/src/common/io/socket/session.c
@@ -23,6 +23,7 @@ Object type
struct SocketSession
{
MemContext *memContext; // Mem context
+ SocketSessionType type; // Type (server or client)
int fd; // File descriptor
String *host; // Hostname or IP address
unsigned int port; // Port to connect to host on
@@ -32,6 +33,7 @@ struct SocketSession
OBJECT_DEFINE_MOVE(SOCKET_SESSION);
OBJECT_DEFINE_GET(Fd, , SOCKET_SESSION, int, fd);
+OBJECT_DEFINE_GET(Type, const, SOCKET_SESSION, SocketSessionType, type);
OBJECT_DEFINE_FREE(SOCKET_SESSION);
@@ -46,9 +48,10 @@ OBJECT_DEFINE_FREE_RESOURCE_END(LOG);
/**********************************************************************************************************************************/
SocketSession *
-sckSessionNew(int fd, const String *host, unsigned int port, TimeMSec timeout)
+sckSessionNew(SocketSessionType type, int fd, const String *host, unsigned int port, TimeMSec timeout)
{
FUNCTION_LOG_BEGIN(logLevelDebug)
+ FUNCTION_LOG_PARAM(ENUM, type);
FUNCTION_LOG_PARAM(INT, fd);
FUNCTION_LOG_PARAM(STRING, host);
FUNCTION_LOG_PARAM(UINT, port);
@@ -67,6 +70,7 @@ sckSessionNew(int fd, const String *host, unsigned int port, TimeMSec timeout)
*this = (SocketSession)
{
.memContext = MEM_CONTEXT_NEW(),
+ .type = type,
.fd = fd,
.host = strDup(host),
.port = port,
@@ -122,5 +126,7 @@ sckSessionReadWait(SocketSession *this)
String *
sckSessionToLog(const SocketSession *this)
{
- return strNewFmt("{fd: %d, host: %s, port: %u, timeout: %" PRIu64 "}", this->fd, strPtr(this->host), this->port, this->timeout);
+ return strNewFmt(
+ "{type: %s, fd %d, host: %s, port: %u, timeout: %" PRIu64 "}", this->type == sckSessionTypeClient ? "client" : "server",
+ this->fd, strPtr(this->host), this->port, this->timeout);
}
diff --git a/src/common/io/socket/session.h b/src/common/io/socket/session.h
index e4ef0fe096..7a30e96cb5 100644
--- a/src/common/io/socket/session.h
+++ b/src/common/io/socket/session.h
@@ -8,6 +8,15 @@ Currently this is not a full-featured session and is only intended to isolate so
#ifndef COMMON_IO_SOCKET_SESSION_H
#define COMMON_IO_SOCKET_SESSION_H
+/***********************************************************************************************************************************
+Test result operations
+***********************************************************************************************************************************/
+typedef enum
+{
+ sckSessionTypeClient,
+ sckSessionTypeServer,
+} SocketSessionType;
+
/***********************************************************************************************************************************
Object type
***********************************************************************************************************************************/
@@ -22,7 +31,7 @@ typedef struct SocketSession SocketSession;
/***********************************************************************************************************************************
Constructors
***********************************************************************************************************************************/
-SocketSession *sckSessionNew(int fd, const String *host, unsigned int port, TimeMSec timeout);
+SocketSession *sckSessionNew(SocketSessionType type, int fd, const String *host, unsigned int port, TimeMSec timeout);
/***********************************************************************************************************************************
Functions
@@ -39,6 +48,9 @@ Getters/Setters
// Socket file descriptor
int sckSessionFd(SocketSession *this);
+// Socket type
+SocketSessionType sckSessionType(const SocketSession *this);
+
/***********************************************************************************************************************************
Destructor
***********************************************************************************************************************************/
diff --git a/src/common/io/tls/session.c b/src/common/io/tls/session.c
index b1ee08488a..a89b456f40 100644
--- a/src/common/io/tls/session.c
+++ b/src/common/io/tls/session.c
@@ -43,14 +43,13 @@ OBJECT_DEFINE_FREE_RESOURCE_BEGIN(TLS_SESSION, LOG, logLevelTrace)
}
OBJECT_DEFINE_FREE_RESOURCE_END(LOG);
-/***********************************************************************************************************************************
-Close the connection
-***********************************************************************************************************************************/
-static void
-tlsSessionClose(TlsSession *this)
+/**********************************************************************************************************************************/
+void
+tlsSessionClose(TlsSession *this, bool shutdown)
{
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_SESSION, this);
+ FUNCTION_LOG_PARAM(BOOL, shutdown);
FUNCTION_LOG_END();
ASSERT(this != NULL);
@@ -58,6 +57,10 @@ tlsSessionClose(TlsSession *this)
// If not already closed
if (this->session != NULL)
{
+ // Shutdown on request
+ if (shutdown)
+ SSL_shutdown(this->session);
+
// Free the socket session
sckSessionFree(this->socketSession);
this->socketSession = NULL;
@@ -89,7 +92,7 @@ tlsSessionError(TlsSession *this, int code)
// The connection was closed
case SSL_ERROR_ZERO_RETURN:
{
- tlsSessionClose(this);
+ tlsSessionClose(this, false);
break;
}
@@ -106,7 +109,7 @@ tlsSessionError(TlsSession *this, int code)
{
// Get the error before closing so it is not cleared
int errNo = errno;
- tlsSessionClose(this);
+ tlsSessionClose(this, false);
// Throw the sys error
THROW_SYS_ERROR_CODE(errNo, KernelError, "tls failed syscall");
@@ -282,12 +285,17 @@ tlsSessionNew(SSL *session, SocketSession *socketSession, TimeMSec timeout)
.timeout = timeout,
};
- // Initiate TLS connection
+ // Ensure session is freed
+ memContextCallbackSet(this->memContext, tlsSessionFreeResource, this);
+
+ // Negotiate TLS session
cryptoError(
SSL_set_fd(this->session, sckSessionFd(this->socketSession)) != 1, "unable to add socket to TLS session");
- cryptoError(SSL_connect(this->session) != 1, "unable to negotiate TLS connection");
- memContextCallbackSet(this->memContext, tlsSessionFreeResource, this);
+ if (sckSessionType(this->socketSession) == sckSessionTypeClient)
+ cryptoError(SSL_connect(this->session) != 1, "unable to negotiate client TLS session");
+ else
+ cryptoError(SSL_accept(this->session) != 1, "unable to negotiate server TLS session");
// Create read and write interfaces
this->write = ioWriteNewP(this, .write = tlsSessionWrite);
diff --git a/src/common/io/tls/session.h b/src/common/io/tls/session.h
index 32af6ea08d..acb335176b 100644
--- a/src/common/io/tls/session.h
+++ b/src/common/io/tls/session.h
@@ -25,6 +25,10 @@ typedef struct TlsSession TlsSession;
/***********************************************************************************************************************************
Functions
***********************************************************************************************************************************/
+// Close the session. Shutdown should not be attempted after an error, which means the client never has the oppottunity to do a
+// shutdown since the connection is held open until it is disconnected by the server.
+void tlsSessionClose(TlsSession *this, bool shutdown);
+
// Move to a new parent mem context
TlsSession *tlsSessionMove(TlsSession *this, MemContext *parentNew);
diff --git a/test/src/common/harnessTls.c b/test/src/common/harnessTls.c
index e974911d15..e02c83dbdb 100644
--- a/test/src/common/harnessTls.c
+++ b/test/src/common/harnessTls.c
@@ -11,6 +11,8 @@ Tls Test Harness
#include "common/crypto/common.h"
#include "common/error.h"
+#include "common/io/socket/session.h"
+#include "common/io/tls/session.intern.h"
#include "common/type/buffer.h"
#include "common/wait.h"
@@ -24,8 +26,7 @@ Test defaults
static int testServerSocket = 0;
static SSL_CTX *testServerContext = NULL;
-static int testClientSocket = 0;
-static SSL *testClientSSL = NULL;
+static TlsSession *testServerSession = NULL;
/***********************************************************************************************************************************
Initialize TLS and listen on the specified port for TLS connections
@@ -104,23 +105,13 @@ harnessTlsServerInitDefault(void)
/***********************************************************************************************************************************
Expect an exact string from the client
-
-This is a very unforgiving function and short input will leave the server hanging. Definitely room for improvement here.
***********************************************************************************************************************************/
void
harnessTlsServerExpect(const char *expected)
{
Buffer *buffer = bufNew(strlen(expected));
- int readBytes = 0;
- // Read expected bytes
- do
- {
- int lastBytes = SSL_read(testClientSSL, bufRemainsPtr(buffer), (int)bufRemains(buffer));
- readBytes += lastBytes;
- bufUsedSet(buffer, (size_t)readBytes);
- }
- while (bufRemains(buffer));
+ ioRead(tlsSessionIoRead(testServerSession), buffer);
// Treat and ? characters as wildcards so variable elements (e.g. auth hashes) can be ignored
String *actual = strNewBuf(buffer);
@@ -142,7 +133,8 @@ Send a reply to the client
void
harnessTlsServerReply(const char *reply)
{
- SSL_write(testClientSSL, reply, (int)strlen(reply));
+ ioWrite(tlsSessionIoWrite(testServerSession), BUF((unsigned char *)reply, strlen(reply)));
+ ioWriteFlush(tlsSessionIoWrite(testServerSession));
}
/***********************************************************************************************************************************
@@ -154,15 +146,15 @@ harnessTlsServerAccept(void)
struct sockaddr_in addr;
unsigned int len = sizeof(addr);
- testClientSocket = accept(testServerSocket, (struct sockaddr *)&addr, &len);
+ int testClientSocket = accept(testServerSocket, (struct sockaddr *)&addr, &len);
if (testClientSocket < 0)
THROW_SYS_ERROR(AssertError, "unable to accept socket");
- testClientSSL = SSL_new(testServerContext);
- SSL_set_fd(testClientSSL, testClientSocket);
+ SSL *testClientSSL = SSL_new(testServerContext);
- cryptoError(SSL_accept(testClientSSL) <= 0, "unable to accept TLS connection");
+ testServerSession = tlsSessionNew(
+ testClientSSL, sckSessionNew(sckSessionTypeServer, testClientSocket, STRDEF("client"), 0, 5000), 5000);
}
/***********************************************************************************************************************************
@@ -171,17 +163,18 @@ Close the connection
void
harnessTlsServerClose(void)
{
- SSL_shutdown(testClientSSL);
- SSL_free(testClientSSL);
- close(testClientSocket);
+ tlsSessionClose(testServerSession, true);
+ tlsSessionFree(testServerSession);
+ testServerSession = NULL;
}
/**********************************************************************************************************************************/
void
harnessTlsServerAbort(void)
{
- SSL_free(testClientSSL);
- close(testClientSocket);
+ tlsSessionClose(testServerSession, false);
+ tlsSessionFree(testServerSession);
+ testServerSession = NULL;
}
/**********************************************************************************************************************************/
diff --git a/test/src/module/common/ioTlsTest.c b/test/src/module/common/ioTlsTest.c
index 2eb8375c2c..5cf2f5a772 100644
--- a/test/src/module/common/ioTlsTest.c
+++ b/test/src/module/common/ioTlsTest.c
@@ -398,7 +398,7 @@ testRun(void)
TEST_RESULT_UINT(ioRead(tlsSessionIoRead(session), output), 0, "read no output after eof");
TEST_RESULT_BOOL(ioReadEof(tlsSessionIoRead(session)), true, " check eof = true");
- TEST_RESULT_VOID(tlsSessionClose(session), "close again");
+ TEST_RESULT_VOID(tlsSessionClose(session, false), "close again");
TEST_ERROR(tlsSessionError(session, SSL_ERROR_WANT_X509_LOOKUP), ServiceError, "tls error [4]");
// -----------------------------------------------------------------------------------------------------------------