Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 41 additions & 29 deletions src/realmd/AuthSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "AuthCodes.h"
#include "PatchHandler.h"
#include "Util.h"
#include "IO/Timer/AsyncSystemTimer.h"

#ifdef USE_SENDGRID
#include "MailerService.h"
Expand Down Expand Up @@ -161,7 +162,6 @@ typedef struct AUTH_LOGON_PROOF_S

typedef struct AUTH_RECONNECT_PROOF_C
{
//uint8 cmd;
uint8 R1[16];
uint8 R2[20];
uint8 R3[20];
Expand Down Expand Up @@ -194,16 +194,33 @@ typedef struct AuthHandler
std::array<uint8, 16> VersionChallenge = { { 0xBA, 0xA3, 0x1E, 0x99, 0xA0, 0x0B, 0x21, 0x57, 0xFC, 0x37, 0x3F, 0xB3, 0x69, 0xCD, 0xD2, 0xF1 } };

// Accept the connection and set the s random value for SRP6 // TODO where is this SRP6 done?
AuthSocket::AuthSocket(SocketDescriptor const& socketDescriptor) : MaNGOS::AsyncSocket<AuthSocket>(socketDescriptor)
AuthSocket::AuthSocket(IO::Networking::SocketDescriptor const& socketDescriptor) : IO::Networking::AsyncSocket<AuthSocket>(socketDescriptor)
{
sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Accepting connection from '%s'", socketDescriptor.peerAddress.c_str());
}

void AuthSocket::Start()
{
if (int secs = sConfig.GetIntDefault("MaxSessionDuration", 300))
{
this->m_sessionDurationTimeout = sAsyncSystemTimer.ScheduleFunctionOnce(std::chrono::seconds(secs), [this]()
{
sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Connection has reached MaxSessionDuration. Closing socket...");
// It's correct that we capture _this_, since the timer will be canceled in destructor
this->CloseSocket();
});
}
ProcessIncomingData();
}

// Close patch file descriptor before leaving
AuthSocket::~AuthSocket()
{
if (m_patch != ACE_INVALID_HANDLE)
ACE_OS::close(m_patch);

if (m_sessionDurationTimeout)
m_sessionDurationTimeout->Cancel();
}

AccountTypes AuthSocket::GetSecurityOn(uint32 realmId) const
Expand All @@ -220,11 +237,11 @@ void AuthSocket::ProcessIncomingData()
std::shared_ptr<eAuthCmd> cmd = std::make_shared<eAuthCmd>();

sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "ProcessIncomingData() Reading... Ready for next opcode");
Read((char*)cmd.get(), sizeof(eAuthCmd), [self = shared_from_this(), cmd](MaNGOS::IO::NetworkError const& error) -> void
Read((char*)cmd.get(), sizeof(eAuthCmd), [self = shared_from_this(), cmd](IO::NetworkError const& error) -> void
{
if (error)
{
if (error.Error != MaNGOS::IO::NetworkError::ErrorType::SocketClosed)
if (error.Error != IO::NetworkError::ErrorType::SocketClosed)
sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[Auth] ProcessIncomingData Read(cmd) error");
return;
}
Expand Down Expand Up @@ -279,11 +296,6 @@ void AuthSocket::ProcessIncomingData()
});
}

void AuthSocket::Start()
{
ProcessIncomingData();
}

std::shared_ptr<ByteBuffer> AuthSocket::GenerateLogonProofResponse(Sha1Hash sha)
{
std::shared_ptr<ByteBuffer> pkt(new ByteBuffer());
Expand Down Expand Up @@ -334,7 +346,7 @@ void AuthSocket::_HandleLogonChallenge()
std::shared_ptr<sAuthLogonChallengeHeader> header = std::make_shared<sAuthLogonChallengeHeader>();

// Read the header first, to get the length of the remaining packet
Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](MaNGOS::IO::NetworkError const& error) -> void
Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](IO::NetworkError const& error) -> void
{
if (error)
{
Expand All @@ -356,7 +368,7 @@ void AuthSocket::_HandleLogonChallenge()

// Read the remaining of the packet
std::shared_ptr<sAuthLogonChallengeBody> body = std::make_shared<sAuthLogonChallengeBody>();
self->Read((char*)body.get(), actualBodySize, [self, header, body](MaNGOS::IO::NetworkError const& error)
self->Read((char*)body.get(), actualBodySize, [self, header, body](IO::NetworkError const& error)
{
if (error)
{
Expand Down Expand Up @@ -445,7 +457,7 @@ void AuthSocket::_HandleLogonChallenge()
sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s 'email address requires email verification - rejecting login", self->m_login.c_str(), self->get_remote_address().c_str());
*pkt << (uint8) WOW_FAIL_UNKNOWN_ACCOUNT;

self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error) {
self->Write(pkt, [self](IO::NetworkError const& error) {
if (error)
sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleLogonChallenge self->Write(): ERROR");
else
Expand Down Expand Up @@ -580,7 +592,7 @@ void AuthSocket::_HandleLogonChallenge()
}
}

self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error)
self->Write(pkt, [self](IO::NetworkError const& error)
{
if (error)
sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleLogonChallenge self->Write(): ERROR");
Expand All @@ -605,7 +617,7 @@ void AuthSocket::_HandleLogonProof()
expectedSize = sizeof(sAuthLogonProof_C_Pre_1_11_0);
}

Read((char*) lp.get(), expectedSize, [self = shared_from_this(), lp](MaNGOS::IO::NetworkError const& error)
Read((char*) lp.get(), expectedSize, [self = shared_from_this(), lp](IO::NetworkError const& error)
{
if (error)
{
Expand All @@ -624,7 +636,7 @@ void AuthSocket::_HandleLogonProof()
}

std::shared_ptr<PINData> pinData(new PINData());
self->Read((char*) pinData.get(), sizeof(PINData), [self, lp, pinData](MaNGOS::IO::NetworkError const& error)
self->Read((char*) pinData.get(), sizeof(PINData), [self, lp, pinData](IO::NetworkError const& error)
{
self->_HandleLogonProof__PostRecv(lp, pinData);
});
Expand Down Expand Up @@ -665,7 +677,7 @@ void AuthSocket::_HandleLogonProof__PostRecv_HandleInvalidVersion(std::shared_pt
*pkt << (uint8) WOW_FAIL_VERSION_INVALID;
sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] %u is not a valid client version!", m_build);
sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Patch %s not found", tmp);
Write(pkt, [self = shared_from_this(), pkt](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this(), pkt](IO::NetworkError const& error)
{
if (error)
{
Expand Down Expand Up @@ -711,7 +723,7 @@ void AuthSocket::_HandleLogonProof__PostRecv_HandleInvalidVersion(std::shared_pt
// Set right status
m_status = STATUS_PATCH;

Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down Expand Up @@ -788,7 +800,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr<sAuthLogonProof_C c
std::shared_ptr<ByteBuffer> pkt(new ByteBuffer());
*pkt << (uint8) CMD_AUTH_LOGON_PROOF;
*pkt << (uint8) WOW_FAIL_VERSION_INVALID;
Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down Expand Up @@ -817,7 +829,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr<sAuthLogonProof_C c
std::shared_ptr<ByteBuffer> pkt(new ByteBuffer());
*pkt << (uint8) CMD_AUTH_LOGON_PROOF;
*pkt << (uint8) WOW_FAIL_DB_BUSY;
Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down Expand Up @@ -851,7 +863,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr<sAuthLogonProof_C c
std::shared_ptr<ByteBuffer> pkt(new ByteBuffer());
*pkt << (uint8) CMD_AUTH_LOGON_PROOF;
*pkt << (uint8) WOW_FAIL_PARENTCONTROL;
Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down Expand Up @@ -881,7 +893,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr<sAuthLogonProof_C c
std::shared_ptr<ByteBuffer> pkt = GenerateLogonProofResponse(sha);
m_status = STATUS_AUTHED;

Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down Expand Up @@ -937,7 +949,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr<sAuthLogonProof_C c
*pkt << (uint8) 0;
*pkt << (uint8) 0;
}
Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
Write(pkt, [self = shared_from_this()](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand All @@ -952,7 +964,7 @@ void AuthSocket::_HandleReconnectChallenge()

// Read the header first, to get the length of the remaining packet
std::shared_ptr<sAuthLogonChallengeHeader> header = std::make_shared<sAuthLogonChallengeHeader>();
Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](MaNGOS::IO::NetworkError const& error)
Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](IO::NetworkError const& error)
{
if (error)
{
Expand All @@ -974,7 +986,7 @@ void AuthSocket::_HandleReconnectChallenge()

// Read the remaining of the packet
std::shared_ptr<sAuthLogonChallengeBody> body = std::make_shared<sAuthLogonChallengeBody>();
self->Read((char*)body.get(), actualBodySize, [self, header, body](MaNGOS::IO::NetworkError const& error)
self->Read((char*)body.get(), actualBodySize, [self, header, body](IO::NetworkError const& error)
{
if (error)
{
Expand Down Expand Up @@ -1046,7 +1058,7 @@ void AuthSocket::_HandleReconnectChallenge()
self->m_reconnectProof.SetRand(16 * 8);
pkt->append(self->m_reconnectProof.AsByteArray(16)); // 16 bytes random
pkt->append(VersionChallenge.data(), VersionChallenge.size());
self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error)
self->Write(pkt, [self](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand All @@ -1062,7 +1074,7 @@ void AuthSocket::_HandleReconnectProof()

// Read the packet
std::shared_ptr<sAuthReconnectProof_C> lp(new sAuthReconnectProof_C());
Read((char*) lp.get(), sizeof(sAuthReconnectProof_C), [self = shared_from_this(), lp](MaNGOS::IO::NetworkError const& error)
Read((char*) lp.get(), sizeof(sAuthReconnectProof_C), [self = shared_from_this(), lp](IO::NetworkError const& error)
{
if (error)
{
Expand Down Expand Up @@ -1098,7 +1110,7 @@ void AuthSocket::_HandleReconnectProof()
std::shared_ptr<ByteBuffer> pkt = std::make_shared<ByteBuffer>();
*pkt << uint8(CMD_AUTH_RECONNECT_PROOF);
*pkt << uint8(WOW_SUCCESS);
self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error)
self->Write(pkt, [self](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand All @@ -1121,7 +1133,7 @@ void AuthSocket::_HandleRealmList()
assert(this->m_accountId);

sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleRealmList");
ReadSkip(4, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error)
ReadSkip(4, [self = shared_from_this()](IO::NetworkError const& error)
{
if (error)
{
Expand Down Expand Up @@ -1156,7 +1168,7 @@ void AuthSocket::_HandleRealmList()
*pkt << (uint16)realmlistBuffer.size();
pkt->append(realmlistBuffer);

self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error)
self->Write(pkt, [self](IO::NetworkError const& error)
{
self->ProcessIncomingData();
});
Expand Down
7 changes: 5 additions & 2 deletions src/realmd/AuthSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "SRP6/SRP6.h"
#include "ByteBuffer.h"
#include "IO/Networking/AsyncSocket.h"
#include "IO/Timer/TimerHandle.h"

struct PINData
{
Expand All @@ -53,12 +54,12 @@ enum LockFlag
struct sAuthLogonProof_C;

// Handle login commands
class AuthSocket : public MaNGOS::AsyncSocket<AuthSocket>
class AuthSocket : public IO::Networking::AsyncSocket<AuthSocket>
{
public:
const static int s_BYTE_SIZE = 32;

explicit AuthSocket(SocketDescriptor const& clientAddress);
explicit AuthSocket(IO::Networking::SocketDescriptor const& clientAddress);
~AuthSocket();

void Start() final;
Expand Down Expand Up @@ -141,6 +142,8 @@ class AuthSocket : public MaNGOS::AsyncSocket<AuthSocket>
ACE_HANDLE m_patch = ACE_INVALID_HANDLE;

void InitPatch();

std::shared_ptr<IO::Timer::TimerHandle> m_sessionDurationTimeout;
};

#endif
Expand Down
7 changes: 7 additions & 0 deletions src/shared/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ set (shared_SRCS
IO/Networking/AsyncSocket.h
IO/Networking/NetworkError.h
IO/Networking/SocketDescriptor.h
IO/Multithreading/CreateThread.h
IO/Multithreading/CreateThread.cpp
IO/Timer/impl/windows/AsyncSystemTimer.h
IO/Timer/impl/windows/AsyncSystemTimer.cpp
IO/Timer/impl/windows/TimerHandle.h
IO/Timer/impl/windows/TimerHandle.cpp
IO/Timer/AsyncSystemTimer.h
)

if(USE_LIBCURL)
Expand Down
2 changes: 1 addition & 1 deletion src/shared/Database/Database.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class Database
/// Unless in Sync mode, the return value just gives you a hint whenever or not the statement was added to be async queue
bool Execute(char const* sql);
bool Execute(DbExecMode executionMode, char const* sql);
bool PExecute(DbExecMode executionMode, char const* format,...) ATTR_PRINTF(2,3);
bool PExecute(DbExecMode executionMode, char const* format,...) ATTR_PRINTF(3,4);
bool PExecute(char const* format,...) ATTR_PRINTF(2,3);

// Writes SQL commands to a LOG file (see mangosd.conf "LogSQL")
Expand Down
54 changes: 54 additions & 0 deletions src/shared/IO/Multithreading/CreateThread.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "CreateThread.h"

#if defined(WIN32)
#include <Windows.h>
#elif defined(__linux__)
#include <pthread.h>
#endif

std::thread IO::Multithreading::CreateThread(std::string const& name, std::function<void()> entryFunction)
{
return std::thread([name, entryFunction = std::move(entryFunction)]()
{
IO::Multithreading::RenameCurrentThread(name);
entryFunction();
});
}

void IO::Multithreading::RenameCurrentThread(std::string const& name)
{
#if defined(WIN32)
// Windows part taken from https://stackoverflow.com/a/23899379
// SetThreadDescription is only supported on >= Win10, that's why we are using this approach

const DWORD MS_VC_EXCEPTION=0x406D1388;
#pragma pack(push,8)
typedef struct tagTHREADNAME_INFO
{
DWORD dwType; // Must be 0x1000.
LPCSTR szName; // Pointer to name (in user addr space).
DWORD dwThreadID; // Thread ID (-1=caller thread).
DWORD dwFlags; // Reserved for future use, must be zero.
} THREADNAME_INFO;
#pragma pack(pop)

THREADNAME_INFO info;
info.dwType = 0x1000;
info.szName = name.c_str();
info.dwThreadID = GetCurrentThreadId();
info.dwFlags = 0;

__try
{
RaiseException( MS_VC_EXCEPTION, 0, sizeof(info)/sizeof(ULONG_PTR), (ULONG_PTR*)&info );
}
__except(EXCEPTION_EXECUTE_HANDLER)
{
}
#elif defined(__linux__)
pthread_setname_np(pthread_self(), name.c_str());
#else
// It's not too serisous if we cant rename a thread
#warning "IO::Multithreading::_renameThisThread not supported on your platform"
#endif
}
18 changes: 18 additions & 0 deletions src/shared/IO/Multithreading/CreateThread.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef MANGOS_CREATETHREAD_H
#define MANGOS_CREATETHREAD_H

#include <thread>
#include <functional>

namespace IO { namespace Multithreading {
/// Creates a new system thread that has a name attached to it.
/// Names are super useful when monitoring the utilization of each thread.
[[nodiscard("Use this return value to at least .join() or .detach() the thread")]]
std::thread CreateThread(std::string const& name, std::function<void()> entryFunction);

/// Will rename your current thread.
/// Names are super useful when monitoring the utilization of each thread.
void RenameCurrentThread(std::string const& name);
}} // namespace IO::Multithreading

#endif //MANGOS_CREATETHREAD_H
2 changes: 1 addition & 1 deletion src/shared/IO/Networking/AsyncServerListener.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#ifdef WIN32
#include "./impl/windows/AsyncServerListener.h"
#else
#error "Mangos::IO::Networking not supported on your platform"
#error "IO::Networking not supported on your platform"
#endif

#endif //MANGOS_IO_NETWORKING_ASYNCSERVERLISTENER_H
Loading