From 2216fdedf150a69c8ecdde2bdae887f2db2fa5b3 Mon Sep 17 00:00:00 2001 From: Evgeny Poberezkin <2769109+epoberezkin@users.noreply.github.com> Date: Mon, 30 Jan 2023 17:36:58 +0000 Subject: [PATCH] avoid possible race conditions when cancelled clients/asyncs can be removed after the new ones are added (so that the new are removed as well) --- src/Simplex/Messaging/Agent/Client.hs | 4 +- tests/AgentTests/FunctionalAPITests.hs | 51 ++++++++++---------------- tests/SMPClient.hs | 31 ++++++++-------- 3 files changed, 37 insertions(+), 49 deletions(-) diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index b683a27866..400e58c144 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -520,7 +520,7 @@ throwWhenNoDelivery c SndQueue {server, sndId} = closeProtocolServerClients :: AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () closeProtocolServerClients c clientsSel = - readTVarIO cs >>= mapM_ (forkIO . closeClient) >> atomically (writeTVar cs M.empty) + atomically (swapTVar cs M.empty) >>= mapM_ (forkIO . closeClient) where cs = clientsSel c closeClient cVar = do @@ -530,7 +530,7 @@ closeProtocolServerClients c clientsSel = _ -> pure () cancelActions :: (Foldable f, Monoid (f (Async ()))) => TVar (f (Async ())) -> IO () -cancelActions as = readTVarIO as >>= mapM_ (forkIO . uninterruptibleCancel) >> atomically (writeTVar as mempty) +cancelActions as = atomically (swapTVar as mempty) >>= mapM_ (forkIO . uninterruptibleCancel) withConnLock :: MonadUnliftIO m => AgentClient -> ConnId -> String -> m a -> m a withConnLock _ "" _ = id diff --git a/tests/AgentTests/FunctionalAPITests.hs b/tests/AgentTests/FunctionalAPITests.hs index 1a5c9c4858..fdf931df04 100644 --- a/tests/AgentTests/FunctionalAPITests.hs +++ b/tests/AgentTests/FunctionalAPITests.hs @@ -56,10 +56,10 @@ import Simplex.Messaging.Version import Test.Hspec import UnliftIO -(##>) :: MonadIO m => m (ATransmission 'Agent) -> ATransmission 'Agent -> m () +(##>) :: (HasCallStack, MonadIO m) => m (ATransmission 'Agent) -> ATransmission 'Agent -> m () a ##> t = a >>= \t' -> liftIO (t' `shouldBe` t) -(=##>) :: MonadIO m => m (ATransmission 'Agent) -> (ATransmission 'Agent -> Bool) -> m () +(=##>) :: (HasCallStack, MonadIO m) => m (ATransmission 'Agent) -> (ATransmission 'Agent -> Bool) -> m () a =##> p = a >>= \t -> liftIO (t `shouldSatisfy` p) get :: MonadIO m => AgentClient -> m (ATransmission 'Agent) @@ -85,10 +85,10 @@ agentCfgRatchetV1 = agentCfg {e2eEncryptVRange = vr11} vr11 :: VersionRange vr11 = mkVersionRange 1 1 -runRight_ :: ExceptT AgentErrorType IO () -> Expectation +runRight_ :: HasCallStack => ExceptT AgentErrorType IO () -> Expectation runRight_ action = runExceptT action `shouldReturn` Right () -runRight :: ExceptT AgentErrorType IO a -> IO a +runRight :: HasCallStack => ExceptT AgentErrorType IO a -> IO a runRight action = runExceptT action >>= \case Right x -> pure x @@ -240,7 +240,7 @@ runTestCfg2 aliceCfg bobCfg baseMsgId runTest = do bob <- getSMPAgentClient bobCfg {database = testDB2} initAgentServers runTest alice bob baseMsgId -runAgentClientTest :: AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientTest :: HasCallStack => AgentClient -> AgentClient -> AgentMsgId -> IO () runAgentClientTest alice bob baseId = do runRight_ $ do (bobId, qInfo) <- createConnection alice 1 True SCMInvitation Nothing @@ -275,7 +275,7 @@ runAgentClientTest alice bob baseId = do where msgId = subtract baseId -runAgentClientContactTest :: AgentClient -> AgentClient -> AgentMsgId -> IO () +runAgentClientContactTest :: HasCallStack => AgentClient -> AgentClient -> AgentMsgId -> IO () runAgentClientContactTest alice bob baseId = do runRight_ $ do (_, qInfo) <- createConnection alice 1 True SCMContact Nothing @@ -312,7 +312,7 @@ runAgentClientContactTest alice bob baseId = do where msgId = subtract baseId -noMessages :: AgentClient -> String -> Expectation +noMessages :: HasCallStack => AgentClient -> String -> Expectation noMessages c err = tryGet `shouldReturn` () where tryGet = @@ -320,7 +320,7 @@ noMessages c err = tryGet `shouldReturn` () Just msg -> error $ err <> ": " <> show msg _ -> return () -testAsyncInitiatingOffline :: IO () +testAsyncInitiatingOffline :: HasCallStack => IO () testAsyncInitiatingOffline = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers @@ -337,7 +337,7 @@ testAsyncInitiatingOffline = do get bob ##> ("", aliceId, CON) exchangeGreetings alice' bobId bob aliceId -testAsyncJoiningOfflineBeforeActivation :: IO () +testAsyncJoiningOfflineBeforeActivation :: HasCallStack => IO () testAsyncJoiningOfflineBeforeActivation = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers @@ -354,7 +354,7 @@ testAsyncJoiningOfflineBeforeActivation = do get bob' ##> ("", aliceId, CON) exchangeGreetings alice bobId bob' aliceId -testAsyncBothOffline :: IO () +testAsyncBothOffline :: HasCallStack => IO () testAsyncBothOffline = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers @@ -374,7 +374,7 @@ testAsyncBothOffline = do get bob' ##> ("", aliceId, CON) exchangeGreetings alice' bobId bob' aliceId -testAsyncServerOffline :: ATransport -> IO () +testAsyncServerOffline :: HasCallStack => ATransport -> IO () testAsyncServerOffline t = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers @@ -400,7 +400,7 @@ testAsyncServerOffline t = do get bob ##> ("", aliceId, CON) exchangeGreetings alice bobId bob aliceId -testAsyncHelloTimeout :: IO () +testAsyncHelloTimeout :: HasCallStack => IO () testAsyncHelloTimeout = do -- this test would only work if any of the agent is v1, there is no HELLO timeout in v2 alice <- getSMPAgentClient agentCfgV1 initAgentServers @@ -411,7 +411,7 @@ testAsyncHelloTimeout = do aliceId <- joinConnection bob 1 True cReq "bob's connInfo" get bob ##> ("", aliceId, ERR $ CONN NOT_ACCEPTED) -testDuplicateMessage :: ATransport -> IO () +testDuplicateMessage :: HasCallStack => ATransport -> IO () testDuplicateMessage t = do alice <- getSMPAgentClient agentCfg initAgentServers bob <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers @@ -791,11 +791,10 @@ testDeleteUserQuietly = do exchangeGreetingsMsgId 6 a bId b aId liftIO $ noMessages a "nothing else should be delivered to alice" -testUsersNoServer :: ATransport -> IO () +testUsersNoServer :: HasCallStack => ATransport -> IO () testUsersNoServer t = do a <- getSMPAgentClient agentCfg {initialCleanupDelay = 10000, cleanupInterval = 10000, deleteErrorCount = 3} initAgentServers b <- getSMPAgentClient agentCfg {database = testDB2} initAgentServers - liftIO $ print 1 (aId, bId, auId, _aId', bId') <- withSmpServerStoreLogOn t testPort $ \_ -> runRight $ do (aId, bId) <- makeConnection a b exchangeGreetingsMsgId 4 a bId b aId @@ -803,26 +802,18 @@ testUsersNoServer t = do (aId', bId') <- makeConnectionForUsers a auId b 1 exchangeGreetingsMsgId 4 a bId' b aId' pure (aId, bId, auId, aId', bId') - liftIO $ print 2 get a =##> \case ("", "", DOWN _ [c]) -> c == bId || c == bId'; _ -> False get a =##> \case ("", "", DOWN _ [c]) -> c == bId || c == bId'; _ -> False get b =##> \case ("", "", DOWN _ cs) -> length cs == 2; _ -> False - liftIO $ print 3 runRight_ $ do deleteUser a auId True - liftIO $ print 4 get a =##> \case ("", c, DEL_RCVQ _ _ (Just (BROKER _ e))) -> c == bId' && (e == TIMEOUT || e == NETWORK); _ -> False - liftIO $ print 4.1 get a =##> \case ("", c, DEL_CONN) -> c == bId'; _ -> False - liftIO $ print 4.2 get a =##> \case ("", "", DEL_USER u) -> u == auId; _ -> False - liftIO $ print 5 liftIO $ noMessages a "nothing else should be delivered to alice" withSmpServerStoreLogOn t testPort $ \_ -> runRight_ $ do - liftIO $ print 6 get a =##> \case ("", "", UP _ [c]) -> c == bId; _ -> False get b =##> \case ("", "", UP _ cs) -> length cs == 2; _ -> False - liftIO $ print 7 exchangeGreetingsMsgId 6 a bId b aId testSwitchConnection :: InitialAgentServers -> IO () @@ -859,28 +850,23 @@ phase c connId d p = SWITCH {} <- pure r pure () -testSwitchAsync :: InitialAgentServers -> IO () +testSwitchAsync :: HasCallStack => InitialAgentServers -> IO () testSwitchAsync servers = do - liftIO $ print 1 (aId, bId) <- withA $ \a -> withB $ \b -> runRight $ do (aId, bId) <- makeConnection a b exchangeGreetingsMsgId 4 a bId b aId pure (aId, bId) - liftIO $ print 2 let withA' = session withA bId withB' = session withB aId withA' $ \a -> do switchConnectionAsync a "" bId phase a bId QDRcv SPStarted - liftIO $ print 3 withB' $ \b -> phase b aId QDSnd SPStarted withA' $ \a -> phase a bId QDRcv SPConfirmed - liftIO $ print 4 withB' $ \b -> do phase b aId QDSnd SPConfirmed phase b aId QDSnd SPCompleted withA' $ \a -> phase a bId QDRcv SPCompleted - liftIO $ print 5 withA $ \a -> withB $ \b -> runRight_ $ do subscribeConnection a bId subscribeConnection b aId @@ -956,7 +942,7 @@ testRatchetAdHash = do ad2 <- getConnectionRatchetAdHash b aId liftIO $ ad1 `shouldBe` ad2 -testTwoUsers :: IO () +testTwoUsers :: HasCallStack => IO () testTwoUsers = do let nc = netCfg initAgentServers a <- getSMPAgentClient agentCfg initAgentServers @@ -1021,12 +1007,13 @@ testTwoUsers = do exchangeGreetingsMsgId 8 a bId2 b aId2 exchangeGreetingsMsgId 8 a bId2' b aId2' where + hasClients :: HasCallStack => AgentClient -> Int -> ExceptT AgentErrorType IO () hasClients c n = liftIO $ M.size <$> readTVarIO (smpClients c) `shouldReturn` n -exchangeGreetings :: AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () +exchangeGreetings :: HasCallStack => AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () exchangeGreetings = exchangeGreetingsMsgId 4 -exchangeGreetingsMsgId :: Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () +exchangeGreetingsMsgId :: HasCallStack => Int64 -> AgentClient -> ConnId -> AgentClient -> ConnId -> ExceptT AgentErrorType IO () exchangeGreetingsMsgId msgId alice bobId bob aliceId = do msgId1 <- sendMessage alice bobId SMP.noMsgFlags "hello" liftIO $ msgId1 `shouldBe` msgId diff --git a/tests/SMPClient.hs b/tests/SMPClient.hs index b005699d3a..8fb553bf13 100644 --- a/tests/SMPClient.hs +++ b/tests/SMPClient.hs @@ -5,6 +5,7 @@ {-# LANGUAGE NumericUnderscores #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} @@ -91,25 +92,25 @@ cfg = logTLSErrors = True } -withSmpServerStoreMsgLogOnV2 :: ATransport -> ServiceName -> (ThreadId -> IO a) -> IO a +withSmpServerStoreMsgLogOnV2 :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a withSmpServerStoreMsgLogOnV2 t = withSmpServerConfigOn t cfgV2 {storeLogFile = Just testStoreLogFile, storeMsgsFile = Just testStoreMsgsFile} -withSmpServerStoreMsgLogOn :: ATransport -> ServiceName -> (ThreadId -> IO a) -> IO a +withSmpServerStoreMsgLogOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a withSmpServerStoreMsgLogOn t = withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile, storeMsgsFile = Just testStoreMsgsFile, serverStatsBackupFile = Just testServerStatsBackupFile} -withSmpServerStoreLogOn :: ATransport -> ServiceName -> (ThreadId -> IO a) -> IO a +withSmpServerStoreLogOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a withSmpServerStoreLogOn t = withSmpServerConfigOn t cfg {storeLogFile = Just testStoreLogFile, serverStatsBackupFile = Just testServerStatsBackupFile} -withSmpServerConfigOn :: ATransport -> ServerConfig -> ServiceName -> (ThreadId -> IO a) -> IO a +withSmpServerConfigOn :: HasCallStack => ATransport -> ServerConfig -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a withSmpServerConfigOn t cfg' port' = serverBracket (\started -> runSMPServerBlocking started cfg' {transports = [(port', t)]}) (pure ()) -withSmpServerThreadOn :: ATransport -> ServiceName -> (ThreadId -> IO a) -> IO a +withSmpServerThreadOn :: HasCallStack => ATransport -> ServiceName -> (HasCallStack => ThreadId -> IO a) -> IO a withSmpServerThreadOn t = withSmpServerConfigOn t cfg -serverBracket :: MonadUnliftIO m => (TMVar Bool -> m ()) -> m () -> (ThreadId -> m a) -> m a +serverBracket :: (HasCallStack, MonadUnliftIO m) => (TMVar Bool -> m ()) -> m () -> (HasCallStack => ThreadId -> m a) -> m a serverBracket process afterProcess f = do started <- newEmptyTMVarIO E.bracket @@ -122,16 +123,16 @@ serverBracket process afterProcess f = do Nothing -> error $ "server did not " <> s _ -> pure () -withSmpServerOn :: ATransport -> ServiceName -> IO a -> IO a +withSmpServerOn :: HasCallStack => ATransport -> ServiceName -> IO a -> IO a withSmpServerOn t port' = withSmpServerThreadOn t port' . const -withSmpServer :: ATransport -> IO a -> IO a +withSmpServer :: HasCallStack => ATransport -> IO a -> IO a withSmpServer t = withSmpServerOn t testPort -runSmpTest :: forall c a. Transport c => (THandle c -> IO a) -> IO a +runSmpTest :: forall c a. (HasCallStack, Transport c) => (HasCallStack => THandle c -> IO a) -> IO a runSmpTest test = withSmpServer (transport @c) $ testSMPClient test -runSmpTestN :: forall c a. Transport c => Int -> ([THandle c] -> IO a) -> IO a +runSmpTestN :: forall c a. (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO a) -> IO a runSmpTestN nClients test = withSmpServer (transport @c) $ run nClients [] where run :: Int -> [THandle c] -> IO a @@ -154,25 +155,25 @@ smpServerTest _ t = runSmpTest $ \h -> tPut' h t >> tGet' h [(Nothing, _, (CorrId corrId, qId, Right cmd))] <- tGet h pure (Nothing, corrId, qId, cmd) -smpTest :: Transport c => TProxy c -> (THandle c -> IO ()) -> Expectation +smpTest :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> IO ()) -> Expectation smpTest _ test' = runSmpTest test' `shouldReturn` () -smpTestN :: Transport c => Int -> ([THandle c] -> IO ()) -> Expectation +smpTestN :: (HasCallStack, Transport c) => Int -> (HasCallStack => [THandle c] -> IO ()) -> Expectation smpTestN n test' = runSmpTestN n test' `shouldReturn` () -smpTest2 :: Transport c => TProxy c -> (THandle c -> THandle c -> IO ()) -> Expectation +smpTest2 :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> IO ()) -> Expectation smpTest2 _ test' = smpTestN 2 _test where _test [h1, h2] = test' h1 h2 _test _ = error "expected 2 handles" -smpTest3 :: Transport c => TProxy c -> (THandle c -> THandle c -> THandle c -> IO ()) -> Expectation +smpTest3 :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> IO ()) -> Expectation smpTest3 _ test' = smpTestN 3 _test where _test [h1, h2, h3] = test' h1 h2 h3 _test _ = error "expected 3 handles" -smpTest4 :: Transport c => TProxy c -> (THandle c -> THandle c -> THandle c -> THandle c -> IO ()) -> Expectation +smpTest4 :: (HasCallStack, Transport c) => TProxy c -> (HasCallStack => THandle c -> THandle c -> THandle c -> THandle c -> IO ()) -> Expectation smpTest4 _ test' = smpTestN 4 _test where _test [h1, h2, h3, h4] = test' h1 h2 h3 h4