diff --git a/src/Simplex/Messaging/Agent.hs b/src/Simplex/Messaging/Agent.hs index 36dcee49d..3a39686bf 100644 --- a/src/Simplex/Messaging/Agent.hs +++ b/src/Simplex/Messaging/Agent.hs @@ -358,8 +358,8 @@ getNetworkConfig = readTVarIO . useNetworkConfig reconnectAllServers :: MonadUnliftIO m => AgentClient -> m () reconnectAllServers c = liftIO $ do - closeProtocolServerClients c smpClients - closeProtocolServerClients c ntfClients + reconnectServerClients c smpClients + reconnectServerClients c ntfClients -- | Register device notifications token registerNtfToken :: AgentErrorMonad m => AgentClient -> DeviceToken -> NotificationsMode -> m NtfTknStatus diff --git a/src/Simplex/Messaging/Agent/Client.hs b/src/Simplex/Messaging/Agent/Client.hs index 9c6571f14..d24af2be8 100644 --- a/src/Simplex/Messaging/Agent/Client.hs +++ b/src/Simplex/Messaging/Agent/Client.hs @@ -28,6 +28,7 @@ module Simplex.Messaging.Agent.Client withInvLock, closeAgentClient, closeProtocolServerClients, + reconnectServerClients, closeXFTPServerClient, runSMPServerTest, runXFTPServerTest, @@ -140,6 +141,7 @@ import Data.Bifunctor (bimap, first, second) import Data.ByteString.Base64 import Data.ByteString.Char8 (ByteString) import qualified Data.ByteString.Char8 as B +import Data.Composition ((.:.)) import Data.Either (lefts, partitionEithers) import Data.Functor (($>)) import Data.List (deleteFirstsBy, foldl', partition, (\\)) @@ -499,11 +501,15 @@ instance ProtocolServerClient XFTPErrorType FileResponse where getSMPServerClient :: forall m. AgentMonad m => AgentClient -> SMPTransportSession -> m SMPClient getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, _) = do unlessM (readTVarIO active) . throwError $ INACTIVE - v <- atomically (getTSessVar c tSess smpClients) - either newClient (waitForProtocolClient c tSess) v - `catchAgentError` \e -> resubscribeSMPSession c tSess >> throwError e + atomically (getTSessVar c tSess smpClients) + >>= either newClient (waitForProtocolClient c tSess) where - newClient = newProtocolClient c tSess smpClients connectClient + -- we resubscribe only on newClient error, but not on waitForProtocolClient error, + -- as the large number of delivery workers waiting for the client TMVar + -- make it expensive to check for pending subscriptions. + newClient v = + newProtocolClient c tSess smpClients connectClient v + `catchAgentError` \e -> resubscribeSMPSession c tSess >> throwError e connectClient :: SMPClientVar -> m SMPClient connectClient v = do cfg <- getClientConfig c smpCfg @@ -515,14 +521,19 @@ getSMPServerClient c@AgentClient {active, smpClients, msgQ} tSess@(userId, srv, removeClientAndSubs >>= serverDown logInfo . decodeUtf8 $ "Agent disconnected from " <> showServer srv where + -- we make active subscriptions pending only if the client for tSess was current (in the map) and active, + -- because we can have a race condition when a new current client could have already + -- made subscriptions active, and the old client would be processing diconnection later. removeClientAndSubs :: IO ([RcvQueue], [ConnId]) - removeClientAndSubs = atomically $ do - removeTSessVar v tSess smpClients - qs <- RQ.getDelSessQueues tSess $ activeSubs c - mapM_ (`RQ.addQueue` pendingSubs c) qs - let cs = S.fromList $ map qConnId qs - cs' <- RQ.getConns $ activeSubs c - pure (qs, S.toList $ cs `S.difference` cs') + removeClientAndSubs = atomically $ ifM currentActiveClient removeSubs $ pure ([], []) + where + currentActiveClient = (&&) <$> removeTSessVar' v tSess smpClients <*> readTVar active + removeSubs = do + qs <- RQ.getDelSessQueues tSess $ activeSubs c + mapM_ (`RQ.addQueue` pendingSubs c) qs + let cs = S.fromList $ map qConnId qs + cs' <- RQ.getConns $ activeSubs c + pure (qs, S.toList $ cs `S.difference` cs') serverDown :: ([RcvQueue], [ConnId]) -> IO () serverDown (qs, conns) = whenM (readTVarIO active) $ do @@ -648,9 +659,13 @@ getTSessVar c tSess vs = maybe (Left <$> newSessionVar) (pure . Right) =<< TM.lo pure v removeTSessVar :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM () -removeTSessVar v tSess vs = - TM.lookup tSess vs - >>= mapM_ (\v' -> when (sessionVarId v == sessionVarId v') $ TM.delete tSess vs) +removeTSessVar = void .:. removeTSessVar' + +removeTSessVar' :: SessionVar a -> TransportSession msg -> TMap (TransportSession msg) (SessionVar a) -> STM Bool +removeTSessVar' v tSess vs = + TM.lookup tSess vs >>= \case + Just v' | sessionVarId v == sessionVarId v' -> TM.delete tSess vs $> True + _ -> pure False waitForProtocolClient :: (AgentMonad m, ProtocolTypeI (ProtoType msg)) => AgentClient -> TransportSession msg -> ClientVar msg -> m (Client msg) waitForProtocolClient c (_, srv, _) v = do @@ -738,6 +753,10 @@ closeProtocolServerClients :: ProtocolServerClient err msg => AgentClient -> (Ag closeProtocolServerClients c clientsSel = atomically (clientsSel c `swapTVar` M.empty) >>= mapM_ (forkIO . closeClient_ c) +reconnectServerClients :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> IO () +reconnectServerClients c clientsSel = + readTVarIO (clientsSel c) >>= mapM_ (forkIO . closeClient_ c) + closeClient :: ProtocolServerClient err msg => AgentClient -> (AgentClient -> TMap (TransportSession msg) (ClientVar msg)) -> TransportSession msg -> IO () closeClient c clientSel tSess = atomically (TM.lookupDelete tSess $ clientSel c) >>= mapM_ (closeClient_ c) diff --git a/src/Simplex/Messaging/Client.hs b/src/Simplex/Messaging/Client.hs index a3aaaa84d..cd7797c03 100644 --- a/src/Simplex/Messaging/Client.hs +++ b/src/Simplex/Messaging/Client.hs @@ -118,7 +118,6 @@ data ProtocolClient err msg = ProtocolClient sessionId :: SessionId, sessionTs :: UTCTime, thVersion :: Version, - timeoutPerBlock :: Int, blockSize :: Int, batch :: Bool, client_ :: PClient err msg @@ -151,7 +150,6 @@ clientStub sessionId = do sessionId, sessionTs = undefined, thVersion = 5, - timeoutPerBlock = undefined, blockSize = smpBlockSize, batch = undefined, client_ = @@ -314,7 +312,7 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, `catch` \(e :: IOException) -> pure . Left $ PCEIOError e Left e -> pure $ Left e where - NetworkConfig {tcpConnectTimeout, tcpTimeout, tcpTimeoutPerKb, smpPingInterval} = networkConfig + NetworkConfig {tcpConnectTimeout, tcpTimeout, smpPingInterval} = networkConfig mkProtocolClient :: TransportHost -> STM (PClient err msg) mkProtocolClient transportHost = do connected <- newTVar False @@ -365,8 +363,7 @@ getProtocolClient transportSession@(_, srv, _) cfg@ProtocolClientConfig {qSize, Left e -> atomically . putTMVar cVar . Left $ PCETransportError e Right th@THandle {sessionId, thVersion, blockSize, batch} -> do sessionTs <- getCurrentTime - let timeoutPerBlock = (blockSize * tcpTimeoutPerKb) `div` 1024 - c' = ProtocolClient {action = Nothing, client_ = c, sessionId, thVersion, sessionTs, timeoutPerBlock, blockSize, batch} + let c' = ProtocolClient {action = Nothing, client_ = c, sessionId, thVersion, sessionTs, blockSize, batch} atomically $ do writeTVar (connected c) True putTMVar cVar $ Right c' diff --git a/src/Simplex/Messaging/Server.hs b/src/Simplex/Messaging/Server.hs index 1bf7de2eb..f81f66d8d 100644 --- a/src/Simplex/Messaging/Server.hs +++ b/src/Simplex/Messaging/Server.hs @@ -50,6 +50,7 @@ import qualified Data.ByteString.Char8 as B import Data.Either (fromRight, partitionEithers) import Data.Functor (($>)) import Data.Int (Int64) +import qualified Data.IntMap.Strict as IM import Data.List (intercalate) import qualified Data.List.NonEmpty as L import qualified Data.Map.Strict as M @@ -157,7 +158,7 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do updateSubscribers = do (qId, clnt) <- readTQueue $ subQ s let clientToBeNotified c' = - if sameClientSession clnt c' + if sameClientId clnt c' then pure Nothing else do yes <- readTVar $ connected c' @@ -165,9 +166,12 @@ smpServer started cfg@ServerConfig {transports, transportConfig = tCfg} = do TM.lookupInsert qId clnt (subs s) $>>= clientToBeNotified endPreviousSubscriptions :: (QueueId, Client) -> M (Maybe s) endPreviousSubscriptions (qId, c) = do - void . forkIO $ do + tId <- atomically $ stateTVar (endThreadSeq c) $ \next -> (next, next + 1) + t <- forkIO $ do labelMyThread $ label <> ".endPreviousSubscriptions" atomically $ writeTBQueue (sndQ c) [(CorrId "", qId, END)] + atomically $ modifyTVar' (endThreads c) $ IM.delete tId + mkWeakThreadId t >>= atomically . modifyTVar' (endThreads c) . IM.insert tId atomically $ TM.lookupDelete qId (clientSubs c) expireMessagesThread_ :: ServerConfig -> [M ()] @@ -389,23 +393,26 @@ runClientTransport th@THandle {thVersion, sessionId} = do noSubscriptions c = atomically $ (&&) <$> TM.null (subscriptions c) <*> TM.null (ntfSubscriptions c) clientDisconnected :: Client -> M () -clientDisconnected c@Client {clientId, subscriptions, connected, sessionId} = do +clientDisconnected c@Client {clientId, subscriptions, connected, sessionId, endThreads} = do labelMyThread . B.unpack $ "client $" <> encode sessionId <> " disc" - atomically $ writeTVar connected False - subs <- readTVarIO subscriptions + subs <- atomically $ do + writeTVar connected False + swapTVar subscriptions M.empty liftIO $ mapM_ cancelSub subs - atomically $ writeTVar subscriptions M.empty - cs <- asks $ subscribers . server - atomically . mapM_ (\rId -> TM.update deleteCurrentClient rId cs) $ M.keys subs + srvSubs <- asks $ subscribers . server + atomically $ modifyTVar' srvSubs $ \cs -> + M.foldrWithKey (\sub _ -> M.update deleteCurrentClient sub) cs subs asks clients >>= atomically . TM.delete clientId + tIds <- atomically $ swapTVar endThreads IM.empty + liftIO $ mapM_ (mapM_ killThread <=< deRefWeak) tIds where deleteCurrentClient :: Client -> Maybe Client deleteCurrentClient c' - | sameClientSession c c' = Nothing + | sameClientId c c' = Nothing | otherwise = Just c' -sameClientSession :: Client -> Client -> Bool -sameClientSession Client {sessionId} Client {sessionId = s'} = sessionId == s' +sameClientId :: Client -> Client -> Bool +sameClientId Client {clientId} Client {clientId = cId'} = clientId == cId' cancelSub :: TVar Sub -> IO () cancelSub sub = diff --git a/src/Simplex/Messaging/Server/Env/STM.hs b/src/Simplex/Messaging/Server/Env/STM.hs index ab88331f6..aa86cbd29 100644 --- a/src/Simplex/Messaging/Server/Env/STM.hs +++ b/src/Simplex/Messaging/Server/Env/STM.hs @@ -10,6 +10,8 @@ import Control.Monad.IO.Unlift import Crypto.Random import Data.ByteString.Char8 (ByteString) import Data.Int (Int64) +import Data.IntMap.Strict (IntMap) +import qualified Data.IntMap.Strict as IM import Data.List.NonEmpty (NonEmpty) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M @@ -124,6 +126,8 @@ data Client = Client ntfSubscriptions :: TMap NotifierId (), rcvQ :: TBQueue (NonEmpty (Maybe QueueRec, Transmission Cmd)), sndQ :: TBQueue (NonEmpty (Transmission BrokerMsg)), + endThreads :: TVar (IntMap (Weak ThreadId)), + endThreadSeq :: TVar Int, thVersion :: Version, sessionId :: ByteString, connected :: TVar Bool, @@ -155,10 +159,12 @@ newClient nextClientId qSize thVersion sessionId createdAt = do ntfSubscriptions <- TM.empty rcvQ <- newTBQueue qSize sndQ <- newTBQueue qSize + endThreads <- newTVar IM.empty + endThreadSeq <- newTVar 0 connected <- newTVar True rcvActiveAt <- newTVar createdAt sndActiveAt <- newTVar createdAt - return Client {clientId, subscriptions, ntfSubscriptions, rcvQ, sndQ, thVersion, sessionId, connected, createdAt, rcvActiveAt, sndActiveAt} + return Client {clientId, subscriptions, ntfSubscriptions, rcvQ, sndQ, endThreads, endThreadSeq, thVersion, sessionId, connected, createdAt, rcvActiveAt, sndActiveAt} newSubscription :: SubscriptionThread -> STM Sub newSubscription subThread = do