Skip to content

Commit

Permalink
Use SocketSession/TlsSession for test servers.
Browse files Browse the repository at this point in the history
A session looks much the same whether it is initiated from the client or the server, so use the session objects to implement the TLS, HTTP, and S3 test servers.

For TLS, at least, there are some differences between client and server sessions so add a client/server type to SocketSession to determine how the session was initiated.

Aside from reducing code duplication, the main advantage is that the test server will now timeout rather than hanging indefinitely when less input that expected is received.
  • Loading branch information
dwsteele committed Apr 14, 2020
1 parent 71fb28b commit 9ffa2c6
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 38 deletions.
8 changes: 8 additions & 0 deletions doc/xml/release.xml
Expand Up @@ -50,6 +50,14 @@
<p>Split session functionality of <code>TlsClient</code> out into <code>TlsSession</code>.</p>
</release-item>

<release-item>
<release-item-contributor-list>
<release-item-reviewer id="cynthia.shang"/>
</release-item-contributor-list>

<p>Use <code>SocketSession</code>/<code>TlsSession</code> for test servers.</p>
</release-item>

<release-item>
<release-item-contributor-list>
<release-item-reviewer id="cynthia.shang"/>
Expand Down
2 changes: 1 addition & 1 deletion src/common/io/socket/client.c
Expand Up @@ -141,7 +141,7 @@ sckClientOpen(SocketClient *this)
// Create the session
MEM_CONTEXT_PRIOR_BEGIN()
{
result = sckSessionNew(fd, this->host, this->port, this->timeout);
result = sckSessionNew(sckSessionTypeClient, fd, this->host, this->port, this->timeout);
}
MEM_CONTEXT_PRIOR_END();

Expand Down
10 changes: 8 additions & 2 deletions src/common/io/socket/session.c
Expand Up @@ -23,6 +23,7 @@ Object type
struct SocketSession
{
MemContext *memContext; // Mem context
SocketSessionType type; // Type (server or client)
int fd; // File descriptor
String *host; // Hostname or IP address
unsigned int port; // Port to connect to host on
Expand All @@ -32,6 +33,7 @@ struct SocketSession
OBJECT_DEFINE_MOVE(SOCKET_SESSION);

OBJECT_DEFINE_GET(Fd, , SOCKET_SESSION, int, fd);
OBJECT_DEFINE_GET(Type, const, SOCKET_SESSION, SocketSessionType, type);

OBJECT_DEFINE_FREE(SOCKET_SESSION);

Expand All @@ -46,9 +48,10 @@ OBJECT_DEFINE_FREE_RESOURCE_END(LOG);

/**********************************************************************************************************************************/
SocketSession *
sckSessionNew(int fd, const String *host, unsigned int port, TimeMSec timeout)
sckSessionNew(SocketSessionType type, int fd, const String *host, unsigned int port, TimeMSec timeout)
{
FUNCTION_LOG_BEGIN(logLevelDebug)
FUNCTION_LOG_PARAM(ENUM, type);
FUNCTION_LOG_PARAM(INT, fd);
FUNCTION_LOG_PARAM(STRING, host);
FUNCTION_LOG_PARAM(UINT, port);
Expand All @@ -67,6 +70,7 @@ sckSessionNew(int fd, const String *host, unsigned int port, TimeMSec timeout)
*this = (SocketSession)
{
.memContext = MEM_CONTEXT_NEW(),
.type = type,
.fd = fd,
.host = strDup(host),
.port = port,
Expand Down Expand Up @@ -122,5 +126,7 @@ sckSessionReadWait(SocketSession *this)
String *
sckSessionToLog(const SocketSession *this)
{
return strNewFmt("{fd: %d, host: %s, port: %u, timeout: %" PRIu64 "}", this->fd, strPtr(this->host), this->port, this->timeout);
return strNewFmt(
"{type: %s, fd %d, host: %s, port: %u, timeout: %" PRIu64 "}", this->type == sckSessionTypeClient ? "client" : "server",
this->fd, strPtr(this->host), this->port, this->timeout);
}
14 changes: 13 additions & 1 deletion src/common/io/socket/session.h
Expand Up @@ -8,6 +8,15 @@ Currently this is not a full-featured session and is only intended to isolate so
#ifndef COMMON_IO_SOCKET_SESSION_H
#define COMMON_IO_SOCKET_SESSION_H

/***********************************************************************************************************************************
Test result operations
***********************************************************************************************************************************/
typedef enum
{
sckSessionTypeClient,
sckSessionTypeServer,
} SocketSessionType;

/***********************************************************************************************************************************
Object type
***********************************************************************************************************************************/
Expand All @@ -22,7 +31,7 @@ typedef struct SocketSession SocketSession;
/***********************************************************************************************************************************
Constructors
***********************************************************************************************************************************/
SocketSession *sckSessionNew(int fd, const String *host, unsigned int port, TimeMSec timeout);
SocketSession *sckSessionNew(SocketSessionType type, int fd, const String *host, unsigned int port, TimeMSec timeout);

/***********************************************************************************************************************************
Functions
Expand All @@ -39,6 +48,9 @@ Getters/Setters
// Socket file descriptor
int sckSessionFd(SocketSession *this);

// Socket type
SocketSessionType sckSessionType(const SocketSession *this);

/***********************************************************************************************************************************
Destructor
***********************************************************************************************************************************/
Expand Down
28 changes: 18 additions & 10 deletions src/common/io/tls/session.c
Expand Up @@ -43,21 +43,24 @@ OBJECT_DEFINE_FREE_RESOURCE_BEGIN(TLS_SESSION, LOG, logLevelTrace)
}
OBJECT_DEFINE_FREE_RESOURCE_END(LOG);

/***********************************************************************************************************************************
Close the connection
***********************************************************************************************************************************/
static void
tlsSessionClose(TlsSession *this)
/**********************************************************************************************************************************/
void
tlsSessionClose(TlsSession *this, bool shutdown)
{
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_SESSION, this);
FUNCTION_LOG_PARAM(BOOL, shutdown);
FUNCTION_LOG_END();

ASSERT(this != NULL);

// If not already closed
if (this->session != NULL)
{
// Shutdown on request
if (shutdown)
SSL_shutdown(this->session);

// Free the socket session
sckSessionFree(this->socketSession);
this->socketSession = NULL;
Expand Down Expand Up @@ -89,7 +92,7 @@ tlsSessionError(TlsSession *this, int code)
// The connection was closed
case SSL_ERROR_ZERO_RETURN:
{
tlsSessionClose(this);
tlsSessionClose(this, false);
break;
}

Expand All @@ -106,7 +109,7 @@ tlsSessionError(TlsSession *this, int code)
{
// Get the error before closing so it is not cleared
int errNo = errno;
tlsSessionClose(this);
tlsSessionClose(this, false);

// Throw the sys error
THROW_SYS_ERROR_CODE(errNo, KernelError, "tls failed syscall");
Expand Down Expand Up @@ -282,12 +285,17 @@ tlsSessionNew(SSL *session, SocketSession *socketSession, TimeMSec timeout)
.timeout = timeout,
};

// Initiate TLS connection
// Ensure session is freed
memContextCallbackSet(this->memContext, tlsSessionFreeResource, this);

// Negotiate TLS session
cryptoError(
SSL_set_fd(this->session, sckSessionFd(this->socketSession)) != 1, "unable to add socket to TLS session");
cryptoError(SSL_connect(this->session) != 1, "unable to negotiate TLS connection");

memContextCallbackSet(this->memContext, tlsSessionFreeResource, this);
if (sckSessionType(this->socketSession) == sckSessionTypeClient)
cryptoError(SSL_connect(this->session) != 1, "unable to negotiate client TLS session");
else
cryptoError(SSL_accept(this->session) != 1, "unable to negotiate server TLS session");

// Create read and write interfaces
this->write = ioWriteNewP(this, .write = tlsSessionWrite);
Expand Down
4 changes: 4 additions & 0 deletions src/common/io/tls/session.h
Expand Up @@ -25,6 +25,10 @@ typedef struct TlsSession TlsSession;
/***********************************************************************************************************************************
Functions
***********************************************************************************************************************************/
// Close the session. Shutdown should not be attempted after an error, which means the client never has the oppottunity to do a
// shutdown since the connection is held open until it is disconnected by the server.
void tlsSessionClose(TlsSession *this, bool shutdown);

// Move to a new parent mem context
TlsSession *tlsSessionMove(TlsSession *this, MemContext *parentNew);

Expand Down
39 changes: 16 additions & 23 deletions test/src/common/harnessTls.c
Expand Up @@ -11,6 +11,8 @@ Tls Test Harness

#include "common/crypto/common.h"
#include "common/error.h"
#include "common/io/socket/session.h"
#include "common/io/tls/session.intern.h"
#include "common/type/buffer.h"
#include "common/wait.h"

Expand All @@ -24,8 +26,7 @@ Test defaults

static int testServerSocket = 0;
static SSL_CTX *testServerContext = NULL;
static int testClientSocket = 0;
static SSL *testClientSSL = NULL;
static TlsSession *testServerSession = NULL;

/***********************************************************************************************************************************
Initialize TLS and listen on the specified port for TLS connections
Expand Down Expand Up @@ -104,23 +105,13 @@ harnessTlsServerInitDefault(void)

/***********************************************************************************************************************************
Expect an exact string from the client
This is a very unforgiving function and short input will leave the server hanging. Definitely room for improvement here.
***********************************************************************************************************************************/
void
harnessTlsServerExpect(const char *expected)
{
Buffer *buffer = bufNew(strlen(expected));
int readBytes = 0;

// Read expected bytes
do
{
int lastBytes = SSL_read(testClientSSL, bufRemainsPtr(buffer), (int)bufRemains(buffer));
readBytes += lastBytes;
bufUsedSet(buffer, (size_t)readBytes);
}
while (bufRemains(buffer));
ioRead(tlsSessionIoRead(testServerSession), buffer);

// Treat and ? characters as wildcards so variable elements (e.g. auth hashes) can be ignored
String *actual = strNewBuf(buffer);
Expand All @@ -142,7 +133,8 @@ Send a reply to the client
void
harnessTlsServerReply(const char *reply)
{
SSL_write(testClientSSL, reply, (int)strlen(reply));
ioWrite(tlsSessionIoWrite(testServerSession), BUF((unsigned char *)reply, strlen(reply)));
ioWriteFlush(tlsSessionIoWrite(testServerSession));
}

/***********************************************************************************************************************************
Expand All @@ -154,15 +146,15 @@ harnessTlsServerAccept(void)
struct sockaddr_in addr;
unsigned int len = sizeof(addr);

testClientSocket = accept(testServerSocket, (struct sockaddr *)&addr, &len);
int testClientSocket = accept(testServerSocket, (struct sockaddr *)&addr, &len);

if (testClientSocket < 0)
THROW_SYS_ERROR(AssertError, "unable to accept socket");

testClientSSL = SSL_new(testServerContext);
SSL_set_fd(testClientSSL, testClientSocket);
SSL *testClientSSL = SSL_new(testServerContext);

cryptoError(SSL_accept(testClientSSL) <= 0, "unable to accept TLS connection");
testServerSession = tlsSessionNew(
testClientSSL, sckSessionNew(sckSessionTypeServer, testClientSocket, STRDEF("client"), 0, 5000), 5000);
}

/***********************************************************************************************************************************
Expand All @@ -171,17 +163,18 @@ Close the connection
void
harnessTlsServerClose(void)
{
SSL_shutdown(testClientSSL);
SSL_free(testClientSSL);
close(testClientSocket);
tlsSessionClose(testServerSession, true);
tlsSessionFree(testServerSession);
testServerSession = NULL;
}

/**********************************************************************************************************************************/
void
harnessTlsServerAbort(void)
{
SSL_free(testClientSSL);
close(testClientSocket);
tlsSessionClose(testServerSession, false);
tlsSessionFree(testServerSession);
testServerSession = NULL;
}

/**********************************************************************************************************************************/
Expand Down
2 changes: 1 addition & 1 deletion test/src/module/common/ioTlsTest.c
Expand Up @@ -398,7 +398,7 @@ testRun(void)
TEST_RESULT_UINT(ioRead(tlsSessionIoRead(session), output), 0, "read no output after eof");
TEST_RESULT_BOOL(ioReadEof(tlsSessionIoRead(session)), true, " check eof = true");

TEST_RESULT_VOID(tlsSessionClose(session), "close again");
TEST_RESULT_VOID(tlsSessionClose(session, false), "close again");
TEST_ERROR(tlsSessionError(session, SSL_ERROR_WANT_X509_LOOKUP), ServiceError, "tls error [4]");

// -----------------------------------------------------------------------------------------------------------------
Expand Down

0 comments on commit 9ffa2c6

Please sign in to comment.