Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Simplex/Messaging/Agent/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ throwWhenNoDelivery c SndQueue {server, sndId} =

closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO ()
closeProtocolServerClients c clientsSel =
readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty)
atomically (swapTVar cs M.empty) >>= mapM_ (forkIO . closeClient)
where
cs = clientsSel c
closeClient cVar = do
Expand All @@ -530,7 +530,7 @@ closeProtocolServerClients c clientsSel =
_ -> pure ()

cancelActions :: (Foldable f, Monoid (f (Async ()))) => TVar (f (Async ())) -> IO ()
cancelActions as = readTVarIO as >>= mapM_ (forkIO . uninterruptibleCancel) >> atomically (writeTVar as mempty)
cancelActions as = atomically (swapTVar as mempty) >>= mapM_ (forkIO . uninterruptibleCancel)

withConnLock :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m a -> m a
withConnLock _ "" _ = id
Expand Down
51 changes: 19 additions & 32 deletions tests/AgentTests/FunctionalAPITests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ import Simplex.Messaging.Version
import Test.Hspec
import UnliftIO

(##>) :: MonadIO m => m (ATransmission 'Agent) -> ATransmission 'Agent -> m ()
(##>) :: (HasCallStack, MonadIO m) => m (ATransmission 'Agent) -> ATransmission 'Agent -> m ()
a ##> t = a >>= \t' -> liftIO (t' `shouldBe` t)

(=##>) :: MonadIO m => m (ATransmission 'Agent) -> (ATransmission 'Agent -> Bool) -> m ()
(=##>) :: (HasCallStack, MonadIO m) => m (ATransmission 'Agent) -> (ATransmission 'Agent -> Bool) -> m ()
a =##> p = a >>= \t -> liftIO (t `shouldSatisfy` p)

get :: MonadIO m => AgentClient -> m (ATransmission 'Agent)
Expand All @@ -85,10 +85,10 @@ agentCfgRatchetV1 = agentCfg {e2eEncryptVRange = vr11}
vr11 :: VersionRange
vr11 = mkVersionRange 1 1

runRight_ :: ExceptT AgentErrorType IO () -> Expectation
runRight_ :: HasCallStack => ExceptT AgentErrorType IO () -> Expectation
runRight_ action = runExceptT action `shouldReturn` Right ()

runRight :: ExceptT AgentErrorType IO a -> IO a
runRight :: HasCallStack => ExceptT AgentErrorType IO a -> IO a
runRight action =
runExceptT action >>= \case
Right x -> pure x
Expand Down Expand Up @@ -240,7 +240,7 @@ runTestCfg2 aliceCfg bobCfg baseMsgId runTest = do
bob <- getSMPAgentClient bobCfg {database = testDB2} initAgentServers
runTest alice bob baseMsgId

runAgentClientTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
runAgentClientTest :: HasCallStack => AgentClient -> AgentClient -> AgentMsgId -> IO ()
runAgentClientTest alice bob baseId = do
runRight_ $ do
(bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing
Expand Down Expand Up @@ -275,7 +275,7 @@ runAgentClientTest alice bob baseId = do
where
msgId = subtract baseId

runAgentClientContactTest :: AgentClient -> AgentClient -> AgentMsgId -> IO ()
runAgentClientContactTest :: HasCallStack => AgentClient -> AgentClient -> AgentMsgId -> IO ()
runAgentClientContactTest alice bob baseId = do
runRight_ $ do
(_, qInfo) <- createConnection alice 1 True SCMContact Nothing
Expand Down Expand Up @@ -312,15 +312,15 @@ runAgentClientContactTest alice bob baseId = do
where
msgId = subtract baseId

noMessages :: AgentClient -> String -> Expectation
noMessages :: HasCallStack => AgentClient -> String -> Expectation
noMessages c err = tryGet `shouldReturn` ()
where
tryGet =
10000 `timeout` get c >>= \case
Just msg -> error $ err <> ": " <> show msg
_ -> return ()

testAsyncInitiatingOffline :: IO ()
testAsyncInitiatingOffline :: HasCallStack => IO ()
testAsyncInitiatingOffline = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Expand All @@ -337,7 +337,7 @@ testAsyncInitiatingOffline = do
get bob ##> ("", aliceId, CON)
exchangeGreetings alice' bobId bob aliceId

testAsyncJoiningOfflineBeforeActivation :: IO ()
testAsyncJoiningOfflineBeforeActivation :: HasCallStack => IO ()
testAsyncJoiningOfflineBeforeActivation = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Expand All @@ -354,7 +354,7 @@ testAsyncJoiningOfflineBeforeActivation = do
get bob' ##> ("", aliceId, CON)
exchangeGreetings alice bobId bob' aliceId

testAsyncBothOffline :: IO ()
testAsyncBothOffline :: HasCallStack => IO ()
testAsyncBothOffline = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Expand All @@ -374,7 +374,7 @@ testAsyncBothOffline = do
get bob' ##> ("", aliceId, CON)
exchangeGreetings alice' bobId bob' aliceId

testAsyncServerOffline :: ATransport -> IO ()
testAsyncServerOffline :: HasCallStack => ATransport -> IO ()
testAsyncServerOffline t = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Expand All @@ -400,7 +400,7 @@ testAsyncServerOffline t = do
get bob ##> ("", aliceId, CON)
exchangeGreetings alice bobId bob aliceId

testAsyncHelloTimeout :: IO ()
testAsyncHelloTimeout :: HasCallStack => IO ()
testAsyncHelloTimeout = do
-- this test would only work if any of the agent is v1, there is no HELLO timeout in v2
alice <- getSMPAgentClient agentCfgV1 initAgentServers
Expand All @@ -411,7 +411,7 @@ testAsyncHelloTimeout = do
aliceId <- joinConnection bob 1 True cReq "bob's connInfo"
get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED)

testDuplicateMessage :: ATransport -> IO ()
testDuplicateMessage :: HasCallStack => ATransport -> IO ()
testDuplicateMessage t = do
alice <- getSMPAgentClient agentCfg initAgentServers
bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
Expand Down Expand Up @@ -791,38 +791,29 @@ testDeleteUserQuietly = do
exchangeGreetingsMsgId 6 a bId b aId
liftIO $ noMessages a "nothing else should be delivered to alice"

testUsersNoServer :: ATransport -> IO ()
testUsersNoServer :: HasCallStack => ATransport -> IO ()
testUsersNoServer t = do
a <- getSMPAgentClient agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers
b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers
liftIO $ print 1
(aId, bId, auId, _aId', bId') <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
auId <- createUser a [noAuthSrv testSMPServer]
(aId', bId') <- makeConnectionForUsers a auId b 1
exchangeGreetingsMsgId 4 a bId' b aId'
pure (aId, bId, auId, aId', bId')
liftIO $ print 2
get a =##> \case ("", "", DOWN _ [c]) -> c == bId || c == bId'; _ -> False
get a =##> \case ("", "", DOWN _ [c]) -> c == bId || c == bId'; _ -> False
get b =##> \case ("", "", DOWN _ cs) -> length cs == 2; _ -> False
liftIO $ print 3
runRight_ $ do
deleteUser a auId True
liftIO $ print 4
get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c == bId' && (e == TIMEOUT || e == NETWORK); _ -> False
liftIO $ print 4.1
get a =##> \case ("", c, DEL_CONN) -> c == bId'; _ -> False
liftIO $ print 4.2
get a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False
liftIO $ print 5
liftIO $ noMessages a "nothing else should be delivered to alice"
withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do
liftIO $ print 6
get a =##> \case ("", "", UP _ [c]) -> c == bId; _ -> False
get b =##> \case ("", "", UP _ cs) -> length cs == 2; _ -> False
liftIO $ print 7
exchangeGreetingsMsgId 6 a bId b aId

testSwitchConnection :: InitialAgentServers -> IO ()
Expand Down Expand Up @@ -859,28 +850,23 @@ phase c connId d p =
SWITCH {} <- pure r
pure ()

testSwitchAsync :: InitialAgentServers -> IO ()
testSwitchAsync :: HasCallStack => InitialAgentServers -> IO ()
testSwitchAsync servers = do
liftIO $ print 1
(aId, bId) <- withA $ \a -> withB $ \b -> runRight $ do
(aId, bId) <- makeConnection a b
exchangeGreetingsMsgId 4 a bId b aId
pure (aId, bId)
liftIO $ print 2
let withA' = session withA bId
withB' = session withB aId
withA' $ \a -> do
switchConnectionAsync a "" bId
phase a bId QDRcv SPStarted
liftIO $ print 3
withB' $ \b -> phase b aId QDSnd SPStarted
withA' $ \a -> phase a bId QDRcv SPConfirmed
liftIO $ print 4
withB' $ \b -> do
phase b aId QDSnd SPConfirmed
phase b aId QDSnd SPCompleted
withA' $ \a -> phase a bId QDRcv SPCompleted
liftIO $ print 5
withA $ \a -> withB $ \b -> runRight_ $ do
subscribeConnection a bId
subscribeConnection b aId
Expand Down Expand Up @@ -956,7 +942,7 @@ testRatchetAdHash = do
ad2 <- getConnectionRatchetAdHash b aId
liftIO $ ad1 `shouldBe` ad2

testTwoUsers :: IO ()
testTwoUsers :: HasCallStack => IO ()
testTwoUsers = do
let nc = netCfg initAgentServers
a <- getSMPAgentClient agentCfg initAgentServers
Expand Down Expand Up @@ -1021,12 +1007,13 @@ testTwoUsers = do
exchangeGreetingsMsgId 8 a bId2 b aId2
exchangeGreetingsMsgId 8 a bId2' b aId2'
where
hasClients :: HasCallStack => AgentClient -> Int -> ExceptT AgentErrorType IO ()
hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n

exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
exchangeGreetings :: HasCallStack => AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
exchangeGreetings = exchangeGreetingsMsgId 4

exchangeGreetingsMsgId :: Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
exchangeGreetingsMsgId :: HasCallStack => Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO ()
exchangeGreetingsMsgId msgId alice bobId bob aliceId = do
msgId1 <- sendMessage alice bobId SMP.noMsgFlags "hello"
liftIO $ msgId1 `shouldBe` msgId
Expand Down
31 changes: 16 additions & 15 deletions tests/SMPClient.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

Expand Down Expand Up @@ -91,25 +92,25 @@ cfg =
logTLSErrors = True
}

withSmpServerStoreMsgLogOnV2 :: ATransport -> ServiceName -> (ThreadId -> IO a) -> IO a
withSmpServerStoreMsgLogOnV2 :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a
withSmpServerStoreMsgLogOnV2 t = withSmpServerConfigOn t cfgV2 {storeLogFile = Just testStoreLogFile, storeMsgsFile = Just testStoreMsgsFile}

withSmpServerStoreMsgLogOn :: ATransport -> ServiceName -> (ThreadId -> IO a) -> IO a
withSmpServerStoreMsgLogOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a
withSmpServerStoreMsgLogOn t = withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile, storeMsgsFile = Just testStoreMsgsFile, serverStatsBackupFile = Just testServerStatsBackupFile}

withSmpServerStoreLogOn :: ATransport -> ServiceName -> (ThreadId -> IO a) -> IO a
withSmpServerStoreLogOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a
withSmpServerStoreLogOn t = withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile, serverStatsBackupFile = Just testServerStatsBackupFile}

withSmpServerConfigOn :: ATransport -> ServerConfig -> ServiceName -> (ThreadId -> IO a) -> IO a
withSmpServerConfigOn :: HasCallStack => ATransport -> ServerConfig -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a
withSmpServerConfigOn t cfg' port' =
serverBracket
(\started -> runSMPServerBlocking started cfg' {transports = [(port', t)]})
(pure ())

withSmpServerThreadOn :: ATransport -> ServiceName -> (ThreadId -> IO a) -> IO a
withSmpServerThreadOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a
withSmpServerThreadOn t = withSmpServerConfigOn t cfg

serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a
serverBracket :: (HasCallStack, MonadUnliftIO m) => (TMVar Bool -> m ()) -> m () -> (HasCallStack => ThreadId -> m a) -> m a
serverBracket process afterProcess f = do
started <- newEmptyTMVarIO
E.bracket
Expand All @@ -122,16 +123,16 @@ serverBracket process afterProcess f = do
Nothing -> error $ "server did not " <> s
_ -> pure ()

withSmpServerOn :: ATransport -> ServiceName -> IO a -> IO a
withSmpServerOn :: HasCallStack => ATransport -> ServiceName -> IO a -> IO a
withSmpServerOn t port' = withSmpServerThreadOn t port' . const

withSmpServer :: ATransport -> IO a -> IO a
withSmpServer :: HasCallStack => ATransport -> IO a -> IO a
withSmpServer t = withSmpServerOn t testPort

runSmpTest :: forall c a. Transport c => (THandle c -> IO a) -> IO a
runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandle c -> IO a) -> IO a
runSmpTest test = withSmpServer (transport @c) $ testSMPClient test

runSmpTestN :: forall c a. Transport c => Int -> ([THandle c] -> IO a) -> IO a
runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO a) -> IO a
runSmpTestN nClients test = withSmpServer (transport @c) $ run nClients []
where
run :: Int -> [THandle c] -> IO a
Expand All @@ -154,25 +155,25 @@ smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h
[(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h
pure (Nothing, corrId, qId, cmd)

smpTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation
smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> IO ()) -> Expectation
smpTest _ test' = runSmpTest test' `shouldReturn` ()

smpTestN :: Transport c => Int -> ([THandle c] -> IO ()) -> Expectation
smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO ()) -> Expectation
smpTestN n test' = runSmpTestN n test' `shouldReturn` ()

smpTest2 :: Transport c => TProxy c -> (THandle c -> THandle c -> IO ()) -> Expectation
smpTest2 :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation
smpTest2 _ test' = smpTestN 2 _test
where
_test [h1, h2] = test' h1 h2
_test _ = error "expected 2 handles"

smpTest3 :: Transport c => TProxy c -> (THandle c -> THandle c -> THandle c -> IO ()) -> Expectation
smpTest3 :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> IO ()) -> Expectation
smpTest3 _ test' = smpTestN 3 _test
where
_test [h1, h2, h3] = test' h1 h2 h3
_test _ = error "expected 3 handles"

smpTest4 :: Transport c => TProxy c -> (THandle c -> THandle c -> THandle c -> THandle c -> IO ()) -> Expectation
smpTest4 :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> THandle c -> IO ()) -> Expectation
smpTest4 _ test' = smpTestN 4 _test
where
_test [h1, h2, h3, h4] = test' h1 h2 h3 h4
Expand Down