Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions src/Simplex/Messaging/Agent.hs
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,17 @@ deleteUser' c userId = do

newConnAsync :: forall m c. (AgentMonad m, ConnectionModeI c) => AgentClient -> UserId -> ACorrId -> Bool -> SConnectionMode c -> m ConnId
newConnAsync c userId corrId enableNtfs cMode = do
g <- asks idsDrg
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
let cData = ConnData {userId, connId = "", connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
connId <- withStore c $ \db -> createNewConn db g cData cMode
connId <- newConnNoQueues c userId "" enableNtfs cMode
enqueueCommand c corrId connId Nothing $ AClientCommand $ NEW enableNtfs (ACM cMode)
pure connId

newConnNoQueues :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> SConnectionMode c -> m ConnId
newConnNoQueues c userId connId enableNtfs cMode = do
g <- asks idsDrg
connAgentVersion <- asks $ maxVersion . smpAgentVRange . config
let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
withStore c $ \db -> createNewConn db g cData cMode

joinConnAsync :: AgentMonad m => AgentClient -> UserId -> ACorrId -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConnAsync c userId corrId enableNtfs cReqUri@(CRInvitationUri ConnReqUriData {crAgentVRange} _) cInfo = do
aVRange <- asks $ smpAgentVRange . config
Expand Down Expand Up @@ -473,9 +477,9 @@ newConn c userId connId asyncMode enableNtfs cMode clientData =
newConnSrv :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> SConnectionMode c -> Maybe CRClientData -> SMPServerWithAuth -> m (ConnId, ConnectionRequestUri c)
newConnSrv c userId connId asyncMode enableNtfs cMode clientData srv = do
AgentConfig {smpClientVRange, smpAgentVRange, e2eEncryptVRange} <- asks config
(q, qUri) <- newRcvQueue c userId "" srv smpClientVRange
connId' <- setUpConn asyncMode q $ maxVersion smpAgentVRange
let rq = (q :: RcvQueue) {connId = connId'}
connId' <- if asyncMode then pure connId else newConnNoQueues c userId connId enableNtfs cMode
(rq, qUri) <- newRcvQueue c userId connId' srv smpClientVRange `catchError` \e -> liftIO (print e) >> throwError e
void . withStore c $ \db -> updateNewConnRcv db connId' rq
addSubscription c rq
when enableNtfs $ do
ns <- asks ntfSupervisor
Expand All @@ -487,14 +491,6 @@ newConnSrv c userId connId asyncMode enableNtfs cMode clientData srv = do
(pk1, pk2, e2eRcvParams) <- liftIO . CR.generateE2EParams $ maxVersion e2eEncryptVRange
withStore' c $ \db -> createRatchetX3dhKeys db connId' pk1 pk2
pure (connId', CRInvitationUri crData $ toVersionRangeT e2eRcvParams e2eEncryptVRange)
where
setUpConn True rq _ = do
void . withStore c $ \db -> updateNewConnRcv db connId rq
pure connId
setUpConn False rq connAgentVersion = do
g <- asks idsDrg
let cData = ConnData {userId, connId, connAgentVersion, enableNtfs, duplexHandshake = Nothing, deleted = False} -- connection mode is determined by the accepting agent
withStore c $ \db -> createRcvConn db g cData rq cMode

joinConn :: AgentMonad m => AgentClient -> UserId -> ConnId -> Bool -> Bool -> ConnectionRequestUri c -> ConnInfo -> m ConnId
joinConn c userId connId asyncMode enableNtfs cReq cInfo = do
Expand Down
59 changes: 27 additions & 32 deletions src/Simplex/Messaging/Agent/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ import Data.Bifunctor (bimap, first, second)
import Data.ByteString.Base64
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Either (isRight, partitionEithers)
import Data.Either (partitionEithers)
import Data.Functor (($>))
import Data.List (foldl', partition)
import Data.List.NonEmpty (NonEmpty (..), (<|))
Expand Down Expand Up @@ -352,17 +352,10 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv,
where
resubscribe :: NonEmpty RcvQueue -> m ()
resubscribe qs = do
connected <- maybe False isRight <$> atomically (TM.lookup tSess smpClients $>>= tryReadTMVar)
cs <- atomically . RQ.getConns $ activeSubs c
-- TODO the unsolved problem is changing session mode
-- currently it would reconnect the same client and will keep it around, until the app is restarted.
-- instead it should check whether session mode changed and connect somehow differently if it did...
(client_, rs) <- subscribeQueues_ c tSess qs
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) $ L.toList rs
rs <- subscribeQueues c $ L.toList qs
let (errs, okConns) = partitionEithers $ map (\(RcvQueue {connId}, r) -> bimap (connId,) (const connId) r) rs
liftIO $ do
unless connected . forM_ client_ $ \cl -> do
incClientStat c userId cl "CONNECT" ""
notifySub "" $ hostEvent CONNECT cl
let conns = S.toList $ S.fromList okConns `S.difference` cs
unless (null conns) $ notifySub "" $ UP srv conns
let (tempErrs, finalErrs) = partition (temporaryAgentError . snd) errs
Expand Down Expand Up @@ -610,12 +603,19 @@ runSMPServerTest c userId (ProtoServerWithAuth srv auth) = do
testErr step = SMPTestFailure step . protocolClientError SMP addr

mkTransportSession :: AgentMonad m => AgentClient -> UserId -> ProtoServer msg -> EntityId -> m (TransportSession msg)
mkTransportSession c userId srv entityId = do
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
pure (userId, srv, if mode == TSMEntity then Just entityId else Nothing)
mkTransportSession c userId srv entityId = mkTSession userId srv entityId <$> getSessionMode c

mkTSession :: UserId -> ProtoServer msg -> EntityId -> TransportSessionMode -> TransportSession msg
mkTSession userId srv entityId mode = (userId, srv, if mode == TSMEntity then Just entityId else Nothing)

mkSMPTransportSession :: (AgentMonad m, SMPQueueRec q) => AgentClient -> q -> m SMPTransportSession
mkSMPTransportSession c q = mkTransportSession c (qUserId q) (qServer q) (queueId q)
mkSMPTransportSession c q = mkSMPTSession q <$> getSessionMode c

mkSMPTSession :: SMPQueueRec q => q -> TransportSessionMode -> SMPTransportSession
mkSMPTSession q = mkTSession (qUserId q) (qServer q) (qConnId q)

getSessionMode :: AgentMonad m => AgentClient -> m TransportSessionMode
getSessionMode = fmap sessionMode . readTVarIO . useNetworkConfig

newRcvQueue :: AgentMonad m => AgentClient -> UserId -> ConnId -> SMPServerWithAuth -> VersionRange -> m (RcvQueue, SMPQueueUri)
newRcvQueue c userId connId (ProtoServerWithAuth srv auth) vRange = do
Expand Down Expand Up @@ -684,7 +684,7 @@ temporaryOrHostError = \case
BROKER _ HOST -> True
e -> temporaryAgentError e

subscribeQueues :: AgentMonad m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
subscribeQueues :: forall m. AgentMonad m => AgentClient -> [RcvQueue] -> m [(RcvQueue, Either AgentErrorType ())]
subscribeQueues c qs = do
(errs, qs') <- partitionEithers <$> mapM checkQueue qs
forM_ qs' $ \rq@RcvQueue {connId} -> atomically $ do
Expand All @@ -693,32 +693,27 @@ subscribeQueues c qs = do
(errs <>) <$> do
mode <- sessionMode <$> readTVarIO (useNetworkConfig c)
let sessRcvQs = foldl' (addRcvQueue mode) M.empty qs'
concat <$> mapConcurrently (fmap (L.toList . snd) . uncurry (subscribeQueues_ c)) (M.assocs sessRcvQs)
concat <$> mapConcurrently (fmap L.toList . uncurry (subscribeQueues_ c)) (M.assocs sessRcvQs)
where
checkQueue rq@RcvQueue {rcvId, server} = do
prohibited <- atomically . TM.member (server, rcvId) $ getMsgLocks c
pure $ if prohibited then Left (rq, Left $ CMD PROHIBITED) else Right rq
addRcvQueue :: TransportSessionMode -> Map SMPTransportSession (NonEmpty RcvQueue) -> RcvQueue -> Map SMPTransportSession (NonEmpty RcvQueue)
addRcvQueue mode m rq@RcvQueue {userId, server, rcvId} =
let tSess = (userId, server, if mode == TSMEntity then Just rcvId else Nothing)
addRcvQueue mode m rq =
let tSess = mkSMPTSession rq mode
in M.alter (Just . maybe [rq] (rq <|)) tSess m

-- | subscribe multiple queues - all passed queues should be on the same server
-- TODO check session? depends on how session mode change will be handled
subscribeQueues_ :: AgentMonad m => AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> m (Maybe SMPClient, NonEmpty (RcvQueue, Either AgentErrorType ()))
subscribeQueues_ c tSess@(userId, srv, _) qs = do
smp_ <- tryError $ getSMPServerClient c tSess
(eitherToMaybe smp_,) <$> case smp_ of
subscribeQueues_ :: AgentMonad m => AgentClient -> SMPTransportSession -> NonEmpty RcvQueue -> m (NonEmpty (RcvQueue, Either AgentErrorType ()))
subscribeQueues_ c tSess@(userId, srv, _) qs =
tryError (getSMPServerClient c tSess) >>= \case
Left e -> pure $ L.map (,Left e) qs
Right smp -> do
Right smp -> liftIO $ do
logServer "-->" c srv (bshow (length qs) <> " queues") "SUB"
let qs2 = L.map queueCreds qs
n = (length qs2 - 1) `div` 90 + 1
liftIO $ incClientStatN c userId smp n "SUBS" "OK"
liftIO $ do
rs <- L.zip qs <$> subscribeSMPQueues smp qs2
mapM_ (uncurry $ processSubResult c) rs
pure $ L.map (second . first $ protocolClientError SMP $ clientServer smp) rs
let n = (length qs - 1) `div` 90 + 1
incClientStatN c userId smp n "SUBS" "OK"
rs <- L.zip qs <$> subscribeSMPQueues smp (L.map queueCreds qs)
mapM_ (uncurry $ processSubResult c) rs
pure $ L.map (second . first $ protocolClientError SMP $ clientServer smp) rs
where
queueCreds RcvQueue {rcvPrivateKey, rcvId} = (rcvPrivateKey, rcvId)

Expand Down
14 changes: 7 additions & 7 deletions src/Simplex/Messaging/Agent/TRcvQueues.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ addQueue rq (TRcvQueues qs) = TM.insert (qKey rq) rq qs
deleteQueue :: RcvQueue -> TRcvQueues -> STM ()
deleteQueue rq (TRcvQueues qs) = TM.delete (qKey rq) qs

getSessQueues :: (UserId, SMPServer, Maybe RecipientId) -> TRcvQueues -> STM [RcvQueue]
getSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM [RcvQueue]
getSessQueues tSess (TRcvQueues qs) = M.foldl' addQ [] <$> readTVar qs
where
addQ qs' rq = if rq `isSession` tSess then rq : qs' else qs'

getDelSessQueues :: (UserId, SMPServer, Maybe RecipientId) -> TRcvQueues -> STM ([RcvQueue], Set ConnId)
getDelSessQueues :: (UserId, SMPServer, Maybe ConnId) -> TRcvQueues -> STM ([RcvQueue], Set ConnId)
getDelSessQueues tSess (TRcvQueues qs) = stateTVar qs $ M.foldl' addQ (([], S.empty), M.empty)
where
addQ (removed@(remQs, remConns), qs') rq@RcvQueue {connId}
| rq `isSession` tSess = ((rq : remQs, S.insert connId remConns), qs')
| otherwise = (removed, M.insert (qKey rq) rq qs')

isSession :: RcvQueue -> (UserId, SMPServer, Maybe RecipientId) -> Bool
isSession RcvQueue {userId, server, rcvId} (uId, srv, qId_) =
userId == uId && server == srv && maybe True (rcvId ==) qId_
isSession :: RcvQueue -> (UserId, SMPServer, Maybe ConnId) -> Bool
isSession rq (uId, srv, connId_) =
userId rq == uId && server rq == srv && maybe True (connId rq ==) connId_

qKey :: RcvQueue -> (UserId, SMPServer, RecipientId)
qKey RcvQueue {userId, server, rcvId} = (userId, server, rcvId)
qKey :: RcvQueue -> (UserId, SMPServer, ConnId)
qKey rq = (userId rq, server rq, connId rq)
4 changes: 2 additions & 2 deletions tests/AgentTests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,8 @@ testMsgDeliveryAgentRestart t bob = do
(corrId == "3" && cmd == OK)
|| (corrId == "" && cmd == SENT 5)
_ -> False
bob <# ("", "", UP server ["alice"])
bob <#= \case ("", "alice", Msg "hello again") -> True; _ -> False
bob <#= \case ("", "alice", Msg "hello again") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False
bob <#= \case ("", "alice", Msg "hello again") -> True; ("", "", UP s ["alice"]) -> s == server; _ -> False
bob #: ("12", "alice", "ACK 5") #> ("12", "alice", OK)

removeFile testStoreLogFile
Expand Down
90 changes: 81 additions & 9 deletions tests/AgentTests/FunctionalAPITests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ where

import Control.Concurrent (killThread, threadDelay)
import Control.Monad
import Control.Monad.Except (ExceptT, MonadError (throwError), runExceptT)
import Control.Monad.Except (ExceptT, runExceptT, throwError)
import Control.Monad.IO.Unlift
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
Expand All @@ -41,7 +41,8 @@ import Simplex.Messaging.Agent
import Simplex.Messaging.Agent.Client (SMPTestFailure (..), SMPTestStep (..))
import Simplex.Messaging.Agent.Env.SQLite (AgentConfig (..), InitialAgentServers (..))
import Simplex.Messaging.Agent.Protocol
import Simplex.Messaging.Client (ProtocolClientConfig (..), defaultClientConfig)
import Simplex.Messaging.Agent.Store (UserId)
import Simplex.Messaging.Client (NetworkConfig (..), ProtocolClientConfig (..), TransportSessionMode (TSMEntity, TSMUser), defaultClientConfig)
import Simplex.Messaging.Encoding.String
import Simplex.Messaging.Protocol (BasicAuth, ErrorType (..), MsgBody, ProtocolServer (..))
import qualified Simplex.Messaging.Protocol as SMP
Expand Down Expand Up @@ -183,6 +184,9 @@ functionalAPITests t = do
describe "getRatchetAdHash" $
it "should return the same data for both peers" $
withSmpServer t testRatchetAdHash
describe "multiple users" $
it "should connect two users and switch session mode" $
withSmpServer t testTwoUsers

testBasicAuth :: ATransport -> Bool -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> (Maybe BasicAuth, Version) -> IO Int
testBasicAuth t allowNewQueues srv@(srvAuth, srvVersion) clnt1 clnt2 = do
Expand Down Expand Up @@ -446,9 +450,12 @@ testDuplicateMessage t = do
get bob2 =##> \case ("", c, Msg "hello 3") -> c == aliceId; _ -> False

makeConnection :: AgentClient -> AgentClient -> ExceptT AgentErrorType IO (ConnId, ConnId)
makeConnection alice bob = do
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
aliceId <- joinConnection bob 1 True qInfo "bob's connInfo"
makeConnection alice bob = makeConnectionForUsers alice 1 bob 1

makeConnectionForUsers :: AgentClient -> UserId -> AgentClient -> UserId -> ExceptT AgentErrorType IO (ConnId, ConnId)
makeConnectionForUsers alice aliceUserId bob bobUserId = do
(bobId, qInfo) <- createConnection alice aliceUserId True SCMInvitation Nothing
aliceId <- joinConnection bob bobUserId True qInfo "bob's connInfo"
("", _, CONF confId _ "bob's connInfo") <- get alice
allowConnection alice bobId confId "alice's connInfo"
get alice ##> ("", bobId, CON)
Expand Down Expand Up @@ -536,10 +543,8 @@ testSuspendingAgentCompleteSending t = do
get b =##> \case ("", c, SENT 6) -> c == aId; ("", "", UP {}) -> True; _ -> False
("", "", SUSPENDED) <- get b

r <- get a
liftIO $ print r
("", "", UP {}) <- pure r
get a =##> \case ("", c, Msg "hello too") -> c == bId; _ -> False
get a =##> \case ("", c, Msg "hello too") -> c == bId; ("", "", UP {}) -> True; _ -> False
get a =##> \case ("", c, Msg "hello too") -> c == bId; ("", "", UP {}) -> True; _ -> False
ackMessage a bId 5
get a =##> \case ("", c, Msg "how are you?") -> c == bId; _ -> False
ackMessage a bId 6
Expand Down Expand Up @@ -831,6 +836,73 @@ testRatchetAdHash = do
ad2 <- getConnectionRatchetAdHash b aId
liftIO $ ad1 `shouldBe` ad2

testTwoUsers :: IO ()
testTwoUsers = do
let nc = netCfg initAgentServers
a <- getSMPAgentClient agentCfg initAgentServers
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
sessionMode nc `shouldBe` TSMUser
runRight_ $ do
(aId1, bId1) <- makeConnectionForUsers a 1 b 1
exchangeGreetings a bId1 b aId1
(aId1', bId1') <- makeConnectionForUsers a 1 b 1
exchangeGreetings a bId1' b aId1'
a `hasClients` 1
b `hasClients` 1
setNetworkConfig a nc {sessionMode = TSMEntity}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- get a
("", "", UP _ _) <- get a
a `hasClients` 2

exchangeGreetingsMsgId 6 a bId1 b aId1
exchangeGreetingsMsgId 6 a bId1' b aId1'
liftIO $ threadDelay 250000
setNetworkConfig a nc {sessionMode = TSMUser}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- get a
("", "", DOWN _ _) <- get a
("", "", UP _ _) <- get a
("", "", UP _ _) <- get a
a `hasClients` 1

aUserId2 <- createUser a [noAuthSrv testSMPServer]
(aId2, bId2) <- makeConnectionForUsers a aUserId2 b 1
exchangeGreetings a bId2 b aId2
(aId2', bId2') <- makeConnectionForUsers a aUserId2 b 1
exchangeGreetings a bId2' b aId2'
a `hasClients` 2
b `hasClients` 1
setNetworkConfig a nc {sessionMode = TSMEntity}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- get a
("", "", DOWN _ _) <- get a
("", "", UP _ _) <- get a
("", "", UP _ _) <- get a
a `hasClients` 4
exchangeGreetingsMsgId 8 a bId1 b aId1
exchangeGreetingsMsgId 8 a bId1' b aId1'
exchangeGreetingsMsgId 6 a bId2 b aId2
exchangeGreetingsMsgId 6 a bId2' b aId2'
liftIO $ threadDelay 250000
setNetworkConfig a nc {sessionMode = TSMUser}
liftIO $ threadDelay 250000
("", "", DOWN _ _) <- get a
("", "", DOWN _ _) <- get a
("", "", DOWN _ _) <- get a
("", "", DOWN _ _) <- get a
("", "", UP _ _) <- get a
("", "", UP _ _) <- get a
("", "", UP _ _) <- get a
("", "", UP _ _) <- get a
a `hasClients` 2
exchangeGreetingsMsgId 10 a bId1 b aId1
exchangeGreetingsMsgId 10 a bId1' b aId1'
exchangeGreetingsMsgId 8 a bId2 b aId2
exchangeGreetingsMsgId 8 a bId2' b aId2'
where
hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n

exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
exchangeGreetings = exchangeGreetingsMsgId 4

Expand Down