Skip to content

Commit

Permalink
Store pid of the backend when connecting to Postgres (#71)
Browse files Browse the repository at this point in the history
* Store pid of the backend when connecting to Postgres

* Add BackendPid
  • Loading branch information
arybczak committed Mar 8, 2024
1 parent 8fb828c commit 00ddea3
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# hpqtypes-1.12.0.0 (????-??-??)
* Drop support for GHC 8.8.
* Attach `CallStack` to `DBException`.
* Store ID of the server process attached to the current session.

# hpqtypes-1.11.1.2 (2023-11-08)
* Support multihost setups and the `connect_timeout` parameter in the connection
Expand Down
3 changes: 3 additions & 0 deletions hpqtypes.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ library
, Database.PostgreSQL.PQTypes.SQL.Class
, Database.PostgreSQL.PQTypes.Transaction.Settings
, Database.PostgreSQL.PQTypes.XML
, Database.PostgreSQL.PQTypes.Internal.BackendPid
, Database.PostgreSQL.PQTypes.Internal.Error
, Database.PostgreSQL.PQTypes.Internal.Error.Code
, Database.PostgreSQL.PQTypes.Internal.Composite
Expand Down Expand Up @@ -146,6 +147,7 @@ library
, ConstraintKinds
, DataKinds
, DeriveFunctor
, DerivingStrategies
, ExistentialQuantification
, FlexibleContexts
, FlexibleInstances
Expand Down Expand Up @@ -213,6 +215,7 @@ test-suite hpqtypes-tests
, ConstraintKinds
, DataKinds
, DeriveFunctor
, DerivingStrategies
, ExistentialQuantification
, FlexibleContexts
, FlexibleInstances
Expand Down
13 changes: 11 additions & 2 deletions src/Database/PostgreSQL/PQTypes/Class.hs
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
module Database.PostgreSQL.PQTypes.Class
( QueryName (..)
, MonadDB (..)
( -- * Class
MonadDB (..)

-- * Misc
, BackendPid (..)
, QueryName (..)
) where

import Control.Monad.Trans
import Control.Monad.Trans.Control
import GHC.Stack

import Database.PostgreSQL.PQTypes.FromRow
import Database.PostgreSQL.PQTypes.Internal.BackendPid
import Database.PostgreSQL.PQTypes.Internal.Connection
import Database.PostgreSQL.PQTypes.Internal.Notification
import Database.PostgreSQL.PQTypes.Internal.QueryResult
Expand All @@ -32,6 +37,9 @@ class (Applicative m, Monad m) => MonadDB m where
-- 'getLastQuery'.
withFrozenLastQuery :: m a -> m a

-- | Get ID of the server process attached to the current session.
getBackendPid :: m BackendPid

-- | Get current connection statistics.
getConnectionStats :: HasCallStack => m ConnectionStats

Expand Down Expand Up @@ -89,6 +97,7 @@ instance
runPreparedQuery name = withFrozenCallStack $ lift . runPreparedQuery name
getLastQuery = lift getLastQuery
withFrozenLastQuery m = controlT $ \run -> withFrozenLastQuery (run m)
getBackendPid = lift getBackendPid
getConnectionStats = withFrozenCallStack $ lift getConnectionStats
getQueryResult = lift getQueryResult
clearQueryResult = lift clearQueryResult
Expand Down
7 changes: 7 additions & 0 deletions src/Database/PostgreSQL/PQTypes/Internal/BackendPid.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module Database.PostgreSQL.PQTypes.Internal.BackendPid
( BackendPid (..)
) where

-- | Process ID of the server process attached to the current session.
newtype BackendPid = BackendPid Int
deriving newtype (Eq, Ord, Show)
46 changes: 38 additions & 8 deletions src/Database/PostgreSQL/PQTypes/Internal/Connection.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
{-# LANGUAGE TypeApplications #-}

module Database.PostgreSQL.PQTypes.Internal.Connection
( -- * Connection
Connection (..)
, getBackendPidIO
, ConnectionData (..)
, withConnectionData
, ConnectionStats (..)
Expand All @@ -26,10 +29,11 @@ import Control.Exception qualified as E
import Control.Monad
import Control.Monad.Base
import Control.Monad.Catch
import Data.Bifunctor
import Data.ByteString.Char8 qualified as BS
import Data.Foldable qualified as F
import Data.Functor.Identity
import Data.IORef
import Data.Int
import Data.Kind
import Data.Pool
import Data.Set qualified as S
Expand All @@ -42,12 +46,14 @@ import Foreign.Ptr
import GHC.Conc (closeFdWith)
import GHC.Stack

import Database.PostgreSQL.PQTypes.Internal.BackendPid
import Database.PostgreSQL.PQTypes.Internal.C.Interface
import Database.PostgreSQL.PQTypes.Internal.C.Types
import Database.PostgreSQL.PQTypes.Internal.Composite
import Database.PostgreSQL.PQTypes.Internal.Error
import Database.PostgreSQL.PQTypes.Internal.Error.Code
import Database.PostgreSQL.PQTypes.Internal.Exception
import Database.PostgreSQL.PQTypes.Internal.QueryResult
import Database.PostgreSQL.PQTypes.Internal.Utils
import Database.PostgreSQL.PQTypes.SQL.Class
import Database.PostgreSQL.PQTypes.SQL.Raw
Expand Down Expand Up @@ -114,6 +120,8 @@ initialStats =
data ConnectionData = ConnectionData
{ cdPtr :: !(Ptr PGconn)
-- ^ Pointer to connection object.
, cdBackendPid :: !BackendPid
-- ^ Process ID of the server process attached to the current session.
, cdStats :: !ConnectionStats
-- ^ Statistics associated with the connection.
, cdPreparedQueries :: !(IORef (S.Set T.Text))
Expand All @@ -125,6 +133,11 @@ newtype Connection = Connection
{ unConnection :: MVar (Maybe ConnectionData)
}

getBackendPidIO :: Connection -> IO BackendPid
getBackendPidIO conn = do
withConnectionData conn "getBackendPidIO" $ \cd -> do
pure (cd, cdBackendPid cd)

withConnectionData
:: Connection
-> String
Expand All @@ -133,7 +146,9 @@ withConnectionData
withConnectionData (Connection mvc) fname f =
modifyMVar mvc $ \mc -> case mc of
Nothing -> hpqTypesError $ fname ++ ": no connection"
Just cd -> first Just <$> f cd
Just cd -> do
(cd', r) <- f cd
cd' `seq` pure (Just cd', r)

-- | Database connection supplier.
newtype ConnectionSourceM m = ConnectionSourceM
Expand Down Expand Up @@ -215,12 +230,25 @@ connect ConnectionSettings {..} = mask $ \unmask -> do
Just
ConnectionData
{ cdPtr = connPtr
, cdBackendPid = noBackendPid
, cdStats = initialStats
, cdPreparedQueries = preparedQueries
}
F.forM_ csRole $ \role -> runQueryIO conn $ "SET ROLE " <> role

let selectPid = "SELECT pg_backend_pid()" :: RawSQL ()
(_, res) <- runQueryIO conn selectPid
case F.toList $ mkQueryResult @(Identity Int32) selectPid noBackendPid res of
[pid] -> withConnectionData conn fname $ \cd -> do
pure (cd {cdBackendPid = BackendPid $ fromIntegral pid}, ())
pids -> do
let err = HPQTypesError $ "unexpected backend pid: " ++ show pids
rethrowWithContext selectPid noBackendPid $ toException err

pure conn
where
noBackendPid = BackendPid 0

fname = "connect"

openConnection :: (forall r. IO r -> IO r) -> CString -> IO (Ptr PGconn)
Expand Down Expand Up @@ -317,6 +345,7 @@ runPreparedQueryIO conn (QueryName queryName) sql = do
E.throwIO
DBException
{ dbeQueryContext = sql
, dbeBackendPid = cdBackendPid
, dbeError = HPQTypesError "runPreparedQueryIO: unnamed prepared query is not supported"
, dbeCallStack = callStack
}
Expand All @@ -329,7 +358,7 @@ runPreparedQueryIO conn (QueryName queryName) sql = do
-- succeeds, we need to reflect that fact in cdPreparedQueries since
-- you can't prepare a query with the same name more than once.
res <- c_PQparamPrepare cdPtr nullPtr param cname query
void . withForeignPtr res $ verifyResult sql cdPtr
void . withForeignPtr res $ verifyResult sql cdBackendPid cdPtr
modifyIORef' cdPreparedQueries $ S.insert queryName
(,)
<$> (fromIntegral <$> c_PQparamCount param)
Expand All @@ -353,7 +382,7 @@ runQueryImpl fname conn sql execSql = do
-- runtime system is used) and react appropriately.
queryRunner <- async . restore $ do
(paramCount, res) <- execSql cd
affected <- withForeignPtr res $ verifyResult sql cdPtr
affected <- withForeignPtr res $ verifyResult sql cdBackendPid cdPtr
stats' <- case affected of
Left _ ->
return
Expand All @@ -370,8 +399,7 @@ runQueryImpl fname conn sql execSql = do
, statsValues = statsValues cdStats + (rows * columns)
, statsParams = statsParams cdStats + paramCount
}
-- Force evaluation of modified stats to squash a space leak.
stats' `seq` return (cd {cdStats = stats'}, (either id id affected, res))
return (cd {cdStats = stats'}, (either id id affected, res))
-- If we receive an exception while waiting for the execution to complete,
-- we need to send a request to PostgreSQL for query cancellation and wait
-- for the query runner thread to terminate. It is paramount we make the
Expand Down Expand Up @@ -399,10 +427,11 @@ runQueryImpl fname conn sql execSql = do
verifyResult
:: (HasCallStack, IsSQL sql)
=> sql
-> BackendPid
-> Ptr PGconn
-> Ptr PGresult
-> IO (Either Int Int)
verifyResult sql conn res = do
verifyResult sql pid conn res = do
-- works even if res is NULL
rst <- c_PQresultStatus res
case rst of
Expand All @@ -421,7 +450,7 @@ verifyResult sql conn res = do
_ | otherwise -> return . Left $ 0
where
throwSQLError =
rethrowWithContext sql
rethrowWithContext sql pid
=<< if res == nullPtr
then
return . E.toException . QueryError
Expand Down Expand Up @@ -451,6 +480,7 @@ verifyResult sql conn res = do
E.throwIO
DBException
{ dbeQueryContext = sql
, dbeBackendPid = pid
, dbeError = HPQTypesError ("verifyResult: string returned by PQcmdTuples is not a valid number: " ++ show sn)
, dbeCallStack = callStack
}
13 changes: 11 additions & 2 deletions src/Database/PostgreSQL/PQTypes/Internal/Exception.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ module Database.PostgreSQL.PQTypes.Internal.Exception
import Control.Exception qualified as E
import GHC.Stack

import Database.PostgreSQL.PQTypes.Internal.BackendPid
import Database.PostgreSQL.PQTypes.SQL.Class

-- | Main exception type. All exceptions thrown by
-- the library are additionally wrapped in this type.
data DBException = forall e sql. (E.Exception e, Show sql) => DBException
{ dbeQueryContext :: !sql
-- ^ Last SQL query that was executed.
, dbeBackendPid :: !BackendPid
-- ^ Process ID of the server process attached to the current session.
, dbeError :: !e
-- ^ Specific error.
, dbeCallStack :: CallStack
Expand All @@ -24,11 +27,17 @@ deriving instance Show DBException
instance E.Exception DBException

-- | Rethrow supplied exception enriched with given SQL.
rethrowWithContext :: (HasCallStack, IsSQL sql) => sql -> E.SomeException -> IO a
rethrowWithContext sql (E.SomeException e) =
rethrowWithContext
:: (HasCallStack, IsSQL sql)
=> sql
-> BackendPid
-> E.SomeException
-> IO a
rethrowWithContext sql pid (E.SomeException e) =
E.throwIO
DBException
{ dbeQueryContext = sql
, dbeBackendPid = pid
, dbeError = e
, dbeCallStack = callStack
}
13 changes: 7 additions & 6 deletions src/Database/PostgreSQL/PQTypes/Internal/Monad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import Control.Monad.State.Strict
import Control.Monad.Trans.Control
import Control.Monad.Trans.State.Strict qualified as S
import Control.Monad.Writer.Class
import Data.Bifunctor
import GHC.Stack

import Database.PostgreSQL.PQTypes.Class
Expand Down Expand Up @@ -77,9 +76,9 @@ mapDBT f g m = DBT . StateT $ g . runStateT (unDBT m) . f

instance (m ~ n, MonadBase IO m, MonadMask m) => MonadDB (DBT_ m n) where
runQuery sql = withFrozenCallStack $ DBT . StateT $ \st -> liftBase $ do
second (updateStateWith st sql) <$> runQueryIO (dbConnection st) sql
updateStateWith st sql =<< runQueryIO (dbConnection st) sql
runPreparedQuery name sql = withFrozenCallStack $ DBT . StateT $ \st -> liftBase $ do
second (updateStateWith st sql) <$> runPreparedQueryIO (dbConnection st) name sql
updateStateWith st sql =<< runPreparedQueryIO (dbConnection st) name sql

getLastQuery = DBT . gets $ dbLastQuery

Expand All @@ -88,6 +87,9 @@ instance (m ~ n, MonadBase IO m, MonadMask m) => MonadDB (DBT_ m n) where
(x, st'') <- runStateT (unDBT callback) st'
pure (x, st'' {dbRecordLastQuery = dbRecordLastQuery st})

getBackendPid = DBT . StateT $ \st -> do
(,st) <$> liftBase (getBackendPidIO $ dbConnection st)

getConnectionStats = withFrozenCallStack $ do
mconn <- DBT $ liftBase . readMVar =<< gets (unConnection . dbConnection)
case mconn of
Expand All @@ -100,9 +102,8 @@ instance (m ~ n, MonadBase IO m, MonadMask m) => MonadDB (DBT_ m n) where
getTransactionSettings = DBT . gets $ dbTransactionSettings
setTransactionSettings ts = DBT . modify $ \st -> st {dbTransactionSettings = ts}

getNotification time = DBT . StateT $ \st ->
(,st)
<$> liftBase (getNotificationIO st time)
getNotification time = DBT . StateT $ \st -> do
(,st) <$> liftBase (getNotificationIO st time)

withNewConnection m = DBT . StateT $ \st -> do
let cs = dbConnectionSource st
Expand Down
24 changes: 21 additions & 3 deletions src/Database/PostgreSQL/PQTypes/Internal/QueryResult.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

module Database.PostgreSQL.PQTypes.Internal.QueryResult
( QueryResult (..)
, mkQueryResult
, ntuples
, nfields

Expand All @@ -24,6 +25,7 @@ import System.IO.Unsafe

import Database.PostgreSQL.PQTypes.Format
import Database.PostgreSQL.PQTypes.FromRow
import Database.PostgreSQL.PQTypes.Internal.BackendPid
import Database.PostgreSQL.PQTypes.Internal.C.Interface
import Database.PostgreSQL.PQTypes.Internal.C.Types
import Database.PostgreSQL.PQTypes.Internal.Error
Expand All @@ -35,12 +37,27 @@ import Database.PostgreSQL.PQTypes.SQL.Class
-- extraction appropriately.
data QueryResult t = forall row. FromRow row => QueryResult
{ qrSQL :: !SomeSQL
, qrBackendPid :: !BackendPid
, qrResult :: !(ForeignPtr PGresult)
, qrFromRow :: !(row -> t)
}

mkQueryResult
:: (FromRow t, IsSQL sql)
=> sql
-> BackendPid
-> ForeignPtr PGresult
-> QueryResult t
mkQueryResult sql pid res =
QueryResult
{ qrSQL = SomeSQL sql
, qrBackendPid = pid
, qrResult = res
, qrFromRow = id
}

instance Functor QueryResult where
f `fmap` QueryResult ctx fres g = QueryResult ctx fres (f . g)
f `fmap` QueryResult ctx pid fres g = QueryResult ctx pid fres (f . g)

instance Foldable QueryResult where
foldr f acc = runIdentity . foldrImpl False (coerce f) acc
Expand Down Expand Up @@ -77,7 +94,7 @@ foldImpl
-> acc
-> QueryResult t
-> m acc
foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) fres g) =
foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) pid fres g) =
unsafePerformIO $ withForeignPtr fres $ \res -> do
-- This bit is referentially transparent iff appropriate
-- FrowRow and FromSQL instances are (the ones provided
Expand All @@ -87,6 +104,7 @@ foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) fres g)
E.throwIO
DBException
{ dbeQueryContext = ctx
, dbeBackendPid = pid
, dbeError =
RowLengthMismatch
{ lengthExpected = pqVariablesP rowp
Expand All @@ -101,7 +119,7 @@ foldImpl initCtr termCtr advCtr strict f iacc (QueryResult (SomeSQL ctx) fres g)
then return acc
else do
-- mask asynchronous exceptions so they won't be wrapped in DBException
obj <- E.mask_ (g <$> fromRow res err 0 i `E.catch` rethrowWithContext ctx)
obj <- E.mask_ (g <$> fromRow res err 0 i `E.catch` rethrowWithContext ctx pid)
worker `apply` (f obj =<< acc) $ advCtr i
worker (pure iacc) =<< initCtr res
where
Expand Down
Loading

0 comments on commit 00ddea3

Please sign in to comment.