Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 42 additions & 43 deletions src/Simplex/Messaging/Notifications/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ module Simplex.Messaging.Notifications.Server where
import Control.Concurrent.STM (stateTVar)
import Control.Logger.Simple
import Control.Monad.Except
import Control.Monad.IO.Unlift (MonadUnliftIO)
import Control.Monad.Reader
import Crypto.Random (MonadRandom)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Functor (($>))
Expand Down Expand Up @@ -57,15 +55,17 @@ import UnliftIO.Directory (doesFileExist, renameFile)
import UnliftIO.Exception
import UnliftIO.STM

runNtfServer :: (MonadRandom m, MonadUnliftIO m) => NtfServerConfig -> m ()
runNtfServer :: NtfServerConfig -> IO ()
runNtfServer cfg = do
started <- newEmptyTMVarIO
runNtfServerBlocking started cfg

runNtfServerBlocking :: (MonadRandom m, MonadUnliftIO m) => TMVar Bool -> NtfServerConfig -> m ()
runNtfServerBlocking :: TMVar Bool -> NtfServerConfig -> IO ()
runNtfServerBlocking started cfg = runReaderT (ntfServer cfg started) =<< newNtfServerEnv cfg

ntfServer :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerConfig -> TMVar Bool -> m ()
type M a = ReaderT NtfEnv IO a

ntfServer :: NtfServerConfig -> TMVar Bool -> M ()
ntfServer cfg@NtfServerConfig {transports} started = do
restoreServerStats
s <- asks subscriber
Expand All @@ -74,30 +74,30 @@ ntfServer cfg@NtfServerConfig {transports} started = do
void . forkIO $ resubscribe s subs
raceAny_ (ntfSubscriber s : ntfPush ps : map runServer transports <> serverStatsThread_ cfg) `finally` stopServer
where
runServer :: (ServiceName, ATransport) -> m ()
runServer :: (ServiceName, ATransport) -> M ()
runServer (tcpPort, ATransport t) = do
serverParams <- asks tlsServerParams
runTransportServer started tcpPort serverParams (runClient t)

runClient :: Transport c => TProxy c -> c -> m ()
runClient :: Transport c => TProxy c -> c -> M ()
runClient _ h = do
kh <- asks serverIdentity
liftIO (runExceptT $ ntfServerHandshake h kh supportedNTFServerVRange) >>= \case
Right th -> runNtfClientTransport th
Left _ -> pure ()

stopServer :: m ()
stopServer :: M ()
stopServer = do
withNtfLog closeStoreLog
saveServerStats
asks (smpSubscribers . subscriber) >>= readTVarIO >>= mapM_ (\SMPSubscriber {subThreadId} -> readTVarIO subThreadId >>= mapM_ (liftIO . deRefWeak >=> mapM_ killThread))

serverStatsThread_ :: NtfServerConfig -> [m ()]
serverStatsThread_ :: NtfServerConfig -> [M ()]
serverStatsThread_ NtfServerConfig {logStatsInterval = Just interval, logStatsStartTime, serverStatsLogFile} =
[logServerStats logStatsStartTime interval serverStatsLogFile]
serverStatsThread_ _ = []

logServerStats :: Int -> Int -> FilePath -> m ()
logServerStats :: Int -> Int -> FilePath -> M ()
logServerStats startAt logInterval statsFilePath = do
initialDelay <- (startAt -) . fromIntegral . (`div` 1000000_000000) . diffTimeToPicoseconds . utctDayTime <$> liftIO getCurrentTime
liftIO $ putStrLn $ "server stats log enabled: " <> statsFilePath
Expand Down Expand Up @@ -138,7 +138,7 @@ ntfServer cfg@NtfServerConfig {transports} started = do
]
threadDelay interval

resubscribe :: (MonadUnliftIO m, MonadReader NtfEnv m) => NtfSubscriber -> Map NtfSubscriptionId NtfSubData -> m ()
resubscribe :: NtfSubscriber -> Map NtfSubscriptionId NtfSubData -> M ()
resubscribe NtfSubscriber {newSubQ} subs = do
d <- asks $ resubscribeDelay . config
forM_ subs $ \sub@NtfSubData {} ->
Expand All @@ -147,19 +147,19 @@ resubscribe NtfSubscriber {newSubQ} subs = do
threadDelay d
liftIO $ logInfo "SMP connections resubscribed"

ntfSubscriber :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfSubscriber -> m ()
ntfSubscriber :: NtfSubscriber -> M ()
ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAgent {msgQ, agentQ}} = do
raceAny_ [subscribe, receiveSMP, receiveAgent]
where
subscribe :: m ()
subscribe :: M ()
subscribe =
forever $
atomically (readTBQueue newSubQ) >>= \case
sub@(NtfSub NtfSubData {smpQueue = SMPQueueNtf {smpServer}}) -> do
SMPSubscriber {newSubQ = subscriberSubQ} <- getSMPSubscriber smpServer
atomically $ writeTQueue subscriberSubQ sub

getSMPSubscriber :: SMPServer -> m SMPSubscriber
getSMPSubscriber :: SMPServer -> M SMPSubscriber
getSMPSubscriber smpServer =
atomically (TM.lookup smpServer smpSubscribers) >>= maybe createSMPSubscriber pure
where
Expand All @@ -170,7 +170,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
atomically . writeTVar subThreadId $ Just tId
pure sub

runSMPSubscriber :: SMPSubscriber -> m ()
runSMPSubscriber :: SMPSubscriber -> M ()
runSMPSubscriber SMPSubscriber {newSubQ = subscriberSubQ} =
forever $
atomically (peekTQueue subscriberSubQ)
Expand All @@ -188,7 +188,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
PCENetworkError -> pure ()
_ -> void . atomically $ readTQueue subscriberSubQ

receiveSMP :: m ()
receiveSMP :: M ()
receiveSMP = forever $ do
(srv, _, _, ntfId, msg) <- atomically $ readTBQueue msgQ
let smpQueue = SMPQueueNtf srv ntfId
Expand Down Expand Up @@ -227,7 +227,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
where
showServer' = decodeLatin1 . strEncode . host

handleSubError :: SMPQueueNtf -> ProtocolClientError -> m ()
handleSubError :: SMPQueueNtf -> ProtocolClientError -> M ()
handleSubError smpQueue = \case
PCEProtocolError AUTH -> updateSubStatus smpQueue NSAuth
PCEProtocolError e -> updateErr "SMP error " e
Expand All @@ -240,7 +240,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
PCEResponseTimeout -> pure ()
PCENetworkError -> pure ()
where
updateErr :: Show e => ByteString -> e -> m ()
updateErr :: Show e => ByteString -> e -> M ()
updateErr errType e = updateSubStatus smpQueue . NSErr $ errType <> bshow e

updateSubStatus smpQueue status = do
Expand All @@ -252,7 +252,7 @@ ntfSubscriber NtfSubscriber {smpSubscribers, newSubQ, smpAgent = ca@SMPClientAge
withNtfLog $ \sl -> logSubscriptionStatus sl ntfSubId status
)

ntfPush :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfPushServer -> m ()
ntfPush :: NtfPushServer -> M ()
ntfPush s@NtfPushServer {pushQ} = forever $ do
(tkn@NtfTknData {ntfTknId, token = DeviceToken pp _, tknStatus}, ntf) <- atomically (readTBQueue pushQ)
liftIO $ logDebug $ "sending push notification to " <> T.pack (show pp)
Expand All @@ -275,7 +275,7 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do
_ ->
liftIO $ logError "bad notification token status"
where
deliverNotification :: PushProvider -> NtfTknData -> PushNotification -> m (Either PushProviderError ())
deliverNotification :: PushProvider -> NtfTknData -> PushNotification -> M (Either PushProviderError ())
deliverNotification pp tkn@NtfTknData {ntfTknId, tknStatus} ntf = do
deliver <- liftIO $ getPushClient s pp
liftIO (runExceptT $ deliver tkn ntf) >>= \case
Expand All @@ -289,33 +289,33 @@ ntfPush s@NtfPushServer {pushQ} = forever $ do
PPTokenInvalid -> updateTknStatus NTInvalid >> err e
PPPermanentError -> err e
where
retryDeliver :: m (Either PushProviderError ())
retryDeliver :: M (Either PushProviderError ())
retryDeliver = do
deliver <- liftIO $ newPushClient s pp
liftIO (runExceptT $ deliver tkn ntf) >>= either err (pure . Right)
updateTknStatus :: NtfTknStatus -> m ()
updateTknStatus :: NtfTknStatus -> M ()
updateTknStatus status = do
atomically $ writeTVar tknStatus status
withNtfLog $ \sl -> logTokenStatus sl ntfTknId status
err e = logError (T.pack $ "Push provider error (" <> show pp <> "): " <> show e) $> Left e

runNtfClientTransport :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> m ()
runNtfClientTransport :: Transport c => THandle c -> M ()
runNtfClientTransport th@THandle {sessionId} = do
qSize <- asks $ clientQSize . config
ts <- liftIO getSystemTime
c <- atomically $ newNtfServerClient qSize sessionId ts
s <- asks subscriber
ps <- asks pushServer
expCfg <- asks $ inactiveClientExpiration . config
raceAny_ ([send th c, client c s ps, receive th c] <> disconnectThread_ c expCfg)
`finally` clientDisconnected c
raceAny_ ([liftIO $ send th c, client c s ps, receive th c] <> disconnectThread_ c expCfg)
`finally` liftIO (clientDisconnected c)
where
disconnectThread_ c expCfg = maybe [] ((: []) . disconnectTransport th c activeAt) expCfg
disconnectThread_ c expCfg = maybe [] ((: []) . liftIO . disconnectTransport th c activeAt) expCfg

clientDisconnected :: MonadUnliftIO m => NtfServerClient -> m ()
clientDisconnected :: NtfServerClient -> IO ()
clientDisconnected NtfServerClient {connected} = atomically $ writeTVar connected False

receive :: (Transport c, MonadUnliftIO m, MonadReader NtfEnv m) => THandle c -> NtfServerClient -> m ()
receive :: Transport c => THandle c -> NtfServerClient -> M ()
receive th NtfServerClient {rcvQ, sndQ, activeAt} = forever $ do
ts <- tGet th
forM_ ts $ \t@(_, _, (corrId, entId, cmdOrError)) -> do
Expand All @@ -330,7 +330,7 @@ receive th NtfServerClient {rcvQ, sndQ, activeAt} = forever $ do
where
write q t = atomically $ writeTBQueue q t

send :: (Transport c, MonadUnliftIO m) => THandle c -> NtfServerClient -> m ()
send :: Transport c => THandle c -> NtfServerClient -> IO ()
send h@THandle {thVersion = v} NtfServerClient {sndQ, sessionId, activeAt} = forever $ do
t <- atomically $ readTBQueue sndQ
void . liftIO $ tPut h [(Nothing, encodeTransmission v sessionId t)]
Expand All @@ -341,8 +341,7 @@ send h@THandle {thVersion = v} NtfServerClient {sndQ, sessionId, activeAt} = for

data VerificationResult = VRVerified NtfRequest | VRFailed

verifyNtfTransmission ::
forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => SignedTransmission NtfCmd -> NtfCmd -> m VerificationResult
verifyNtfTransmission :: SignedTransmission NtfCmd -> NtfCmd -> M VerificationResult
verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
st <- asks store
case cmd of
Expand Down Expand Up @@ -384,25 +383,25 @@ verifyNtfTransmission (sig_, signed, (corrId, entId, _)) cmd = do
where
verifiedTknCmd t c = VRVerified (NtfReqCmd SToken (NtfTkn t) (corrId, entId, c))
verifiedSubCmd s c = VRVerified (NtfReqCmd SSubscription (NtfSub s) (corrId, entId, c))
verifyToken :: Maybe NtfTknData -> (NtfTknData -> VerificationResult) -> m VerificationResult
verifyToken :: Maybe NtfTknData -> (NtfTknData -> VerificationResult) -> M VerificationResult
verifyToken t_ positiveVerificationResult =
pure $ case t_ of
Just t@NtfTknData {tknVerifyKey} ->
if verifyCmdSignature sig_ signed tknVerifyKey
then positiveVerificationResult t
else VRFailed
_ -> maybe False (dummyVerifyCmd signed) sig_ `seq` VRFailed
verifyToken' :: Maybe NtfTknData -> VerificationResult -> m VerificationResult
verifyToken' :: Maybe NtfTknData -> VerificationResult -> M VerificationResult
verifyToken' t_ = verifyToken t_ . const

client :: forall m. (MonadUnliftIO m, MonadReader NtfEnv m) => NtfServerClient -> NtfSubscriber -> NtfPushServer -> m ()
client :: NtfServerClient -> NtfSubscriber -> NtfPushServer -> M ()
client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPushServer {pushQ, intervalNotifiers} =
forever $
atomically (readTBQueue rcvQ)
>>= processCommand
>>= atomically . writeTBQueue sndQ
where
processCommand :: NtfRequest -> m (Transmission NtfResponse)
processCommand :: NtfRequest -> M (Transmission NtfResponse)
processCommand = \case
NtfReqNew corrId (ANE SToken newTkn@(NewNtfTkn _ _ dhPubKey)) -> do
logDebug "TNEW - new token"
Expand Down Expand Up @@ -531,28 +530,28 @@ client NtfServerClient {rcvQ, sndQ} NtfSubscriber {newSubQ, smpAgent = ca} NtfPu
incNtfStat subDeleted
pure NROk
PING -> pure NRPong
getId :: m NtfEntityId
getId :: M NtfEntityId
getId = getRandomBytes =<< asks (subIdBytes . config)
getRegCode :: m NtfRegCode
getRegCode :: M NtfRegCode
getRegCode = NtfRegCode <$> (getRandomBytes =<< asks (regCodeBytes . config))
getRandomBytes :: Int -> m ByteString
getRandomBytes :: Int -> M ByteString
getRandomBytes n = do
gVar <- asks idsDrg
atomically (C.pseudoRandomBytes n gVar)
cancelInvervalNotifications :: NtfTokenId -> m ()
cancelInvervalNotifications :: NtfTokenId -> M ()
cancelInvervalNotifications tknId =
atomically (TM.lookupDelete tknId intervalNotifiers)
>>= mapM_ (uninterruptibleCancel . action)

withNtfLog :: (MonadUnliftIO m, MonadReader NtfEnv m) => (StoreLog 'WriteMode -> IO a) -> m ()
withNtfLog :: (StoreLog 'WriteMode -> IO a) -> M ()
withNtfLog action = liftIO . mapM_ action =<< asks storeLog

incNtfStat :: (MonadUnliftIO m, MonadReader NtfEnv m) => (NtfServerStats -> TVar Int) -> m ()
incNtfStat :: (NtfServerStats -> TVar Int) -> M ()
incNtfStat statSel = do
stats <- asks serverStats
atomically $ modifyTVar (statSel stats) (+ 1)

saveServerStats :: (MonadUnliftIO m, MonadReader NtfEnv m) => m ()
saveServerStats :: M ()
saveServerStats =
asks (serverStatsBackupFile . config)
>>= mapM_ (\f -> asks serverStats >>= atomically . getNtfServerStatsData >>= liftIO . saveStats f)
Expand All @@ -562,7 +561,7 @@ saveServerStats =
B.writeFile f $ strEncode stats
logInfo "server stats saved"

restoreServerStats :: (MonadUnliftIO m, MonadReader NtfEnv m) => m ()
restoreServerStats :: M ()
restoreServerStats = asks (serverStatsBackupFile . config) >>= mapM_ restoreStats
where
restoreStats f = whenM (doesFileExist f) $ do
Expand Down
Loading