diff --git a/src/realmd/AuthSocket.cpp b/src/realmd/AuthSocket.cpp index 221df4912f7..87881027d3c 100644 --- a/src/realmd/AuthSocket.cpp +++ b/src/realmd/AuthSocket.cpp @@ -34,6 +34,7 @@ #include "AuthCodes.h" #include "PatchHandler.h" #include "Util.h" +#include "IO/Timer/AsyncSystemTimer.h" #ifdef USE_SENDGRID #include "MailerService.h" @@ -161,7 +162,6 @@ typedef struct AUTH_LOGON_PROOF_S typedef struct AUTH_RECONNECT_PROOF_C { - //uint8 cmd; uint8 R1[16]; uint8 R2[20]; uint8 R3[20]; @@ -194,16 +194,33 @@ typedef struct AuthHandler std::array VersionChallenge = { { 0xBA, 0xA3, 0x1E, 0x99, 0xA0, 0x0B, 0x21, 0x57, 0xFC, 0x37, 0x3F, 0xB3, 0x69, 0xCD, 0xD2, 0xF1 } }; // Accept the connection and set the s random value for SRP6 // TODO where is this SRP6 done? -AuthSocket::AuthSocket(SocketDescriptor const& socketDescriptor) : MaNGOS::AsyncSocket(socketDescriptor) +AuthSocket::AuthSocket(IO::Networking::SocketDescriptor const& socketDescriptor) : IO::Networking::AsyncSocket(socketDescriptor) { sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Accepting connection from '%s'", socketDescriptor.peerAddress.c_str()); } +void AuthSocket::Start() +{ + if (int secs = sConfig.GetIntDefault("MaxSessionDuration", 300)) + { + this->m_sessionDurationTimeout = sAsyncSystemTimer.ScheduleFunctionOnce(std::chrono::seconds(secs), [this]() + { + sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "Connection has reached MaxSessionDuration. Closing socket..."); + // It's correct that we capture _this_, since the timer will be canceled in destructor + this->CloseSocket(); + }); + } + ProcessIncomingData(); +} + // Close patch file descriptor before leaving AuthSocket::~AuthSocket() { if (m_patch != ACE_INVALID_HANDLE) ACE_OS::close(m_patch); + + if (m_sessionDurationTimeout) + m_sessionDurationTimeout->Cancel(); } AccountTypes AuthSocket::GetSecurityOn(uint32 realmId) const @@ -220,11 +237,11 @@ void AuthSocket::ProcessIncomingData() std::shared_ptr cmd = std::make_shared(); sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "ProcessIncomingData() Reading... Ready for next opcode"); - Read((char*)cmd.get(), sizeof(eAuthCmd), [self = shared_from_this(), cmd](MaNGOS::IO::NetworkError const& error) -> void + Read((char*)cmd.get(), sizeof(eAuthCmd), [self = shared_from_this(), cmd](IO::NetworkError const& error) -> void { if (error) { - if (error.Error != MaNGOS::IO::NetworkError::ErrorType::SocketClosed) + if (error.Error != IO::NetworkError::ErrorType::SocketClosed) sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "[Auth] ProcessIncomingData Read(cmd) error"); return; } @@ -279,11 +296,6 @@ void AuthSocket::ProcessIncomingData() }); } -void AuthSocket::Start() -{ - ProcessIncomingData(); -} - std::shared_ptr AuthSocket::GenerateLogonProofResponse(Sha1Hash sha) { std::shared_ptr pkt(new ByteBuffer()); @@ -334,7 +346,7 @@ void AuthSocket::_HandleLogonChallenge() std::shared_ptr header = std::make_shared(); // Read the header first, to get the length of the remaining packet - Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](MaNGOS::IO::NetworkError const& error) -> void + Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](IO::NetworkError const& error) -> void { if (error) { @@ -356,7 +368,7 @@ void AuthSocket::_HandleLogonChallenge() // Read the remaining of the packet std::shared_ptr body = std::make_shared(); - self->Read((char*)body.get(), actualBodySize, [self, header, body](MaNGOS::IO::NetworkError const& error) + self->Read((char*)body.get(), actualBodySize, [self, header, body](IO::NetworkError const& error) { if (error) { @@ -445,7 +457,7 @@ void AuthSocket::_HandleLogonChallenge() sLog.Out(LOG_BASIC, LOG_LVL_BASIC, "[AuthChallenge] Account '%s' using IP '%s 'email address requires email verification - rejecting login", self->m_login.c_str(), self->get_remote_address().c_str()); *pkt << (uint8) WOW_FAIL_UNKNOWN_ACCOUNT; - self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error) { + self->Write(pkt, [self](IO::NetworkError const& error) { if (error) sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleLogonChallenge self->Write(): ERROR"); else @@ -580,7 +592,7 @@ void AuthSocket::_HandleLogonChallenge() } } - self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error) + self->Write(pkt, [self](IO::NetworkError const& error) { if (error) sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "_HandleLogonChallenge self->Write(): ERROR"); @@ -605,7 +617,7 @@ void AuthSocket::_HandleLogonProof() expectedSize = sizeof(sAuthLogonProof_C_Pre_1_11_0); } - Read((char*) lp.get(), expectedSize, [self = shared_from_this(), lp](MaNGOS::IO::NetworkError const& error) + Read((char*) lp.get(), expectedSize, [self = shared_from_this(), lp](IO::NetworkError const& error) { if (error) { @@ -624,7 +636,7 @@ void AuthSocket::_HandleLogonProof() } std::shared_ptr pinData(new PINData()); - self->Read((char*) pinData.get(), sizeof(PINData), [self, lp, pinData](MaNGOS::IO::NetworkError const& error) + self->Read((char*) pinData.get(), sizeof(PINData), [self, lp, pinData](IO::NetworkError const& error) { self->_HandleLogonProof__PostRecv(lp, pinData); }); @@ -665,7 +677,7 @@ void AuthSocket::_HandleLogonProof__PostRecv_HandleInvalidVersion(std::shared_pt *pkt << (uint8) WOW_FAIL_VERSION_INVALID; sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] %u is not a valid client version!", m_build); sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "[AuthChallenge] Patch %s not found", tmp); - Write(pkt, [self = shared_from_this(), pkt](MaNGOS::IO::NetworkError const& error) + Write(pkt, [self = shared_from_this(), pkt](IO::NetworkError const& error) { if (error) { @@ -711,7 +723,7 @@ void AuthSocket::_HandleLogonProof__PostRecv_HandleInvalidVersion(std::shared_pt // Set right status m_status = STATUS_PATCH; - Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error) + Write(pkt, [self = shared_from_this()](IO::NetworkError const& error) { self->ProcessIncomingData(); }); @@ -788,7 +800,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr pkt(new ByteBuffer()); *pkt << (uint8) CMD_AUTH_LOGON_PROOF; *pkt << (uint8) WOW_FAIL_VERSION_INVALID; - Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error) + Write(pkt, [self = shared_from_this()](IO::NetworkError const& error) { self->ProcessIncomingData(); }); @@ -817,7 +829,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr pkt(new ByteBuffer()); *pkt << (uint8) CMD_AUTH_LOGON_PROOF; *pkt << (uint8) WOW_FAIL_DB_BUSY; - Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error) + Write(pkt, [self = shared_from_this()](IO::NetworkError const& error) { self->ProcessIncomingData(); }); @@ -851,7 +863,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr pkt(new ByteBuffer()); *pkt << (uint8) CMD_AUTH_LOGON_PROOF; *pkt << (uint8) WOW_FAIL_PARENTCONTROL; - Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error) + Write(pkt, [self = shared_from_this()](IO::NetworkError const& error) { self->ProcessIncomingData(); }); @@ -881,7 +893,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptr pkt = GenerateLogonProofResponse(sha); m_status = STATUS_AUTHED; - Write(pkt, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error) + Write(pkt, [self = shared_from_this()](IO::NetworkError const& error) { self->ProcessIncomingData(); }); @@ -937,7 +949,7 @@ void AuthSocket::_HandleLogonProof__PostRecv(std::shared_ptrProcessIncomingData(); }); @@ -952,7 +964,7 @@ void AuthSocket::_HandleReconnectChallenge() // Read the header first, to get the length of the remaining packet std::shared_ptr header = std::make_shared(); - Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](MaNGOS::IO::NetworkError const& error) + Read((char*)header.get(), sizeof(sAuthLogonChallengeHeader), [self = shared_from_this(), header](IO::NetworkError const& error) { if (error) { @@ -974,7 +986,7 @@ void AuthSocket::_HandleReconnectChallenge() // Read the remaining of the packet std::shared_ptr body = std::make_shared(); - self->Read((char*)body.get(), actualBodySize, [self, header, body](MaNGOS::IO::NetworkError const& error) + self->Read((char*)body.get(), actualBodySize, [self, header, body](IO::NetworkError const& error) { if (error) { @@ -1046,7 +1058,7 @@ void AuthSocket::_HandleReconnectChallenge() self->m_reconnectProof.SetRand(16 * 8); pkt->append(self->m_reconnectProof.AsByteArray(16)); // 16 bytes random pkt->append(VersionChallenge.data(), VersionChallenge.size()); - self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error) + self->Write(pkt, [self](IO::NetworkError const& error) { self->ProcessIncomingData(); }); @@ -1062,7 +1074,7 @@ void AuthSocket::_HandleReconnectProof() // Read the packet std::shared_ptr lp(new sAuthReconnectProof_C()); - Read((char*) lp.get(), sizeof(sAuthReconnectProof_C), [self = shared_from_this(), lp](MaNGOS::IO::NetworkError const& error) + Read((char*) lp.get(), sizeof(sAuthReconnectProof_C), [self = shared_from_this(), lp](IO::NetworkError const& error) { if (error) { @@ -1098,7 +1110,7 @@ void AuthSocket::_HandleReconnectProof() std::shared_ptr pkt = std::make_shared(); *pkt << uint8(CMD_AUTH_RECONNECT_PROOF); *pkt << uint8(WOW_SUCCESS); - self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error) + self->Write(pkt, [self](IO::NetworkError const& error) { self->ProcessIncomingData(); }); @@ -1121,7 +1133,7 @@ void AuthSocket::_HandleRealmList() assert(this->m_accountId); sLog.Out(LOG_BASIC, LOG_LVL_DEBUG, "Entering _HandleRealmList"); - ReadSkip(4, [self = shared_from_this()](MaNGOS::IO::NetworkError const& error) + ReadSkip(4, [self = shared_from_this()](IO::NetworkError const& error) { if (error) { @@ -1156,7 +1168,7 @@ void AuthSocket::_HandleRealmList() *pkt << (uint16)realmlistBuffer.size(); pkt->append(realmlistBuffer); - self->Write(pkt, [self](MaNGOS::IO::NetworkError const& error) + self->Write(pkt, [self](IO::NetworkError const& error) { self->ProcessIncomingData(); }); diff --git a/src/realmd/AuthSocket.h b/src/realmd/AuthSocket.h index d87c97cd27d..a5f03265eee 100644 --- a/src/realmd/AuthSocket.h +++ b/src/realmd/AuthSocket.h @@ -32,6 +32,7 @@ #include "SRP6/SRP6.h" #include "ByteBuffer.h" #include "IO/Networking/AsyncSocket.h" +#include "IO/Timer/TimerHandle.h" struct PINData { @@ -53,12 +54,12 @@ enum LockFlag struct sAuthLogonProof_C; // Handle login commands -class AuthSocket : public MaNGOS::AsyncSocket +class AuthSocket : public IO::Networking::AsyncSocket { public: const static int s_BYTE_SIZE = 32; - explicit AuthSocket(SocketDescriptor const& clientAddress); + explicit AuthSocket(IO::Networking::SocketDescriptor const& clientAddress); ~AuthSocket(); void Start() final; @@ -141,6 +142,8 @@ class AuthSocket : public MaNGOS::AsyncSocket ACE_HANDLE m_patch = ACE_INVALID_HANDLE; void InitPatch(); + + std::shared_ptr m_sessionDurationTimeout; }; #endif diff --git a/src/shared/CMakeLists.txt b/src/shared/CMakeLists.txt index f976d44e342..27d16950d15 100644 --- a/src/shared/CMakeLists.txt +++ b/src/shared/CMakeLists.txt @@ -103,6 +103,13 @@ set (shared_SRCS IO/Networking/AsyncSocket.h IO/Networking/NetworkError.h IO/Networking/SocketDescriptor.h + IO/Multithreading/CreateThread.h + IO/Multithreading/CreateThread.cpp + IO/Timer/impl/windows/AsyncSystemTimer.h + IO/Timer/impl/windows/AsyncSystemTimer.cpp + IO/Timer/impl/windows/TimerHandle.h + IO/Timer/impl/windows/TimerHandle.cpp + IO/Timer/AsyncSystemTimer.h ) if(USE_LIBCURL) diff --git a/src/shared/Database/Database.h b/src/shared/Database/Database.h index 4d764c52bd5..0f9d197a659 100644 --- a/src/shared/Database/Database.h +++ b/src/shared/Database/Database.h @@ -227,7 +227,7 @@ class Database /// Unless in Sync mode, the return value just gives you a hint whenever or not the statement was added to be async queue bool Execute(char const* sql); bool Execute(DbExecMode executionMode, char const* sql); - bool PExecute(DbExecMode executionMode, char const* format,...) ATTR_PRINTF(2,3); + bool PExecute(DbExecMode executionMode, char const* format,...) ATTR_PRINTF(3,4); bool PExecute(char const* format,...) ATTR_PRINTF(2,3); // Writes SQL commands to a LOG file (see mangosd.conf "LogSQL") diff --git a/src/shared/IO/Multithreading/CreateThread.cpp b/src/shared/IO/Multithreading/CreateThread.cpp new file mode 100644 index 00000000000..0f1cec05eee --- /dev/null +++ b/src/shared/IO/Multithreading/CreateThread.cpp @@ -0,0 +1,54 @@ +#include "CreateThread.h" + +#if defined(WIN32) +#include +#elif defined(__linux__) +#include +#endif + +std::thread IO::Multithreading::CreateThread(std::string const& name, std::function entryFunction) +{ + return std::thread([name, entryFunction = std::move(entryFunction)]() + { + IO::Multithreading::RenameCurrentThread(name); + entryFunction(); + }); +} + +void IO::Multithreading::RenameCurrentThread(std::string const& name) +{ +#if defined(WIN32) + // Windows part taken from https://stackoverflow.com/a/23899379 + // SetThreadDescription is only supported on >= Win10, that's why we are using this approach + + const DWORD MS_VC_EXCEPTION=0x406D1388; +#pragma pack(push,8) + typedef struct tagTHREADNAME_INFO + { + DWORD dwType; // Must be 0x1000. + LPCSTR szName; // Pointer to name (in user addr space). + DWORD dwThreadID; // Thread ID (-1=caller thread). + DWORD dwFlags; // Reserved for future use, must be zero. + } THREADNAME_INFO; +#pragma pack(pop) + + THREADNAME_INFO info; + info.dwType = 0x1000; + info.szName = name.c_str(); + info.dwThreadID = GetCurrentThreadId(); + info.dwFlags = 0; + + __try + { + RaiseException( MS_VC_EXCEPTION, 0, sizeof(info)/sizeof(ULONG_PTR), (ULONG_PTR*)&info ); + } + __except(EXCEPTION_EXECUTE_HANDLER) + { + } +#elif defined(__linux__) + pthread_setname_np(pthread_self(), name.c_str()); +#else + // It's not too serisous if we cant rename a thread + #warning "IO::Multithreading::_renameThisThread not supported on your platform" +#endif +} diff --git a/src/shared/IO/Multithreading/CreateThread.h b/src/shared/IO/Multithreading/CreateThread.h new file mode 100644 index 00000000000..e64dfd3a57b --- /dev/null +++ b/src/shared/IO/Multithreading/CreateThread.h @@ -0,0 +1,18 @@ +#ifndef MANGOS_CREATETHREAD_H +#define MANGOS_CREATETHREAD_H + +#include +#include + +namespace IO { namespace Multithreading { + /// Creates a new system thread that has a name attached to it. + /// Names are super useful when monitoring the utilization of each thread. + [[nodiscard("Use this return value to at least .join() or .detach() the thread")]] + std::thread CreateThread(std::string const& name, std::function entryFunction); + + /// Will rename your current thread. + /// Names are super useful when monitoring the utilization of each thread. + void RenameCurrentThread(std::string const& name); +}} // namespace IO::Multithreading + +#endif //MANGOS_CREATETHREAD_H diff --git a/src/shared/IO/Networking/AsyncServerListener.h b/src/shared/IO/Networking/AsyncServerListener.h index dde9f0cb480..b9444215072 100644 --- a/src/shared/IO/Networking/AsyncServerListener.h +++ b/src/shared/IO/Networking/AsyncServerListener.h @@ -4,7 +4,7 @@ #ifdef WIN32 #include "./impl/windows/AsyncServerListener.h" #else -#error "Mangos::IO::Networking not supported on your platform" +#error "IO::Networking not supported on your platform" #endif #endif //MANGOS_IO_NETWORKING_ASYNCSERVERLISTENER_H diff --git a/src/shared/IO/Networking/AsyncSocket.h b/src/shared/IO/Networking/AsyncSocket.h index f8ef196faeb..bdc841b5c63 100644 --- a/src/shared/IO/Networking/AsyncSocket.h +++ b/src/shared/IO/Networking/AsyncSocket.h @@ -4,7 +4,7 @@ #ifdef WIN32 #include "./impl/windows/AsyncSocket.h" #else -#error "Mangos::IO::Networking not supported on your platform" +#error "IO::Networking not supported on your platform" #endif #endif //MANGOS_IO_NETWORKING_ASYNCSOCKET_H diff --git a/src/shared/IO/Networking/NetworkError.h b/src/shared/IO/Networking/NetworkError.h index 3620c53df71..44978433327 100644 --- a/src/shared/IO/Networking/NetworkError.h +++ b/src/shared/IO/Networking/NetworkError.h @@ -3,7 +3,7 @@ #include -namespace MaNGOS { namespace IO +namespace IO { struct NetworkError { enum class ErrorType { @@ -23,6 +23,6 @@ namespace MaNGOS { namespace IO return "TODO, Error to String"; } }; -}} // namespace MaNGOS::IO +} // namespace IO #endif //MANGOS_IO_NETWORKING_NETWORKERROR_H diff --git a/src/shared/IO/Networking/SocketDescriptor.h b/src/shared/IO/Networking/SocketDescriptor.h index 495237612de..38a47144272 100644 --- a/src/shared/IO/Networking/SocketDescriptor.h +++ b/src/shared/IO/Networking/SocketDescriptor.h @@ -4,7 +4,7 @@ #ifdef WIN32 #include "./impl/windows/SocketDescriptor.h" #else -#error "Mangos::IO::Networking not supported on your platform" +#error "IO::Networking not supported on your platform" #endif #endif //MANGOS_IO_NETWORKING_SOCKETDESCRIPTOR_H diff --git a/src/shared/IO/Networking/impl/windows/AsyncServerListener.h b/src/shared/IO/Networking/impl/windows/AsyncServerListener.h index 5c3c439561e..4437cb72fb8 100644 --- a/src/shared/IO/Networking/impl/windows/AsyncServerListener.h +++ b/src/shared/IO/Networking/impl/windows/AsyncServerListener.h @@ -47,16 +47,9 @@ void AsyncServerListener::RunEventLoop(std::chrono::milliseconds bool booleanOkay = ::GetQueuedCompletionStatus(m_completionPort, &bytesWritten, &completionKey, reinterpret_cast(&task), maxBlockingDuration.count()); DWORD errorCode = ::GetLastError(); if (task) - { - task->OnComplete(errorCode); - } - - if (!booleanOkay) - { - if (errorCode != WAIT_TIMEOUT) - sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[ERROR] ::GetQueuedCompletionStatus(...) Error: %u", errorCode); - return; - } + task->OnComplete(booleanOkay ? 0 : errorCode); + else if (errorCode != WAIT_TIMEOUT) + sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[ERROR] ::GetQueuedCompletionStatus(...) Has no TASK!!! Error: %u", errorCode); } template @@ -151,7 +144,7 @@ void AsyncServerListener::StartAcceptOperation() std::string peerAddress(inet_ntoa(addrBuffer->peerAddress.sin_addr)); // inet_ntoa will "free" (reuse) the char* on its own delete addrBuffer; - SocketDescriptor socketDescriptor { peerAddress, peerSocket }; + IO::Networking::SocketDescriptor socketDescriptor { peerAddress, peerSocket }; std::shared_ptr client = std::make_shared(socketDescriptor); HandleAccept(client); diff --git a/src/shared/IO/Networking/impl/windows/AsyncSocket.h b/src/shared/IO/Networking/impl/windows/AsyncSocket.h index 690183e3e5f..fc07f982f15 100644 --- a/src/shared/IO/Networking/impl/windows/AsyncSocket.h +++ b/src/shared/IO/Networking/impl/windows/AsyncSocket.h @@ -11,7 +11,7 @@ #include "./SocketDescriptor.h" #include "./IocpOperationTask.h" -namespace MaNGOS { +namespace IO { namespace Networking { template class AsyncSocketListener; @@ -30,11 +30,11 @@ namespace MaNGOS { virtual void Start() = 0; - void Read(char* target, std::size_t size, std::function const& callback); - void ReadSkip(std::size_t skipSize, std::function const& callback); + void Read(char* target, std::size_t size, std::function const& callback); + void ReadSkip(std::size_t skipSize, std::function const& callback); - void Write(std::shared_ptr const> const& source, std::function const& callback); - void Write(std::shared_ptr const& source, std::function const& callback); + void Write(std::shared_ptr const> const& source, std::function const& callback); + void Write(std::shared_ptr const& source, std::function const& callback); void CloseSocket(); @@ -45,10 +45,10 @@ namespace MaNGOS { bool m_disconnectRequest = false; // Read = the target buffer to write the network stream to - std::function m_readCallback = nullptr; + std::function m_readCallback = nullptr; // Write = the source buffer from where to read to be able to write to the network stream - std::function m_writeCallback = nullptr; + std::function m_writeCallback = nullptr; std::shared_ptr m_writeSrcBufferDummyHolder_ByteBuffer = nullptr; // Optional. To keep the shared_ptr for the lifetime of the transfer std::shared_ptr const> m_writeSrcBufferDummyHolder_u8Vector = nullptr; // Optional. To keep the shared_ptr for the lifetime of the transfer }; @@ -68,16 +68,16 @@ namespace MaNGOS { } template - void AsyncSocket::Read(char* target, std::size_t size, std::function const& callback) + void AsyncSocket::Read(char* target, std::size_t size, std::function const& callback) { if (m_disconnectRequest) { - callback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::SocketClosed}); + callback(IO::NetworkError{IO::NetworkError::ErrorType::SocketClosed}); return; } if (m_readCallback != nullptr) { // We already have a buffer. Just like ASIO, only one Read can be queued at the same time - callback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed}); + callback(IO::NetworkError{IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed}); return; } m_readCallback = callback; @@ -100,7 +100,7 @@ namespace MaNGOS { self->CloseSocket(); auto tmpCallback = std::move(self->m_readCallback); delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::SocketClosed}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::SocketClosed}); return; } @@ -120,7 +120,7 @@ namespace MaNGOS { sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[ERROR] ::WSARecv(...) Error: %u", err); auto tmpCallback = std::move(self->m_readCallback); delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::InternalError}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::InternalError}); return; } } @@ -129,7 +129,7 @@ namespace MaNGOS { { auto tmpCallback = std::move(self->m_readCallback); delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::NoError}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::NoError}); } }); @@ -143,18 +143,18 @@ namespace MaNGOS { sLog.Out(LOG_NETWORK, LOG_LVL_ERROR, "[ERROR] ::WSARecv(...) Error: %u", err); auto tmpCallback = std::move(this->m_readCallback); delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::InternalError}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::InternalError}); return; } } } template - void AsyncSocket::ReadSkip(std::size_t skipSize, std::function const& callback) + void AsyncSocket::ReadSkip(std::size_t skipSize, std::function const& callback) { std::shared_ptr> skipBuffer(new std::vector()); skipBuffer->resize(skipSize); - Read((char*)skipBuffer->data(), skipSize, [skipBuffer, callback](MaNGOS::IO::NetworkError const& error) + Read((char*)skipBuffer->data(), skipSize, [skipBuffer, callback](IO::NetworkError const& error) { // KEEP skipBuffer in scope! // Do not remove skipBuffer before Read() is done, since we are transferring into it via async IO @@ -166,16 +166,16 @@ namespace MaNGOS { /// Warning using this function will NOT copy the buffer, dont overwrite it unless callback is triggered! template - void AsyncSocket::Write(std::shared_ptr const> const& source, std::function const& callback) + void AsyncSocket::Write(std::shared_ptr const> const& source, std::function const& callback) { if (m_disconnectRequest) { - callback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::SocketClosed}); + callback(IO::NetworkError{IO::NetworkError::ErrorType::SocketClosed}); return; } if (m_writeCallback != nullptr) { // We already have a buffer. Just like ASIO, only one Write can be queued at the same time - callback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed}); + callback(IO::NetworkError{IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed}); return; } m_writeCallback = callback; @@ -199,7 +199,7 @@ namespace MaNGOS { auto tmpCallback = std::move(self->m_writeCallback); self->m_writeSrcBufferDummyHolder_u8Vector = nullptr; delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::SocketClosed}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::SocketClosed}); return; } @@ -220,7 +220,7 @@ namespace MaNGOS { auto tmpCallback = std::move(self->m_writeCallback); self->m_writeSrcBufferDummyHolder_u8Vector = nullptr; delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::InternalError}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::InternalError}); return; } } @@ -230,7 +230,7 @@ namespace MaNGOS { auto tmpCallback = std::move(self->m_writeCallback); self->m_writeSrcBufferDummyHolder_u8Vector = nullptr; delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::NoError}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::NoError}); } }); @@ -245,7 +245,7 @@ namespace MaNGOS { auto tmpCallback = std::move(this->m_writeCallback); this->m_writeSrcBufferDummyHolder_u8Vector = nullptr; delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::InternalError}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::InternalError}); return; } } @@ -253,16 +253,16 @@ namespace MaNGOS { /// Warning using this function will NOT copy the buffer, dont overwrite it unless callback is triggered! template - void AsyncSocket::Write(std::shared_ptr const& source, std::function const& callback) + void AsyncSocket::Write(std::shared_ptr const& source, std::function const& callback) { if (m_disconnectRequest) { - callback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::SocketClosed}); + callback(IO::NetworkError{IO::NetworkError::ErrorType::SocketClosed}); return; } if (m_writeCallback != nullptr) { // We already have a buffer. Just like ASIO, only one Write can be queued at the same time - callback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed}); + callback(IO::NetworkError{IO::NetworkError::ErrorType::OnlyOneTransferPerDirectionAllowed}); return; } m_writeCallback = callback; @@ -286,7 +286,7 @@ namespace MaNGOS { auto tmpCallback = std::move(self->m_writeCallback); self->m_writeSrcBufferDummyHolder_ByteBuffer = nullptr; delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::SocketClosed}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::SocketClosed}); return; } @@ -307,7 +307,7 @@ namespace MaNGOS { auto tmpCallback = std::move(self->m_writeCallback); self->m_writeSrcBufferDummyHolder_ByteBuffer = nullptr; delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::InternalError}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::InternalError}); return; } } @@ -317,7 +317,7 @@ namespace MaNGOS { auto tmpCallback = std::move(self->m_writeCallback); self->m_writeSrcBufferDummyHolder_ByteBuffer = nullptr; delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::NoError}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::NoError}); } }); @@ -332,7 +332,7 @@ namespace MaNGOS { auto tmpCallback = std::move(this->m_writeCallback); this->m_writeSrcBufferDummyHolder_u8Vector = nullptr; delete task; - tmpCallback(MaNGOS::IO::NetworkError{MaNGOS::IO::NetworkError::ErrorType::InternalError}); + tmpCallback(IO::NetworkError{IO::NetworkError::ErrorType::InternalError}); return; } } @@ -349,7 +349,7 @@ namespace MaNGOS { ::closesocket(m_socket.nativeSocket); } -} +}} // namespace IO::Networking #endif //MANGOS_IO_NETWORKING_WIN32_ASYNCSOCKET_H diff --git a/src/shared/IO/Networking/impl/windows/SocketDescriptor.h b/src/shared/IO/Networking/impl/windows/SocketDescriptor.h index 66baaf53ba1..02fd0a8464d 100644 --- a/src/shared/IO/Networking/impl/windows/SocketDescriptor.h +++ b/src/shared/IO/Networking/impl/windows/SocketDescriptor.h @@ -5,6 +5,8 @@ #include "WinSock2.h" +namespace IO { namespace Networking { + struct SocketDescriptor { public: /// IP address without port. @@ -15,4 +17,6 @@ struct SocketDescriptor { SOCKET nativeSocket; }; +}} // namespace IO::Networking + #endif //MANGOS_IO_NETWORKING_WIN32_SOCKETDESCIRPTOR_H diff --git a/src/shared/IO/Timer/AsyncSystemTimer.h b/src/shared/IO/Timer/AsyncSystemTimer.h new file mode 100644 index 00000000000..e178f42dccd --- /dev/null +++ b/src/shared/IO/Timer/AsyncSystemTimer.h @@ -0,0 +1,15 @@ +#ifndef MANGOS_IO_TIMER_ASYNCSYSTEMTIMER_H +#define MANGOS_IO_TIMER_ASYNCSYSTEMTIMER_H + +#include "Common.h" +#include "Log.h" +#include "Policies/Singleton.h" + +#ifdef WIN32 +#include "./impl/windows/AsyncSystemTimer.h" +#else +#error "IO::Timer not supported on your platform" +#endif + + +#endif //MANGOS_IO_TIMER_ASYNCSYSTEMTIMER_H diff --git a/src/shared/IO/Timer/TimerHandle.h b/src/shared/IO/Timer/TimerHandle.h new file mode 100644 index 00000000000..73ba3a4db03 --- /dev/null +++ b/src/shared/IO/Timer/TimerHandle.h @@ -0,0 +1,11 @@ +#ifndef MANGOS_IO_TIMER_TIMERHANDLE_H +#define MANGOS_IO_TIMER_TIMERHANDLE_H + +#ifdef WIN32 +#include "./impl/windows/TimerHandle.h" +#else +#error "IO::Timer not supported on your platform" +#endif + + +#endif //MANGOS_IO_TIMER_TIMERHANDLE_H diff --git a/src/shared/IO/Timer/impl/unix/TODO b/src/shared/IO/Timer/impl/unix/TODO new file mode 100644 index 00000000000..f5cf8f0c27a --- /dev/null +++ b/src/shared/IO/Timer/impl/unix/TODO @@ -0,0 +1 @@ +Use timerfd diff --git a/src/shared/IO/Timer/impl/windows/AsyncSystemTimer.cpp b/src/shared/IO/Timer/impl/windows/AsyncSystemTimer.cpp new file mode 100644 index 00000000000..ad97553372c --- /dev/null +++ b/src/shared/IO/Timer/impl/windows/AsyncSystemTimer.cpp @@ -0,0 +1,87 @@ +#include "./AsyncSystemTimer.h" +#include "IO/Multithreading/CreateThread.h"; +#include +#include "Log.h" + +INSTANTIATE_SINGLETON_1(IO::Timer::impl::windows::AsyncSystemTimer); + +IO::Timer::impl::windows::AsyncSystemTimer::AsyncSystemTimer() +{ + m_nativeTimerQueueHandle = ::CreateTimerQueue(); + if (!m_nativeTimerQueueHandle) + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "::CreateTimerQueue() failed: %d", GetLastError()); + MANGOS_ASSERT(m_nativeTimerQueueHandle); + + ScheduleFunctionOnce(std::chrono::seconds(0), []() { + // Since we are single threaded, we can rename this Thread, so we know what's up + IO::Multithreading::RenameCurrentThread("SystemTimer"); + }); +} + +void IO::Timer::impl::windows::AsyncSystemTimer::RemoveAllTimersAndStopThread() +{ + HANDLE timerQueueHandle = m_nativeTimerQueueHandle; + m_nativeTimerQueueHandle = nullptr; + if (timerQueueHandle) + { + ::DeleteTimerQueueEx( + timerQueueHandle, + INVALID_HANDLE_VALUE // MSDN: If this parameter is INVALID_HANDLE_VALUE, the function waits for all callback functions to complete before returning. + ); + } + + m_pendingTimers_mutex.lock(); + m_pendingTimers.clear(); + m_pendingTimers_mutex.unlock(); +} + +void IO::Timer::impl::windows::AsyncSystemTimer::_timerQueueTimeoutCallback(PVOID opaquePointer, BOOLEAN _thisVariableIsNotUsedInTimers) +{ + (void)_thisVariableIsNotUsedInTimers; + + auto handleRawSharedPtr = (std::shared_ptr*)opaquePointer; + std::shared_ptr timerHandle = *handleRawSharedPtr; + delete handleRawSharedPtr; + + timerHandle->m_asyncSystemTimer->m_pendingTimers_mutex.lock(); + bool wasRemoved = timerHandle->m_asyncSystemTimer->m_pendingTimers.erase(timerHandle); + timerHandle->m_asyncSystemTimer->m_pendingTimers_mutex.unlock(); + if (!wasRemoved) + return; // The timer was already removed, so we don't want to re-execute it again. + + timerHandle->m_callback(); +} + +std::shared_ptr IO::Timer::impl::windows::AsyncSystemTimer::_scheduleFunctionOnceMs(uint64_t milliseconds, std::function const& function) +{ + MANGOS_ASSERT(this->m_nativeTimerQueueHandle); + + std::shared_ptr timerHandle = std::make_shared(); + timerHandle->m_asyncSystemTimer = this; + timerHandle->m_callback = function; + + // since we are using an opaque pointer model of the kernel here, + // we have to allocate unsafe memory which we will free inside the function + auto handleRawSharedPtr = new std::shared_ptr(timerHandle); + bool wasOkay = ::CreateTimerQueueTimer( + &timerHandle->m_nativeTimerHandle, + m_nativeTimerQueueHandle, + _timerQueueTimeoutCallback, + handleRawSharedPtr, + milliseconds, + 0, // Period = 0: Don't repeat the timer + WT_EXECUTEONLYONCE | WT_EXECUTEINTIMERTHREAD); // Only execute in WT_EXECUTEINTIMERTHREAD (single thread), otherwise we would spam spawn new system threads. + + if (!wasOkay) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "::CreateTimerQueueTimer failed: %d", GetLastError()); + delete handleRawSharedPtr; + return nullptr; + } + + timerHandle->m_asyncSystemTimer->m_pendingTimers_mutex.lock(); + timerHandle->m_asyncSystemTimer->m_pendingTimers.insert(timerHandle); + timerHandle->m_asyncSystemTimer->m_pendingTimers_mutex.unlock(); + + return timerHandle; +} diff --git a/src/shared/IO/Timer/impl/windows/AsyncSystemTimer.h b/src/shared/IO/Timer/impl/windows/AsyncSystemTimer.h new file mode 100644 index 00000000000..42e16397a8f --- /dev/null +++ b/src/shared/IO/Timer/impl/windows/AsyncSystemTimer.h @@ -0,0 +1,48 @@ +#ifndef MANGOS_IO_TIMER_WIN32_ASYNCSYSTEMTIMER_H +#define MANGOS_IO_TIMER_WIN32_ASYNCSYSTEMTIMER_H + +#include +#include +#include +#include +#include +#include +#include "Policies/SingletonImp.h" +#include "Policies/ThreadingModel.h" +#include "./TimerHandle.h" + +namespace IO { namespace Timer { namespace impl { namespace windows { + class AsyncSystemTimer : public MaNGOS::Singleton> { + friend IO::Timer::TimerHandle; + public: + explicit AsyncSystemTimer(); + ~AsyncSystemTimer() = default; + AsyncSystemTimer(const AsyncSystemTimer&) = delete; + AsyncSystemTimer& operator=(const AsyncSystemTimer&) = delete; + AsyncSystemTimer(AsyncSystemTimer&&) = delete; + AsyncSystemTimer& operator=(AsyncSystemTimer&&) = delete; + + void RemoveAllTimersAndStopThread(); + + /// Low resolution async clock system clock with ~16ms accuracy. + /// Do not use this function for in-game-logic inside mangosd! + /// Use `player->m_Events.AddEvent` instead. + /// Please lock the necessary resources inside this function + template + std::shared_ptr ScheduleFunctionOnce(std::chrono::duration timeFromNow, std::function const& function) + { + uint64_t milliseconds = std::chrono::duration_cast(timeFromNow).count(); + return this->_scheduleFunctionOnceMs(milliseconds, std::move(function)); + } + private: + std::shared_ptr _scheduleFunctionOnceMs(uint64_t milliseconds, std::function const& function); + static void _timerQueueTimeoutCallback(PVOID opaquePointer, BOOLEAN _thisVariableIsNotUsedInTimers); + + std::mutex m_pendingTimers_mutex; + std::unordered_set> m_pendingTimers; + HANDLE m_nativeTimerQueueHandle; + }; +}}}} // namespace IO::Timer::impl::windows +#define sAsyncSystemTimer MaNGOS::Singleton::Instance() + +#endif //MANGOS_IO_TIMER_WIN32_ASYNCSYSTEMTIMER_H diff --git a/src/shared/IO/Timer/impl/windows/TimerHandle.cpp b/src/shared/IO/Timer/impl/windows/TimerHandle.cpp new file mode 100644 index 00000000000..9b9edc1d4d9 --- /dev/null +++ b/src/shared/IO/Timer/impl/windows/TimerHandle.cpp @@ -0,0 +1,20 @@ +#include "TimerHandle.h" +#include "./AsyncSystemTimer.h" +#include "Log.h" + +void IO::Timer::TimerHandle::Cancel() +{ + m_asyncSystemTimer->m_pendingTimers_mutex.lock(); + bool wasRemoved = m_asyncSystemTimer->m_pendingTimers.erase(shared_from_this()); + m_asyncSystemTimer->m_pendingTimers_mutex.unlock(); + if (!wasRemoved) + return; // The timer was already removed, so we don't want to re-execute it again. + + // To avoid race conditions: + // MSDN: If this parameter (last one) is INVALID_HANDLE_VALUE, the function waits for any running timer callback functions to complete before returning. + bool wasOkay = ::DeleteTimerQueueTimer(m_asyncSystemTimer->m_nativeTimerQueueHandle, m_nativeTimerHandle, INVALID_HANDLE_VALUE); + if (!wasOkay) + { + sLog.Out(LOG_BASIC, LOG_LVL_ERROR, "::DeleteTimerQueueTimer failed: %d", GetLastError()); + } +} diff --git a/src/shared/IO/Timer/impl/windows/TimerHandle.h b/src/shared/IO/Timer/impl/windows/TimerHandle.h new file mode 100644 index 00000000000..0bbe6ee2365 --- /dev/null +++ b/src/shared/IO/Timer/impl/windows/TimerHandle.h @@ -0,0 +1,26 @@ +#ifndef MANGOS_IO_TIMER_WIN32_TIMERHANDLE_H +#define MANGOS_IO_TIMER_WIN32_TIMERHANDLE_H + +#include +#include + +namespace IO { namespace Timer { namespace impl { namespace windows { + class AsyncSystemTimer; +}}}} // namespace IO::Timer::impl::windows + +namespace IO { namespace Timer +{ + class TimerHandle : public std::enable_shared_from_this + { + friend IO::Timer::impl::windows::AsyncSystemTimer; + public: + void Cancel(); + private: + void* m_nativeTimerHandle = nullptr; + IO::Timer::impl::windows::AsyncSystemTimer* m_asyncSystemTimer = nullptr; + std::function m_callback = nullptr; + }; +}} // namespace IO::Timer + + +#endif //MANGOS_IO_TIMER_WIN32_TIMERHANDLE_H