diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 1074d9abd0..67edbd4581 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -386,13 +386,17 @@ deleteUser' c userId = do newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId newConnAsync c userId corrId enableNtfs cMode = do - g <- asks idsDrg - connAgentVersion <- asks $ maxVersion . smpAgentVRange . config - let cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent - connId <- withStore c $ \db -> createNewConn db g cData cMode + connId <- newConnNoQueues c userId "" enableNtfs cMode enqueueCommand c corrId connId Nothing $ AClientCommand $ NEW enableNtfs (ACM cMode) pure connId +newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> m ConnId +newConnNoQueues c userId connId enableNtfs cMode = do + g <- asks idsDrg + connAgentVersion <- asks $ maxVersion . smpAgentVRange . config + let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent + withStore c $ \db -> createNewConn db g cData cMode + joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do aVRange <- asks $ smpAgentVRange . config @@ -473,9 +477,9 @@ newConn c userId connId asyncMode enableNtfs cMode clientData = newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c) newConnSrv c userId connId asyncMode enableNtfs cMode clientData srv = do AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config - (q, qUri) <- newRcvQueue c userId "" srv smpClientVRange - connId' <- setUpConn asyncMode q $ maxVersion smpAgentVRange - let rq = (q :: RcvQueue) {connId = connId'} + connId' <- if asyncMode then pure connId else newConnNoQueues c userId connId enableNtfs cMode + (rq, qUri) <- newRcvQueue c userId connId' srv smpClientVRange `catchError` \e -> liftIO (print e) >> throwError e + void . withStore c $ \db -> updateNewConnRcv db connId' rq addSubscription c rq when enableNtfs $ do ns <- asks ntfSupervisor @@ -487,14 +491,6 @@ newConnSrv c userId connId asyncMode enableNtfs cMode clientData srv = do (pk1, pk2, e2eRcvParams) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange withStore' c $ \db -> createRatchetX3dhKeys db connId' pk1 pk2 pure (connId', CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange) - where - setUpConn True rq _ = do - void . withStore c $ \db -> updateNewConnRcv db connId rq - pure connId - setUpConn False rq connAgentVersion = do - g <- asks idsDrg - let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent - withStore c $ \db -> createRcvConn db g cData rq cMode joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId joinConn c userId connId asyncMode enableNtfs cReq cInfo = do diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 1b152b87fd..6efbb8b0d0 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -99,7 +99,7 @@ import Data.Bifunctor (bimap, first, second) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B -import Data.Either (isRight, partitionEithers) +import Data.Either (partitionEithers) import Data.Functor (($>)) import Data.List (foldl', partition) import Data.List.NonEmpty (NonEmpty (..), (<|)) @@ -352,17 +352,10 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, where resubscribe :: NonEmpty RcvQueue -> m () resubscribe qs = do - connected <- maybe False isRight <$> atomically (TM.lookup tSess smpClients $>>= tryReadTMVar) cs <- atomically . RQ.getConns $ activeSubs c - -- TODO the unsolved problem is changing session mode - -- currently it would reconnect the same client and will keep it around, until the app is restarted. - -- instead it should check whether session mode changed and connect somehow differently if it did... - (client_, rs) <- subscribeQueues_ c tSess qs - let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) $ L.toList rs + rs <- subscribeQueues c $ L.toList qs + let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs liftIO $ do - unless connected . forM_ client_ $ \cl -> do - incClientStat c userId cl "CONNECT" "" - notifySub "" $ hostEvent CONNECT cl let conns = S.toList $ S.fromList okConns `S.difference` cs unless (null conns) $ notifySub "" $ UP srv conns let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs @@ -610,12 +603,19 @@ runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do testErr step = SMPTestFailure step . protocolClientError SMP addr mkTransportSession :: AgentMonad m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg) -mkTransportSession c userId srv entityId = do - mode <- sessionMode <$> readTVarIO (useNetworkConfig c) - pure (userId, srv, if mode == TSMEntity then Just entityId else Nothing) +mkTransportSession c userId srv entityId = mkTSession userId srv entityId <$> getSessionMode c + +mkTSession :: UserId -> ProtoServer msg -> EntityId -> TransportSessionMode -> TransportSession msg +mkTSession userId srv entityId mode = (userId, srv, if mode == TSMEntity then Just entityId else Nothing) mkSMPTransportSession :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> m SMPTransportSession -mkSMPTransportSession c q = mkTransportSession c (qUserId q) (qServer q) (queueId q) +mkSMPTransportSession c q = mkSMPTSession q <$> getSessionMode c + +mkSMPTSession :: SMPQueueRec q => q -> TransportSessionMode -> SMPTransportSession +mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q) + +getSessionMode :: AgentMonad m => AgentClient -> m TransportSessionMode +getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri) newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange = do @@ -684,7 +684,7 @@ temporaryOrHostError = \case BROKER _ HOST -> True e -> temporaryAgentError e -subscribeQueues :: AgentMonad m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())] +subscribeQueues :: forall m. AgentMonad m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())] subscribeQueues c qs = do (errs, qs') <- partitionEithers <$> mapM checkQueue qs forM_ qs' $ \rq@RcvQueue {connId} -> atomically $ do @@ -693,32 +693,27 @@ subscribeQueues c qs = do (errs <>) <$> do mode <- sessionMode <$> readTVarIO (useNetworkConfig c) let sessRcvQs = foldl' (addRcvQueue mode) M.empty qs' - concat <$> mapConcurrently (fmap (L.toList . snd) . uncurry (subscribeQueues_ c)) (M.assocs sessRcvQs) + concat <$> mapConcurrently (fmap L.toList . uncurry (subscribeQueues_ c)) (M.assocs sessRcvQs) where checkQueue rq@RcvQueue {rcvId, server} = do prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED) else Right rq addRcvQueue :: TransportSessionMode -> Map SMPTransportSession (NonEmpty RcvQueue) -> RcvQueue -> Map SMPTransportSession (NonEmpty RcvQueue) - addRcvQueue mode m rq@RcvQueue {userId, server, rcvId} = - let tSess = (userId, server, if mode == TSMEntity then Just rcvId else Nothing) + addRcvQueue mode m rq = + let tSess = mkSMPTSession rq mode in M.alter (Just . maybe [rq] (rq <|)) tSess m --- | subscribe multiple queues - all passed queues should be on the same server --- TODO check session? depends on how session mode change will be handled -subscribeQueues_ :: AgentMonad m => AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> m (Maybe SMPClient, NonEmpty (RcvQueue, Either AgentErrorType ())) -subscribeQueues_ c tSess@(userId, srv, _) qs = do - smp_ <- tryError $ getSMPServerClient c tSess - (eitherToMaybe smp_,) <$> case smp_ of +subscribeQueues_ :: AgentMonad m => AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> m (NonEmpty (RcvQueue, Either AgentErrorType ())) +subscribeQueues_ c tSess@(userId, srv, _) qs = + tryError (getSMPServerClient c tSess) >>= \case Left e -> pure $ L.map (,Left e) qs - Right smp -> do + Right smp -> liftIO $ do logServer "-->" c srv (bshow (length qs) <> " queues") "SUB" - let qs2 = L.map queueCreds qs - n = (length qs2 - 1) `div` 90 + 1 - liftIO $ incClientStatN c userId smp n "SUBS" "OK" - liftIO $ do - rs <- L.zip qs <$> subscribeSMPQueues smp qs2 - mapM_ (uncurry $ processSubResult c) rs - pure $ L.map (second . first $ protocolClientError SMP $ clientServer smp) rs + let n = (length qs - 1) `div` 90 + 1 + incClientStatN c userId smp n "SUBS" "OK" + rs <- L.zip qs <$> subscribeSMPQueues smp (L.map queueCreds qs) + mapM_ (uncurry $ processSubResult c) rs + pure $ L.map (second . first $ protocolClientError SMP $ clientServer smp) rs where queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId) diff --git a/src/Simplex/Messaging/Agent/TRcvQueues.hs b/src/Simplex/Messaging/Agent/TRcvQueues.hs index e25c6a8ffb..ff525d4bbd 100644 --- a/src/Simplex/Messaging/Agent/TRcvQueues.hs +++ b/src/Simplex/Messaging/Agent/TRcvQueues.hs @@ -47,21 +47,21 @@ addQueue rq (TRcvQueues qs) = TM.insert (qKey rq) rq qs deleteQueue :: RcvQueue -> TRcvQueues -> STM () deleteQueue rq (TRcvQueues qs) = TM.delete (qKey rq) qs -getSessQueues :: (UserId, SMPServer, Maybe RecipientId) -> TRcvQueues -> STM [RcvQueue] +getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue] getSessQueues tSess (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs where addQ qs' rq = if rq `isSession` tSess then rq : qs' else qs' -getDelSessQueues :: (UserId, SMPServer, Maybe RecipientId) -> TRcvQueues -> STM ([RcvQueue], Set ConnId) +getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM ([RcvQueue], Set ConnId) getDelSessQueues tSess (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ (([], S.empty), M.empty) where addQ (removed@(remQs, remConns), qs') rq@RcvQueue {connId} | rq `isSession` tSess = ((rq : remQs, S.insert connId remConns), qs') | otherwise = (removed, M.insert (qKey rq) rq qs') -isSession :: RcvQueue -> (UserId, SMPServer, Maybe RecipientId) -> Bool -isSession RcvQueue {userId, server, rcvId} (uId, srv, qId_) = - userId == uId && server == srv && maybe True (rcvId ==) qId_ +isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool +isSession rq (uId, srv, connId_) = + userId rq == uId && server rq == srv && maybe True (connId rq ==) connId_ -qKey :: RcvQueue -> (UserId, SMPServer, RecipientId) -qKey RcvQueue {userId, server, rcvId} = (userId, server, rcvId) +qKey :: RcvQueue -> (UserId, SMPServer, ConnId) +qKey rq = (userId rq, server rq, connId rq) diff --git a/tests/AgentTests.hs b/tests/AgentTests.hs index bdc3a28d07..3c6bfa6966 100644 --- a/tests/AgentTests.hs +++ b/tests/AgentTests.hs @@ -381,8 +381,8 @@ testMsgDeliveryAgentRestart t bob = do (corrId == "3" && cmd == OK) || (corrId == "" && cmd == SENT 5) _ -> False - bob <# ("", "", UP server ["alice"]) - bob <#= \case ("", "alice", Msg "hello again") -> True; _ -> False + bob <#= \case ("", "alice", Msg "hello again") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False + bob <#= \case ("", "alice", Msg "hello again") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False bob #: ("12", "alice", "ACK 5") #> ("12", "alice", OK) removeFile testStoreLogFile diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index fc780ad1f2..989d652308 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -26,7 +26,7 @@ where import Control.Concurrent (killThread, threadDelay) import Control.Monad -import Control.Monad.Except (ExceptT, MonadError (throwError), runExceptT) +import Control.Monad.Except (ExceptT, runExceptT, throwError) import Control.Monad.IO.Unlift import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B @@ -41,7 +41,8 @@ import Simplex.Messaging.Agent import Simplex.Messaging.Agent.Client (SMPTestFailure (..), SMPTestStep (..)) import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..)) import Simplex.Messaging.Agent.Protocol -import Simplex.Messaging.Client (ProtocolClientConfig (..), defaultClientConfig) +import Simplex.Messaging.Agent.Store (UserId) +import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (TSMEntity, TSMUser), defaultClientConfig) import Simplex.Messaging.Encoding.String import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..)) import qualified Simplex.Messaging.Protocol as SMP @@ -183,6 +184,9 @@ functionalAPITests t = do describe "getRatchetAdHash" $ it "should return the same data for both peers" $ withSmpServer t testRatchetAdHash + describe "multiple users" $ + it "should connect two users and switch session mode" $ + withSmpServer t testTwoUsers testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 = do @@ -446,9 +450,12 @@ testDuplicateMessage t = do get bob2 =##> \case ("", c, Msg "hello 3") -> c == aliceId; _ -> False makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId) -makeConnection alice bob = do - (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing - aliceId <- joinConnection bob 1 True qInfo "bob's connInfo" +makeConnection alice bob = makeConnectionForUsers alice 1 bob 1 + +makeConnectionForUsers :: AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId) +makeConnectionForUsers alice aliceUserId bob bobUserId = do + (bobId, qInfo) <- createConnection alice aliceUserId True SCMInvitation Nothing + aliceId <- joinConnection bob bobUserId True qInfo "bob's connInfo" ("", _, CONF confId _ "bob's connInfo") <- get alice allowConnection alice bobId confId "alice's connInfo" get alice ##> ("", bobId, CON) @@ -536,10 +543,8 @@ testSuspendingAgentCompleteSending t = do get b =##> \case ("", c, SENT 6) -> c == aId; ("", "", UP {}) -> True; _ -> False ("", "", SUSPENDED) <- get b - r <- get a - liftIO $ print r - ("", "", UP {}) <- pure r - get a =##> \case ("", c, Msg "hello too") -> c == bId; _ -> False + get a =##> \case ("", c, Msg "hello too") -> c == bId; ("", "", UP {}) -> True; _ -> False + get a =##> \case ("", c, Msg "hello too") -> c == bId; ("", "", UP {}) -> True; _ -> False ackMessage a bId 5 get a =##> \case ("", c, Msg "how are you?") -> c == bId; _ -> False ackMessage a bId 6 @@ -831,6 +836,73 @@ testRatchetAdHash = do ad2 <- getConnectionRatchetAdHash b aId liftIO $ ad1 `shouldBe` ad2 +testTwoUsers :: IO () +testTwoUsers = do + let nc = netCfg initAgentServers + a <- getSMPAgentClient agentCfg initAgentServers + b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers + sessionMode nc `shouldBe` TSMUser + runRight_ $ do + (aId1, bId1) <- makeConnectionForUsers a 1 b 1 + exchangeGreetings a bId1 b aId1 + (aId1', bId1') <- makeConnectionForUsers a 1 b 1 + exchangeGreetings a bId1' b aId1' + a `hasClients` 1 + b `hasClients` 1 + setNetworkConfig a nc {sessionMode = TSMEntity} + liftIO $ threadDelay 250000 + ("", "", DOWN _ _) <- get a + ("", "", UP _ _) <- get a + a `hasClients` 2 + + exchangeGreetingsMsgId 6 a bId1 b aId1 + exchangeGreetingsMsgId 6 a bId1' b aId1' + liftIO $ threadDelay 250000 + setNetworkConfig a nc {sessionMode = TSMUser} + liftIO $ threadDelay 250000 + ("", "", DOWN _ _) <- get a + ("", "", DOWN _ _) <- get a + ("", "", UP _ _) <- get a + ("", "", UP _ _) <- get a + a `hasClients` 1 + + aUserId2 <- createUser a [noAuthSrv testSMPServer] + (aId2, bId2) <- makeConnectionForUsers a aUserId2 b 1 + exchangeGreetings a bId2 b aId2 + (aId2', bId2') <- makeConnectionForUsers a aUserId2 b 1 + exchangeGreetings a bId2' b aId2' + a `hasClients` 2 + b `hasClients` 1 + setNetworkConfig a nc {sessionMode = TSMEntity} + liftIO $ threadDelay 250000 + ("", "", DOWN _ _) <- get a + ("", "", DOWN _ _) <- get a + ("", "", UP _ _) <- get a + ("", "", UP _ _) <- get a + a `hasClients` 4 + exchangeGreetingsMsgId 8 a bId1 b aId1 + exchangeGreetingsMsgId 8 a bId1' b aId1' + exchangeGreetingsMsgId 6 a bId2 b aId2 + exchangeGreetingsMsgId 6 a bId2' b aId2' + liftIO $ threadDelay 250000 + setNetworkConfig a nc {sessionMode = TSMUser} + liftIO $ threadDelay 250000 + ("", "", DOWN _ _) <- get a + ("", "", DOWN _ _) <- get a + ("", "", DOWN _ _) <- get a + ("", "", DOWN _ _) <- get a + ("", "", UP _ _) <- get a + ("", "", UP _ _) <- get a + ("", "", UP _ _) <- get a + ("", "", UP _ _) <- get a + a `hasClients` 2 + exchangeGreetingsMsgId 10 a bId1 b aId1 + exchangeGreetingsMsgId 10 a bId1' b aId1' + exchangeGreetingsMsgId 8 a bId2 b aId2 + exchangeGreetingsMsgId 8 a bId2' b aId2' + where + hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n + exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () exchangeGreetings = exchangeGreetingsMsgId 4