diff --git a/src/app/OperationalSessionSetup.cpp b/src/app/OperationalSessionSetup.cpp index 2996cf2dd891d4..2921a8cc03ea01 100644 --- a/src/app/OperationalSessionSetup.cpp +++ b/src/app/OperationalSessionSetup.cpp @@ -94,88 +94,13 @@ bool OperationalSessionSetup::AttachToExistingSecureSession() void OperationalSessionSetup::Connect(Callback::Callback * onConnection, Callback::Callback * onFailure) { - CHIP_ERROR err = CHIP_NO_ERROR; - bool isConnected = false; - - // - // Always enqueue our user provided callbacks into our callback list. - // If anything goes wrong below, we'll trigger failures (including any queued from - // a previous iteration which in theory shouldn't happen, but this is written to be more defensive) - // - EnqueueConnectionCallbacks(onConnection, onFailure); - - switch (mState) - { - case State::Uninitialized: - err = CHIP_ERROR_INCORRECT_STATE; - break; - - case State::NeedsAddress: - isConnected = AttachToExistingSecureSession(); - if (!isConnected) - { - // LookupPeerAddress could perhaps call back with a result - // synchronously, so do our state update first. - MoveToState(State::ResolvingAddress); - err = LookupPeerAddress(); - if (err != CHIP_NO_ERROR) - { - // Roll back the state change, since we are presumably not in - // the middle of a lookup. - MoveToState(State::NeedsAddress); - } - } - - break; - - case State::ResolvingAddress: - case State::WaitingForRetry: - isConnected = AttachToExistingSecureSession(); - break; - - case State::HasAddress: - isConnected = AttachToExistingSecureSession(); - if (!isConnected) - { - // We should not actually every be in be in State::HasAddress. This - // is because in the same call that we moved to State::HasAddress - // we either move to State::Connecting or call - // DequeueConnectionCallbacks with an error thus releasing - // ourselves before any call would reach this section of code. - err = CHIP_ERROR_INCORRECT_STATE; - } - - break; - - case State::Connecting: - break; - - case State::SecureConnected: - isConnected = true; - break; - - default: - err = CHIP_ERROR_INCORRECT_STATE; - } - - if (isConnected) - { - MoveToState(State::SecureConnected); - } + Connect(onConnection, onFailure, nullptr); +} - // - // Dequeue all our callbacks on either encountering an error - // or if we successfully connected. Both should not be set - // simultaneously. - // - if (err != CHIP_NO_ERROR || isConnected) - { - DequeueConnectionCallbacks(err); - // Do not touch `this` instance anymore; it has been destroyed in DequeueConnectionCallbacks. - // While it is odd to have an explicit return here at the end of the function, we do so - // as a precaution in case someone later on adds something to the end of this function. - return; - } +void OperationalSessionSetup::Connect(Callback::Callback * onConnection, + Callback::Callback * onSetupFailure) +{ + Connect(onConnection, nullptr, onSetupFailure); } void OperationalSessionSetup::UpdateDeviceData(const Transport::PeerAddress & addr, const ReliableMessageProtocolConfig & config) @@ -290,8 +215,97 @@ CHIP_ERROR OperationalSessionSetup::EstablishConnection(const ReliableMessagePro return CHIP_NO_ERROR; } +void OperationalSessionSetup::Connect(Callback::Callback * onConnection, + Callback::Callback * onFailure, + Callback::Callback * onSetupFailure) +{ + CHIP_ERROR err = CHIP_NO_ERROR; + bool isConnected = false; + + // + // Always enqueue our user provided callbacks into our callback list. + // If anything goes wrong below, we'll trigger failures (including any queued from + // a previous iteration which in theory shouldn't happen, but this is written to be more defensive) + // + EnqueueConnectionCallbacks(onConnection, onFailure, onSetupFailure); + + switch (mState) + { + case State::Uninitialized: + err = CHIP_ERROR_INCORRECT_STATE; + break; + + case State::NeedsAddress: + isConnected = AttachToExistingSecureSession(); + if (!isConnected) + { + // LookupPeerAddress could perhaps call back with a result + // synchronously, so do our state update first. + MoveToState(State::ResolvingAddress); + err = LookupPeerAddress(); + if (err != CHIP_NO_ERROR) + { + // Roll back the state change, since we are presumably not in + // the middle of a lookup. + MoveToState(State::NeedsAddress); + } + } + + break; + + case State::ResolvingAddress: + case State::WaitingForRetry: + isConnected = AttachToExistingSecureSession(); + break; + + case State::HasAddress: + isConnected = AttachToExistingSecureSession(); + if (!isConnected) + { + // We should not actually every be in be in State::HasAddress. This + // is because in the same call that we moved to State::HasAddress + // we either move to State::Connecting or call + // DequeueConnectionCallbacks with an error thus releasing + // ourselves before any call would reach this section of code. + err = CHIP_ERROR_INCORRECT_STATE; + } + + break; + + case State::Connecting: + break; + + case State::SecureConnected: + isConnected = true; + break; + + default: + err = CHIP_ERROR_INCORRECT_STATE; + } + + if (isConnected) + { + MoveToState(State::SecureConnected); + } + + // + // Dequeue all our callbacks on either encountering an error + // or if we successfully connected. Both should not be set + // simultaneously. + // + if (err != CHIP_NO_ERROR || isConnected) + { + DequeueConnectionCallbacks(err); + // Do not touch `this` instance anymore; it has been destroyed in DequeueConnectionCallbacks. + // While it is odd to have an explicit return here at the end of the function, we do so + // as a precaution in case someone later on adds something to the end of this function. + return; + } +} + void OperationalSessionSetup::EnqueueConnectionCallbacks(Callback::Callback * onConnection, - Callback::Callback * onFailure) + Callback::Callback * onFailure, + Callback::Callback * onSetupFailure) { if (onConnection != nullptr) { @@ -302,11 +316,17 @@ void OperationalSessionSetup::EnqueueConnectionCallbacks(Callback::CallbackCancel()); } + + if (onSetupFailure != nullptr) + { + mSetupFailure.Enqueue(onSetupFailure->Cancel()); + } } -void OperationalSessionSetup::DequeueConnectionCallbacks(CHIP_ERROR error, ReleaseBehavior releaseBehavior) +void OperationalSessionSetup::DequeueConnectionCallbacks(CHIP_ERROR error, SessionEstablishmentStage stage, + ReleaseBehavior releaseBehavior) { - Cancelable failureReady, successReady; + Cancelable failureReady, setupFailureReady, successReady; // // Dequeue both failure and success callback lists into temporary stack args before invoking either of them. @@ -314,6 +334,7 @@ void OperationalSessionSetup::DequeueConnectionCallbacks(CHIP_ERROR error, Relea // since the callee may destroy this object as part of that callback. // mConnectionFailure.DequeueAll(failureReady); + mSetupFailure.DequeueAll(setupFailureReady); mConnectionSuccess.DequeueAll(successReady); #if CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES @@ -339,13 +360,14 @@ void OperationalSessionSetup::DequeueConnectionCallbacks(CHIP_ERROR error, Relea // DO NOT touch any members of this object after this point. It's dead. - NotifyConnectionCallbacks(failureReady, successReady, error, peerId, performingAddressUpdate, exchangeMgr, - optionalSessionHandle); + NotifyConnectionCallbacks(failureReady, setupFailureReady, successReady, error, stage, peerId, performingAddressUpdate, + exchangeMgr, optionalSessionHandle); } -void OperationalSessionSetup::NotifyConnectionCallbacks(Cancelable & failureReady, Cancelable & successReady, CHIP_ERROR error, - const ScopedNodeId & peerId, bool performingAddressUpdate, - Messaging::ExchangeManager * exchangeMgr, +void OperationalSessionSetup::NotifyConnectionCallbacks(Cancelable & failureReady, Cancelable & setupFailureReady, + Cancelable & successReady, CHIP_ERROR error, + SessionEstablishmentStage stage, const ScopedNodeId & peerId, + bool performingAddressUpdate, Messaging::ExchangeManager * exchangeMgr, const Optional & optionalSessionHandle) { // @@ -367,6 +389,22 @@ void OperationalSessionSetup::NotifyConnectionCallbacks(Cancelable & failureRead } } + while (setupFailureReady.mNext != &setupFailureReady) + { + // We expect that we only have callbacks if we are not performing just address update. + VerifyOrDie(!performingAddressUpdate); + Callback::Callback * cb = Callback::Callback::FromCancelable(setupFailureReady.mNext); + + cb->Cancel(); + + if (error != CHIP_NO_ERROR) + { + // Initialize the ConnnectionFailureInfo object + ConnnectionFailureInfo failureInfo(peerId, error, stage); + cb->mCall(cb->mContext, failureInfo); + } + } + while (successReady.mNext != &successReady) { // We expect that we only have callbacks if we are not performing just address update. @@ -383,7 +421,7 @@ void OperationalSessionSetup::NotifyConnectionCallbacks(Cancelable & failureRead } } -void OperationalSessionSetup::OnSessionEstablishmentError(CHIP_ERROR error) +void OperationalSessionSetup::OnSessionEstablishmentError(CHIP_ERROR error, SessionEstablishmentStage stage) { VerifyOrReturn(mState == State::Connecting, ChipLogError(Discovery, "OnSessionEstablishmentError was called while we were not connecting")); @@ -438,7 +476,7 @@ void OperationalSessionSetup::OnSessionEstablishmentError(CHIP_ERROR error) #endif // CHIP_DEVICE_CONFIG_ENABLE_AUTOMATIC_CASE_RETRIES } - DequeueConnectionCallbacks(error); + DequeueConnectionCallbacks(error, stage); // Do not touch `this` instance anymore; it has been destroyed in DequeueConnectionCallbacks. } diff --git a/src/app/OperationalSessionSetup.h b/src/app/OperationalSessionSetup.h index 2925259066e6db..e1389f8d5ce7bb 100644 --- a/src/app/OperationalSessionSetup.h +++ b/src/app/OperationalSessionSetup.h @@ -155,6 +155,19 @@ typedef void (*OnDeviceConnectionRetry)(void * context, const ScopedNodeId & pee class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, public AddressResolve::NodeListener { public: + struct ConnnectionFailureInfo + { + const ScopedNodeId peerId; + CHIP_ERROR error; + SessionEstablishmentStage sessionStage; + + ConnnectionFailureInfo(const ScopedNodeId & peer, CHIP_ERROR err, SessionEstablishmentStage stage) : + peerId(peer), error(err), sessionStage(stage) + {} + }; + + using OnSetupFailure = void (*)(void * context, const ConnnectionFailureInfo & failureInfo); + ~OperationalSessionSetup() override; OperationalSessionSetup(const CASEClientInitParams & params, CASEClientPoolDelegate * clientPool, ScopedNodeId peerId, @@ -180,8 +193,8 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, * The device is expected to have been commissioned, A CASE session * setup will be triggered. * - * On establishing the session, the callback function `onConnection` will be called. If the - * session setup fails, `onFailure` will be called. + * On establishing the session, if the session setup succeeds, the callback function `onConnection` will be called. + * If the session setup fails, `onFailure` will be called. * * If the session already exists, `onConnection` will be called immediately, * before the Connect call returns. @@ -192,11 +205,28 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, */ void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure); + /* + * This function can be called to establish a secure session with the device. + * + * The device is expected to have been commissioned, A CASE session + * setup will be triggered. + * + * On establishing the session, if the session setup succeeds, the callback function `onConnection` will be called. + * If the session setup fails, `onSetupFailure` will be called. + * + * If the session already exists, `onConnection` will be called immediately, + * before the Connect call returns. + * + * `onSetupFailure` may be called before the Connect call returns, for error cases that are detected synchronously + * (e.g. inability to start an address lookup). + */ + void Connect(Callback::Callback * onConnection, Callback::Callback * onSetupFailure); + bool IsForAddressUpdate() const { return mPerformingAddressUpdate; } //////////// SessionEstablishmentDelegate Implementation /////////////// void OnSessionEstablished(const SessionHandle & session) override; - void OnSessionEstablishmentError(CHIP_ERROR error) override; + void OnSessionEstablishmentError(CHIP_ERROR error, SessionEstablishmentStage stage) override; ScopedNodeId GetPeerId() const { return mPeerId; } @@ -264,6 +294,7 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, Callback::CallbackDeque mConnectionSuccess; Callback::CallbackDeque mConnectionFailure; + Callback::CallbackDeque mSetupFailure; OperationalSessionReleaseDelegate * mReleaseDelegate; @@ -306,8 +337,12 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, void CleanupCASEClient(); + void Connect(Callback::Callback * onConnection, Callback::Callback * onFailure, + Callback::Callback * onSetupFailure); + void EnqueueConnectionCallbacks(Callback::Callback * onConnection, - Callback::Callback * onFailure); + Callback::Callback * onFailure, + Callback::Callback * onSetupFailure); enum class ReleaseBehavior { @@ -316,11 +351,13 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, }; /* - * This dequeues all failure and success callbacks and appropriately - * invokes either set depending on the value of error. + * This dequeues all failure and success callbacks and appropriately invokes either set depending + * on the value of error. + * + * If error == CHIP_NO_ERROR, only success callbacks are invoked. Otherwise, only failure callbacks are invoked. * - * If error == CHIP_NO_ERROR, only success callbacks are invoked. - * Otherwise, only failure callbacks are invoked. + * The state offers additional context regarding the failure, indicating the specific state in which + * the error occurs. It is only relayed through failure callbacks when the error is not equal to CHIP_NO_ERROR. * * If releaseBehavior is Release, this uses mReleaseDelegate to release * ourselves (aka `this`). As a result any caller should return right away @@ -328,15 +365,22 @@ class DLL_EXPORT OperationalSessionSetup : public SessionEstablishmentDelegate, * * Setting releaseBehavior to DoNotRelease is meant for use from the destructor */ - void DequeueConnectionCallbacks(CHIP_ERROR error, ReleaseBehavior releaseBehavior = ReleaseBehavior::Release); + void DequeueConnectionCallbacks(CHIP_ERROR error, SessionEstablishmentStage stage, + ReleaseBehavior releaseBehavior = ReleaseBehavior::Release); + + void DequeueConnectionCallbacks(CHIP_ERROR error, ReleaseBehavior releaseBehavior = ReleaseBehavior::Release) + { + this->DequeueConnectionCallbacks(error, SessionEstablishmentStage::kNotInKeyExchange, releaseBehavior); + } /** * Helper for DequeueConnectionCallbacks that handles the actual callback * notifications. This happens after the object has been released, if it's * being released. */ - static void NotifyConnectionCallbacks(Callback::Cancelable & failureReady, Callback::Cancelable & successReady, - CHIP_ERROR error, const ScopedNodeId & peerId, bool performingAddressUpdate, + static void NotifyConnectionCallbacks(Callback::Cancelable & failureReady, Callback::Cancelable & setupFailureReady, + Callback::Cancelable & successReady, CHIP_ERROR error, SessionEstablishmentStage stage, + const ScopedNodeId & peerId, bool performingAddressUpdate, Messaging::ExchangeManager * exchangeMgr, const Optional & optionalSessionHandle); diff --git a/src/protocols/secure_channel/CASESession.cpp b/src/protocols/secure_channel/CASESession.cpp index 6e51d000e3822c..48d0ffd0f4e5e3 100644 --- a/src/protocols/secure_channel/CASESession.cpp +++ b/src/protocols/secure_channel/CASESession.cpp @@ -556,9 +556,10 @@ void CASESession::OnResponseTimeout(ExchangeContext * ec) void CASESession::AbortPendingEstablish(CHIP_ERROR err) { + SessionEstablishmentStage state = MapCASEStateToSessionEstablishmentStage(mState); Clear(); // Do this last in case the delegate frees us. - NotifySessionEstablishmentError(err); + NotifySessionEstablishmentError(err, state); } CHIP_ERROR CASESession::DeriveSecureSession(CryptoContext & session) const @@ -2255,4 +2256,29 @@ bool CASESession::InvokeBackgroundWorkWatchdog() return watchdogFired; } +// Helper function to map CASESession::State to SessionEstablishmentStage +SessionEstablishmentStage CASESession::MapCASEStateToSessionEstablishmentStage(State caseState) +{ + switch (caseState) + { + case State::kInitialized: + return SessionEstablishmentStage::kNotInKeyExchange; + case State::kSentSigma1: + case State::kSentSigma1Resume: + return SessionEstablishmentStage::kSentSigma1; + case State::kSentSigma2: + case State::kSentSigma2Resume: + return SessionEstablishmentStage::kSentSigma2; + case State::kSendSigma3Pending: + return SessionEstablishmentStage::kReceivedSigma2; + case State::kSentSigma3: + return SessionEstablishmentStage::kSentSigma3; + case State::kHandleSigma3Pending: + return SessionEstablishmentStage::kReceivedSigma3; + // Add more mappings here for other states + default: + return SessionEstablishmentStage::kUnknown; // Default mapping + } +} + } // namespace chip diff --git a/src/protocols/secure_channel/CASESession.h b/src/protocols/secure_channel/CASESession.h index 7453b6b5002dc4..6fc58dfb90dc83 100644 --- a/src/protocols/secure_channel/CASESession.h +++ b/src/protocols/secure_channel/CASESession.h @@ -320,6 +320,8 @@ class DLL_EXPORT CASESession : public Messaging::UnsolicitedMessageHandler, #if CONFIG_BUILD_FOR_HOST_UNIT_TEST Optional mStopHandshakeAtState = Optional::Missing(); #endif // CONFIG_BUILD_FOR_HOST_UNIT_TEST + + SessionEstablishmentStage MapCASEStateToSessionEstablishmentStage(State caseState); }; } // namespace chip diff --git a/src/protocols/secure_channel/PairingSession.cpp b/src/protocols/secure_channel/PairingSession.cpp index 63a1701e66541f..23daea30f2800c 100644 --- a/src/protocols/secure_channel/PairingSession.cpp +++ b/src/protocols/secure_channel/PairingSession.cpp @@ -255,7 +255,7 @@ void PairingSession::Clear() mSessionManager = nullptr; } -void PairingSession::NotifySessionEstablishmentError(CHIP_ERROR error) +void PairingSession::NotifySessionEstablishmentError(CHIP_ERROR error, SessionEstablishmentStage stage) { if (mDelegate == nullptr) { @@ -265,7 +265,7 @@ void PairingSession::NotifySessionEstablishmentError(CHIP_ERROR error) auto * delegate = mDelegate; mDelegate = nullptr; - delegate->OnSessionEstablishmentError(error); + delegate->OnSessionEstablishmentError(error, stage); } void PairingSession::OnSessionReleased() diff --git a/src/protocols/secure_channel/PairingSession.h b/src/protocols/secure_channel/PairingSession.h index 844fa33a41ae68..ffbb6d9966485d 100644 --- a/src/protocols/secure_channel/PairingSession.h +++ b/src/protocols/secure_channel/PairingSession.h @@ -218,10 +218,14 @@ class DLL_EXPORT PairingSession : public SessionDelegate void Clear(); /** - * Notify our delegate about a session establishment error, if we have not - * notified it of an error or success before. + * Notify our delegate about a session establishment error and the stage when the error occurs + * if we have not already notified it of an error or success before. + * + * @param error The error code to report. + * @param stage The stage of the session when the error occurs, defaults to kNotInKeyExchange. */ - void NotifySessionEstablishmentError(CHIP_ERROR error); + void NotifySessionEstablishmentError(CHIP_ERROR error, + SessionEstablishmentStage stage = SessionEstablishmentStage::kNotInKeyExchange); protected: CryptoContext::SessionRole mRole; diff --git a/src/protocols/secure_channel/SessionEstablishmentDelegate.h b/src/protocols/secure_channel/SessionEstablishmentDelegate.h index dc73a0ffe6997d..a074e5ee074c12 100644 --- a/src/protocols/secure_channel/SessionEstablishmentDelegate.h +++ b/src/protocols/secure_channel/SessionEstablishmentDelegate.h @@ -32,6 +32,18 @@ namespace chip { +enum class SessionEstablishmentStage : uint8_t +{ + kUnknown = 0, + kNotInKeyExchange = 1, + kSentSigma1 = 2, + kReceivedSigma1 = 3, + kSentSigma2 = 4, + kReceivedSigma2 = 5, + kSentSigma3 = 6, + kReceivedSigma3 = 7, +}; + class DLL_EXPORT SessionEstablishmentDelegate { public: @@ -42,6 +54,16 @@ class DLL_EXPORT SessionEstablishmentDelegate */ virtual void OnSessionEstablishmentError(CHIP_ERROR error) {} + /** + * Called when session establishment fails with an error and state at the + * failure. This will be called at most once per session establishment and + * will not be called if OnSessionEstablished is called. + */ + virtual void OnSessionEstablishmentError(CHIP_ERROR error, SessionEstablishmentStage stage) + { + OnSessionEstablishmentError(error); + } + /** * Called on start of session establishment process */