Skip to content

Commit

Permalink
Allow multiplexing of SSL/Non-SSL connections on the same port
Browse files Browse the repository at this point in the history
  • Loading branch information
zaneb committed Sep 27, 2011
1 parent 12a23ce commit 7ed22e8
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 195 deletions.
77 changes: 54 additions & 23 deletions qpid/cpp/src/qpid/sys/Socket.h
Expand Up @@ -33,32 +33,20 @@ namespace sys {
class Duration;
class SocketAddress;

class Socket : public IOHandle
class GenericSocket: public IOHandle
{
public:
/** Create a socket wrapper for descriptor. */
QPID_COMMON_EXTERN Socket();

/** Set timeout for read and write */
void setTimeout(const Duration& interval) const;
GenericSocket();

/** Set socket non blocking */
void setNonblocking() const;
virtual void setNonblocking() const = 0;

QPID_COMMON_EXTERN void setTcpNoDelay() const;
virtual void setTcpNoDelay(bool nd) const = 0;

QPID_COMMON_EXTERN void connect(const std::string& host, uint16_t port) const;
QPID_COMMON_EXTERN void connect(const SocketAddress&) const;
virtual void connect(const std::string& host, uint16_t port) const = 0;

QPID_COMMON_EXTERN void close() const;

/** Bind to a port and start listening.
*@param port 0 means choose an available port.
*@param backlog maximum number of pending connections.
*@return The bound port.
*/
QPID_COMMON_EXTERN int listen(uint16_t port = 0, int backlog = 10) const;
QPID_COMMON_EXTERN int listen(const SocketAddress&, int backlog = 10) const;
virtual void close() const = 0;

/** Returns the "socket name" ie the address bound to
* the near end of the socket
Expand All @@ -82,9 +70,12 @@ class Socket : public IOHandle
QPID_COMMON_EXTERN std::string getLocalAddress() const;

/**
* Returns the full address of the connection: local and remote host and port.
* Returns the full address of the connection: local and remote host and
* port.
*/
QPID_COMMON_EXTERN std::string getFullAddress() const { return getLocalAddress()+"-"+getPeerAddress(); }
QPID_COMMON_EXTERN std::string getFullAddress() const {
return getLocalAddress() + "-" + getPeerAddress();
}

QPID_COMMON_EXTERN uint16_t getLocalPort() const;
uint16_t getRemotePort() const;
Expand All @@ -98,7 +89,49 @@ class Socket : public IOHandle
/** Accept a connection from a socket that is already listening
* and has an incoming connection
*/
QPID_COMMON_EXTERN Socket* accept() const;
virtual GenericSocket* accept() const = 0;

virtual int read(void *buf, size_t count) const = 0;
virtual int write(const void *buf, size_t count) const = 0;

protected:
GenericSocket(IOHandlePrivate *h);
mutable std::string connectname;
};

class Socket : public GenericSocket
{
public:
/** Create a socket wrapper for descriptor. */
QPID_COMMON_EXTERN Socket();
Socket(IOHandlePrivate*);

/** Set timeout for read and write */
void setTimeout(const Duration& interval) const;

/** Set socket non blocking */
void setNonblocking() const;

QPID_COMMON_EXTERN void setTcpNoDelay(bool nd=true) const;

QPID_COMMON_EXTERN void connect(const std::string& host, uint16_t port) const;
QPID_COMMON_EXTERN void connect(const SocketAddress&) const;

QPID_COMMON_EXTERN void close() const;

/** Bind to a port and start listening.
*@param port 0 means choose an available port.
*@param backlog maximum number of pending connections.
*@return The bound port.
*/
QPID_COMMON_EXTERN int listen(uint16_t port = 0, int backlog = 10) const;
QPID_COMMON_EXTERN int listen(const SocketAddress&, int backlog = 10) const;


/** Accept a connection from a socket that is already listening
* and has an incoming connection
*/
QPID_COMMON_EXTERN virtual GenericSocket* accept() const;

// TODO The following are raw operations, maybe they need better wrapping?
QPID_COMMON_EXTERN int read(void *buf, size_t count) const;
Expand All @@ -108,8 +141,6 @@ class Socket : public IOHandle
/** Create socket */
void createSocket(const SocketAddress&) const;

Socket(IOHandlePrivate*);
mutable std::string connectname;
mutable bool nonblocking;
mutable bool nodelay;
};
Expand Down
102 changes: 89 additions & 13 deletions qpid/cpp/src/qpid/sys/SslPlugin.cpp
Expand Up @@ -25,6 +25,8 @@
#include "qpid/sys/ssl/check.h"
#include "qpid/sys/ssl/util.h"
#include "qpid/sys/ssl/SslHandler.h"
#include "qpid/sys/AsynchIOHandler.h"
#include "qpid/sys/AsynchIO.h"
#include "qpid/sys/ssl/SslIo.h"
#include "qpid/sys/ssl/SslSocket.h"
#include "qpid/broker/Broker.h"
Expand Down Expand Up @@ -57,28 +59,40 @@ struct SslServerOptions : ssl::SslOptions
};

class SslProtocolFactory : public ProtocolFactory {
protected:
const bool tcpNoDelay;
qpid::sys::ssl::SslSocket listener;
qpid::sys::ssl::SslSocket *listener;
const uint16_t listeningPort;
std::auto_ptr<qpid::sys::ssl::SslAcceptor> acceptor;
bool nodict;
void established(Poller::shared_ptr, const qpid::sys::GenericSocket&, ConnectionCodec::Factory*,
bool isClient);
SslProtocolFactory(const SslServerOptions&, int backlog, bool nodelay, qpid::sys::ssl::SslSocket *l);

public:
SslProtocolFactory(const SslServerOptions&, int backlog, bool nodelay);
void accept(Poller::shared_ptr, ConnectionCodec::Factory*);
virtual ~SslProtocolFactory() { delete listener; }
virtual void accept(Poller::shared_ptr, ConnectionCodec::Factory*);
void connect(Poller::shared_ptr, const std::string& host, int16_t port,
ConnectionCodec::Factory*,
boost::function2<void, int, std::string> failed);

uint16_t getPort() const;
std::string getHost() const;
bool supports(const std::string& capability);
virtual bool supports(const std::string& capability);
};

class SslOptionalProtocolFactory : public SslProtocolFactory {
public:
SslOptionalProtocolFactory(const SslServerOptions& opts, int backlog, bool nodelay): SslProtocolFactory(opts, backlog, nodelay, new qpid::sys::ssl::SslOptionalSocket()) { }
virtual void accept(Poller::shared_ptr, ConnectionCodec::Factory*);
virtual bool supports(const std::string& capability);
private:
void established(Poller::shared_ptr, const qpid::sys::ssl::SslSocket&, ConnectionCodec::Factory*,
void established(Poller::shared_ptr, const qpid::sys::GenericSocket&, ConnectionCodec::Factory*,
bool isClient);
};


// Static instance to initialise plugin
static struct SslPlugin : public Plugin {
SslServerOptions options;
Expand All @@ -91,6 +105,7 @@ static struct SslPlugin : public Plugin {
}

void initialize(Target& target) {
QPID_LOG(notice, "Initialising SSL plugin");
broker::Broker* broker = dynamic_cast<broker::Broker*>(&target);
// Only provide to a Broker
if (broker) {
Expand All @@ -101,9 +116,14 @@ static struct SslPlugin : public Plugin {
ssl::initNSS(options, true);

const broker::Broker::Options& opts = broker->getOptions();
ProtocolFactory::shared_ptr protocol(new SslProtocolFactory(options,
opts.connectionBacklog,
opts.tcpNoDelay));
ProtocolFactory::shared_ptr protocol(0 /* TODO FIXME */ ?
new SslProtocolFactory(options,
opts.connectionBacklog,
opts.tcpNoDelay)
:
new SslOptionalProtocolFactory(options,
opts.connectionBacklog,
opts.tcpNoDelay));
QPID_LOG(notice, "Listening for SSL connections on TCP port " << protocol->getPort());
broker->registerProtocolFactory("ssl", protocol);
} catch (const std::exception& e) {
Expand All @@ -115,11 +135,16 @@ static struct SslPlugin : public Plugin {
} sslPlugin;

SslProtocolFactory::SslProtocolFactory(const SslServerOptions& options, int backlog, bool nodelay) :
tcpNoDelay(nodelay), listeningPort(listener.listen(options.port, backlog, options.certName, options.clientAuth)),
tcpNoDelay(nodelay), listener(new qpid::sys::ssl::SslSocket()), listeningPort(listener->listen(options.port, backlog, options.certName, options.clientAuth)),
nodict(options.nodict)
{}

void SslProtocolFactory::established(Poller::shared_ptr poller, const qpid::sys::ssl::SslSocket& s,
SslProtocolFactory::SslProtocolFactory(const SslServerOptions& options, int backlog, bool nodelay, qpid::sys::ssl::SslSocket *l) :
tcpNoDelay(nodelay), listener(l), listeningPort(listener->listen(options.port, backlog, options.certName, options.clientAuth)),
nodict(options.nodict)
{}

void SslProtocolFactory::established(Poller::shared_ptr poller, const qpid::sys::GenericSocket& s,
ConnectionCodec::Factory* f, bool isClient) {
qpid::sys::ssl::SslHandler* async = new qpid::sys::ssl::SslHandler(s.getFullAddress(), f, nodict);

Expand All @@ -128,9 +153,13 @@ void SslProtocolFactory::established(Poller::shared_ptr poller, const qpid::sys:
QPID_LOG(info, "Set TCP_NODELAY on connection to " << s.getPeerAddress());
}

if (isClient)
if (isClient) {
async->setClient();
qpid::sys::ssl::SslIO* aio = new qpid::sys::ssl::SslIO(s,
}

const qpid::sys::ssl::SslSocket *sslSock = dynamic_cast<const qpid::sys::ssl::SslSocket *>(&s);

qpid::sys::ssl::SslIO* aio = new qpid::sys::ssl::SslIO(*sslSock,
boost::bind(&qpid::sys::ssl::SslHandler::readbuff, async, _1, _2),
boost::bind(&qpid::sys::ssl::SslHandler::eof, async, _1),
boost::bind(&qpid::sys::ssl::SslHandler::disconnect, async, _1),
Expand All @@ -147,17 +176,48 @@ uint16_t SslProtocolFactory::getPort() const {
}

std::string SslProtocolFactory::getHost() const {
return listener.getSockname();
return listener->getSockname();
}

void SslProtocolFactory::accept(Poller::shared_ptr poller,
ConnectionCodec::Factory* fact) {
acceptor.reset(
new qpid::sys::ssl::SslAcceptor(listener,
new qpid::sys::ssl::SslAcceptor(*listener,
boost::bind(&SslProtocolFactory::established, this, poller, _1, fact, false)));
acceptor->start(poller);
}

void SslOptionalProtocolFactory::established(Poller::shared_ptr poller, const qpid::sys::GenericSocket& s,
ConnectionCodec::Factory* f, bool isClient) {
const qpid::sys::Socket *plainSock = dynamic_cast<const qpid::sys::Socket*>(&s);

if (plainSock) {
AsynchIOHandler* async = new AsynchIOHandler(plainSock->getFullAddress(), f);

if (tcpNoDelay) {
plainSock->setTcpNoDelay();
QPID_LOG(info, "Set TCP_NODELAY on connection to " << plainSock->getPeerAddress());
}

if (isClient) {
async->setClient();
}
AsynchIO* aio = AsynchIO::create
(*plainSock,
boost::bind(&AsynchIOHandler::readbuff, async, _1, _2),
boost::bind(&AsynchIOHandler::eof, async, _1),
boost::bind(&AsynchIOHandler::disconnect, async, _1),
boost::bind(&AsynchIOHandler::closedSocket, async, _1, _2),
boost::bind(&AsynchIOHandler::nobuffs, async, _1),
boost::bind(&AsynchIOHandler::idle, async, _1));

async->init(aio, 4);
aio->start(poller);
} else {
SslProtocolFactory::established(poller, s, f, isClient);
}
}

void SslProtocolFactory::connect(
Poller::shared_ptr poller,
const std::string& host, int16_t port,
Expand Down Expand Up @@ -188,4 +248,20 @@ bool SslProtocolFactory::supports(const std::string& capability)
return s == SSL;
}

void SslOptionalProtocolFactory::accept(Poller::shared_ptr poller,
ConnectionCodec::Factory* fact) {
acceptor.reset(
new qpid::sys::ssl::SslAcceptor(*listener,
boost::bind(&SslOptionalProtocolFactory::established, this, poller, _1, fact, false)));
acceptor->start(poller);
}


bool SslOptionalProtocolFactory::supports(const std::string& capability)
{
std::string s = capability;
transform(s.begin(), s.end(), s.begin(), tolower);
return s == SSL || s == "tcp";
}

}} // namespace qpid::sys
2 changes: 1 addition & 1 deletion qpid/cpp/src/qpid/sys/posix/AsynchIO.cpp
Expand Up @@ -118,7 +118,7 @@ void AsynchAcceptor::readable(DispatchHandle& h) {
// TODO: Currently we ignore the peers address, perhaps we should
// log it or use it for connection acceptance.
try {
s = socket.accept();
s = dynamic_cast<Socket*>(socket.accept());
if (s) {
acceptedCallback(*s);
} else {
Expand Down

0 comments on commit 7ed22e8

Please sign in to comment.