diff --git a/src/network/network.cpp b/src/network/network.cpp index 32fc4dec9531e..2ba64f0d04255 100644 --- a/src/network/network.cpp +++ b/src/network/network.cpp @@ -760,20 +760,48 @@ bool NetworkClientConnectGame(NetworkAddress &address, CompanyID join_as, const if (!_network_available) return false; if (!NetworkValidateClientName()) return false; - strecpy(_settings_client.network.last_joined, address.GetAddressAsString(false).c_str(), lastof(_settings_client.network.last_joined)); - + _network_join.address = address; _network_join.company = join_as; _network_join.server_password = join_server_password; _network_join.company_password = join_company_password; + if (_game_mode == GM_MENU) { + /* From the menu we can immediately continue with the actual join. */ + NetworkClientJoinGame(); + } else { + /* When not in the main menu, force going to the main menu to prevent half + * joined clients being in weird states. Like most of the UI still believing + * they are in a network game when other parts have been disconnected. + * Not going to the main menu will cause all kinds of invalid pointer dereferences. + * See NetworkClientJoinGame for more information. + */ + _switch_mode = SM_JOIN_GAME; + } + return true; +} + +/** + * Actually perform the joining to the server. Use #NetworkClientConnectGame + * when you want to connect to a specific server/company. + * + * This is a helper function to be able to load the main menu before joining a + * server to prevent all kinds of invalid pointer dereferences. The most obvious + * set of invalid pointers are everything that has information from a network + * "pool", like NetworkClientInfo. At NetworkDisconnect those pools get cleared + * and nothing gets added until the client is authorized, meaning that until the + * client has entered the right password anything dereferencing NetworkClientInfo, + * such as receiving a chat message will be tainted. + */ +void NetworkClientJoinGame() +{ NetworkDisconnect(); NetworkInitialize(); + strecpy(_settings_client.network.last_joined, _network_join.address.GetAddressAsString(false).c_str(), lastof(_settings_client.network.last_joined)); _network_join_status = NETWORK_JOIN_STATUS_CONNECTING; ShowJoinStatusWindow(); - new TCPClientConnecter(address); - return true; + new TCPClientConnecter(_network_join.address); } static void NetworkInitGameInfo() diff --git a/src/network/network_client.h b/src/network/network_client.h index 28d2d00214b03..81d5b720cdc26 100644 --- a/src/network/network_client.h +++ b/src/network/network_client.h @@ -115,6 +115,7 @@ void NetworkClientSetCompanyPassword(const char *password); /** Information required to join a server. */ struct NetworkJoinInfo { NetworkJoinInfo() : company(COMPANY_SPECTATOR), server_password(nullptr), company_password(nullptr) {} + NetworkAddress address; ///< The address of the server to join. CompanyID company; ///< The company to join. const char *server_password; ///< The password of the server to join. const char *company_password; ///< The password of the company to join. diff --git a/src/network/network_func.h b/src/network/network_func.h index 252d207db54e3..5f3e27c12fae6 100644 --- a/src/network/network_func.h +++ b/src/network/network_func.h @@ -52,6 +52,7 @@ void NetworkPopulateCompanyStats(NetworkCompanyStats *stats); void NetworkUpdateClientInfo(ClientID client_id); void NetworkClientsToSpectators(CompanyID cid); bool NetworkClientConnectGame(const std::string &connection_string, CompanyID default_company, const char *join_server_password = nullptr, const char *join_company_password = nullptr); +void NetworkClientJoinGame(); void NetworkClientRequestMove(CompanyID company, const char *pass = ""); void NetworkClientSendRcon(const char *password, const char *command); void NetworkClientSendChat(NetworkAction action, DestType type, int dest, const char *msg, int64 data = 0); diff --git a/src/openttd.cpp b/src/openttd.cpp index fbeeba793ed21..25fcd3eed814d 100644 --- a/src/openttd.cpp +++ b/src/openttd.cpp @@ -1068,6 +1068,14 @@ void SwitchToMode(SwitchMode new_mode) break; } + case SM_JOIN_GAME: + /* Used in case a (re)connect is triggered from inside a (network) game to prevent + * invalid pointer dereferences. See #NetworkClientJoinGame for more information. + * Loads the main menu and then starts the actual join process. */ + LoadIntroGame(); + NetworkClientJoinGame(); + break; + case SM_MENU: // Switch to game intro menu LoadIntroGame(); if (BaseSounds::ini_set.empty() && BaseSounds::GetUsedSet()->fallback && SoundDriver::GetInstance()->HasOutput()) { diff --git a/src/openttd.h b/src/openttd.h index 77fafab1d11a2..2cd9cc1f0929a 100644 --- a/src/openttd.h +++ b/src/openttd.h @@ -36,6 +36,7 @@ enum SwitchMode { SM_START_HEIGHTMAP, ///< Load a heightmap and start a new game from it. SM_LOAD_HEIGHTMAP, ///< Load heightmap from scenario editor. SM_RESTART_HEIGHTMAP, ///< Load a heightmap and start a new game from it with current settings. + SM_JOIN_GAME, ///< Join a network game. }; /** Display Options */