Skip to content

Commit

Permalink
[TLS] Provide thread-safety when required to do so.
Browse files Browse the repository at this point in the history
  • Loading branch information
abh3 committed Jan 29, 2021
1 parent d185a99 commit 2431a27
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/Xrd/XrdLinkXeq.cc
Expand Up @@ -913,7 +913,7 @@ bool XrdLinkXeq::setTLS(bool enable, XrdTlsContext *ctx)
// We want to initialize TLS, do so now.
//
if (!ctx) ctx = tlsCtx;
eNote = tlsIO.Init(*ctx, PollInfo.FD, rwMode, hsMode, false, ID);
eNote = tlsIO.Init(*ctx, PollInfo.FD, rwMode, hsMode, false, false, ID);

// Check for errors
//
Expand Down
84 changes: 77 additions & 7 deletions src/XrdTls/XrdTlsSocket.cc
Expand Up @@ -31,6 +31,7 @@

#include "XrdNet/XrdNetAddrInfo.hh"
#include "XrdSys/XrdSysE2T.hh"
#include "XrdSys/XrdSysPthread.hh"
#include "XrdTls/XrdTlsContext.hh"
#include "XrdTls/XrdTlsNotary.hh"
#include "XrdTls/XrdTlsPeerCerts.hh"
Expand All @@ -49,6 +50,7 @@ struct XrdTlsSocketImpl
hsDone(false), fatal(0), isClient(false),
cOpts(0), cAttr(0), hsNoBlock(false) {}

XrdSysMutex sslMutex; //!< Mutex to serialize calls
XrdTlsContext *tlsctx; //!< Associated context object
SSL *ssl; //!< Associated SSL object
const char *traceID; //!< Trace identifier
Expand All @@ -60,6 +62,7 @@ struct XrdTlsSocketImpl
char cOpts; //!< Connection options
char cAttr; //!< Connection attributes
bool hsNoBlock; //!< Handshake handling nonblocking if true
bool isSerial; //!< True if calls must be serialized
};

/******************************************************************************/
Expand Down Expand Up @@ -131,13 +134,13 @@ XrdTlsSocket::XrdTlsSocket() : pImpl( new XrdTlsSocketImpl() )
XrdTlsSocket::XrdTlsSocket( XrdTlsContext &ctx, int sfd,
XrdTlsSocket::RW_Mode rwm,
XrdTlsSocket::HS_Mode hsm,
bool isClient )
bool isClient, bool serial )
: pImpl( new XrdTlsSocketImpl() )
{

// Simply initialize this object and throw an exception if it fails
//
const char *eMsg = Init(ctx, sfd, rwm, hsm, isClient);
const char *eMsg = Init(ctx, sfd, rwm, hsm, isClient, serial);
if (eMsg) throw std::invalid_argument( eMsg );
}

Expand Down Expand Up @@ -391,6 +394,12 @@ std::string XrdTlsSocket::Err2Text(int sslerr)

XrdTlsPeerCerts *XrdTlsSocket::getCerts(bool ver)
{
XrdSysMutexHelper mHelper;

// Serialize call if need be
//
if (pImpl->isSerial) mHelper.Lock(&(pImpl->sslMutex));

// If verified certs need to be returned, make sure the certs are verified
//
if (ver && SSL_get_verify_result(pImpl->ssl) != X509_V_OK) return 0;
Expand All @@ -409,7 +418,8 @@ XrdTlsPeerCerts *XrdTlsSocket::getCerts(bool ver)
const char *XrdTlsSocket::Init( XrdTlsContext &ctx, int sfd,
XrdTlsSocket::RW_Mode rwm,
XrdTlsSocket::HS_Mode hsm,
bool isClient, const char *tid )
bool isClient, bool serial,
const char *tid )
{
BIO *rbio, *wbio = 0;

Expand Down Expand Up @@ -438,6 +448,7 @@ const char *XrdTlsSocket::Init( XrdTlsContext &ctx, int sfd,
if (parms->opts & XrdTlsContext::dnsok) pImpl->cOpts |= DNSok;
pImpl->traceID = tid;
pImpl->isClient= isClient;
pImpl->isSerial= serial;

// Set the ssl object state to correspond to client or server type
//
Expand Down Expand Up @@ -518,7 +529,14 @@ const char *XrdTlsSocket::Init( XrdTlsContext &ctx, int sfd,

XrdTls::RC XrdTlsSocket::Peek( char *buffer, size_t size, int &bytesPeek )
{
int ssler;
XrdSysMutexHelper mHelper;
int ssler;

//------------------------------------------------------------------------
// Serialize call if need be
//------------------------------------------------------------------------

if (pImpl->isSerial) mHelper.Lock(&(pImpl->sslMutex));

//------------------------------------------------------------------------
// Return an error if this socket received a fatal error as OpenSSL will
Expand Down Expand Up @@ -555,7 +573,7 @@ XrdTls::RC XrdTlsSocket::Peek( char *buffer, size_t size, int &bytesPeek )

// If the caller is non-blocking, the return the issue. Otherwise, block.
//
if ((pImpl->hsNoBlock && NeedHandShake()) || !(pImpl->cAttr & rBlocking))
if ((pImpl->hsNoBlock && NeedHS()) || !(pImpl->cAttr & rBlocking))
return XrdTls::ssl2RC(ssler);

} while(Wait4OK(ssler == SSL_ERROR_WANT_READ));
Expand All @@ -571,13 +589,21 @@ XrdTls::RC XrdTlsSocket::Peek( char *buffer, size_t size, int &bytesPeek )

int XrdTlsSocket::Pending(bool any)
{
XrdSysMutexHelper mHelper;

//------------------------------------------------------------------------
// Return an error if this socket received a fatal error as OpenSSL will
// SEGV when called after such an error. So, return something reasonable.
//------------------------------------------------------------------------

if (pImpl->fatal) return 0;

//------------------------------------------------------------------------
// Serialize call if need be
//------------------------------------------------------------------------

if (pImpl->isSerial) mHelper.Lock(&(pImpl->sslMutex));

if (!any) return SSL_pending(pImpl->ssl);
#if OPENSSL_VERSION_NUMBER < 0x10100000L
return SSL_pending(pImpl->ssl) != 0;
Expand All @@ -593,8 +619,15 @@ int XrdTlsSocket::Pending(bool any)
XrdTls::RC XrdTlsSocket::Read( char *buffer, size_t size, int &bytesRead )
{
EPNAME("Read");
XrdSysMutexHelper mHelper;
int ssler;

//------------------------------------------------------------------------
// Serialize call if need be
//------------------------------------------------------------------------

if (pImpl->isSerial) mHelper.Lock(&(pImpl->sslMutex));

//------------------------------------------------------------------------
// Return an error if this socket received a fatal error as OpenSSL will
// SEGV when called after such an error.
Expand Down Expand Up @@ -636,7 +669,7 @@ XrdTls::RC XrdTlsSocket::Read( char *buffer, size_t size, int &bytesRead )
// If the caller is non-blocking for reads, return the issue. Otherwise,
// block for the caller.
//
if ((pImpl->hsNoBlock && NeedHandShake()) || !(pImpl->cAttr & rBlocking))
if ((pImpl->hsNoBlock && NeedHS()) || !(pImpl->cAttr & rBlocking))
return XrdTls::ssl2RC(ssler);

// Wait until we can read again.
Expand All @@ -662,13 +695,18 @@ void XrdTlsSocket::SetTraceID(const char *tid)
void XrdTlsSocket::Shutdown(XrdTlsSocket::SDType sdType)
{
EPNAME("Shutdown");
XrdSysMutexHelper mHelper;
const char *how;
int sdMode, rc;

// Make sure we have an ssl object.
//
if (pImpl->ssl == 0) return;

// While we do not need to technically serialize here, we're being conservative
//
if (pImpl->isSerial) mHelper.Lock(&(pImpl->sslMutex));

// Perform shutdown as needed. This is required before freeing the ssl object.
// If we previously encountered a SYSCALL or SSL error, shutdown is prohibited!
// The following code is patterned after code in the public TomCat server.
Expand Down Expand Up @@ -726,8 +764,15 @@ XrdTls::RC XrdTlsSocket::Write( const char *buffer, size_t size,
int &bytesWritten )
{
EPNAME("Write");
XrdSysMutexHelper mHelper;
int ssler;

//------------------------------------------------------------------------
// Serialize call if need be
//------------------------------------------------------------------------

if (pImpl->isSerial) mHelper.Lock(&(pImpl->sslMutex));

//------------------------------------------------------------------------
// Return an error if this socket received a fatal error as OpenSSL will
// SEGV when called after such an error.
Expand Down Expand Up @@ -769,7 +814,7 @@ XrdTls::RC XrdTlsSocket::Write( const char *buffer, size_t size,
// If the caller is non-blocking for reads, return the issue. Otherwise,
// block for the caller.
//
if ((pImpl->hsNoBlock && NeedHandShake()) || !(pImpl->cAttr & wBlocking))
if ((pImpl->hsNoBlock && NeedHS()) || !(pImpl->cAttr & wBlocking))
return XrdTls::ssl2RC(ssler);

// Wait unil the write can get restarted
Expand All @@ -785,23 +830,48 @@ XrdTls::RC XrdTlsSocket::Write( const char *buffer, size_t size,

bool XrdTlsSocket::NeedHandShake()
{
XrdSysMutexHelper mHelper;

//------------------------------------------------------------------------
// Return an error if this socket received a fatal error as OpenSSL will
// SEGV when called after such an error. So, return something reasonable.
// Technically, we don't need to serialize this because nothing get
// modified. We do so anyway out of abundance of caution.
//------------------------------------------------------------------------

if (pImpl->isSerial) mHelper.Lock(&(pImpl->sslMutex));
if (pImpl->fatal) return false;
pImpl->hsDone = bool( SSL_is_init_finished( pImpl->ssl ) );
return !pImpl->hsDone;
}

/******************************************************************************/
/* Private: N e e d H S */
/******************************************************************************/

bool XrdTlsSocket::NeedHS()
{
//------------------------------------------------------------------------
// The following code is identical to NeedHandshake() except that it does
// serialize the call because the caller already has done so. While we
// could use a recursive mutex the overhead in doing so is not worth it
// and it is only used for internal purposes.
//------------------------------------------------------------------------

if (pImpl->fatal) return false;
pImpl->hsDone = bool( SSL_is_init_finished( pImpl->ssl ) );
return !pImpl->hsDone;
}

/******************************************************************************/
/* V e r s i o n */
/******************************************************************************/

const char *XrdTlsSocket::Version()
{
// This call modifies nothing nor does it depend on modified data once the
// connection is esablished and doesn't need serialization.
//
return SSL_get_version(pImpl->ssl);
}

Expand Down
13 changes: 10 additions & 3 deletions src/XrdTls/XrdTlsSocket.hh
Expand Up @@ -70,10 +70,13 @@ enum HS_Mode
//! read/write calls should be handled.
//! @param isClient - When true initialize for client use.
//! Otherwise, initialize for server use.
//! @param serial - When true, only allows one thread to use the socket
//! at a time to prevent SSL errors (default). When false
//! does not add this protection, assuming caller does so.
//------------------------------------------------------------------------

XrdTlsSocket( XrdTlsContext &ctx, int sfd, RW_Mode rwm,
HS_Mode hsm, bool isClient );
XrdTlsSocket( XrdTlsContext &ctx, int sfd, RW_Mode rwm, HS_Mode hsm,
bool isClient, bool serial=true );

//------------------------------------------------------------------------
//! Constructor - reserves space for a TLS I/O wrapper. Use the Init()
Expand Down Expand Up @@ -147,6 +150,9 @@ XrdTlsPeerCerts *getCerts(bool ver=true);
//! read/write calls should be handled.
//! @param isClient - When true initialize for client use.
//! Otherwise, initialize for server use.
//! @param serial - When true, only allows one thread to use the socket
//! at a time to prevent SSL errors (default). When false
//! does not add this protection, assuming caller does so.
//! @param tid - Trace identifier to appear in messages. The value must
//! have the same lifetime as this object.
//!
Expand All @@ -157,7 +163,7 @@ XrdTlsPeerCerts *getCerts(bool ver=true);
//------------------------------------------------------------------------

const char *Init( XrdTlsContext &ctx, int sfd, RW_Mode rwm, HS_Mode hsm,
bool isClient, const char *tid="" );
bool isClient, bool serial=true, const char *tid="" );

//------------------------------------------------------------------------
//! Peek at the TLS connection data. If necessary, a handshake will be done.
Expand Down Expand Up @@ -251,6 +257,7 @@ private:
void AcceptEMsg(std::string *eWhy, const char *reason);
int Diagnose(const char *what, int sslrc, int tcode);
std::string Err2Text(int sslerr);
bool NeedHS();
bool Wait4OK(bool wantRead);

XrdTlsSocketImpl *pImpl;
Expand Down

0 comments on commit 2431a27

Please sign in to comment.