Skip to content

Commit

Permalink
Add support for setting a custom role when establishing a connection (#…
Browse files Browse the repository at this point in the history
…60)

* Add support for setting a custom role when establishing a connection

* Test that SET ROLE works if possible

* Bump version
  • Loading branch information
arybczak committed Jan 31, 2023
1 parent 93800e9 commit b97c83c
Show file tree
Hide file tree
Showing 8 changed files with 245 additions and 221 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
@@ -1,3 +1,6 @@
# hpqtypes-1.11.1.0 (2023-??-??)
* Add support for setting a custom role when establishing a connection.

# hpqtypes-1.11.0.0 (2023-01-18)
* Require `resource-pool` >= 0.4 and adjust the `createPool` function to
seamlessly accommodate future changes to the `resource-pool` library.
Expand Down
4 changes: 1 addition & 3 deletions hpqtypes.cabal
@@ -1,5 +1,5 @@
name: hpqtypes
version: 1.11.0.0
version: 1.11.1.0
synopsis: Haskell bindings to libpqtypes

description: Efficient and easy-to-use bindings to (slightly modified)
Expand Down Expand Up @@ -86,7 +86,6 @@ library
, Database.PostgreSQL.PQTypes.Internal.Monad
, Database.PostgreSQL.PQTypes.Internal.Notification
, Database.PostgreSQL.PQTypes.Internal.QueryResult
, Database.PostgreSQL.PQTypes.Internal.Query
, Database.PostgreSQL.PQTypes.Internal.State
, Database.PostgreSQL.PQTypes.Internal.C.Put
, Database.PostgreSQL.PQTypes.Internal.C.Types
Expand Down Expand Up @@ -189,7 +188,6 @@ test-suite hpqtypes-tests
, monad-control >= 1.0.3
, mtl >= 2.1
, random >= 1.0
, resource-pool
, scientific
, test-framework >= 0.8
, test-framework-hunit >= 0.3
Expand Down
1 change: 0 additions & 1 deletion src/Database/PostgreSQL/PQTypes/Class.hs
Expand Up @@ -9,7 +9,6 @@ import Control.Monad.Trans.Control
import Database.PostgreSQL.PQTypes.FromRow
import Database.PostgreSQL.PQTypes.Internal.Connection
import Database.PostgreSQL.PQTypes.Internal.Notification
import Database.PostgreSQL.PQTypes.Internal.Query
import Database.PostgreSQL.PQTypes.Internal.QueryResult
import Database.PostgreSQL.PQTypes.SQL.Class
import Database.PostgreSQL.PQTypes.Transaction.Settings
Expand Down
202 changes: 185 additions & 17 deletions src/Database/PostgreSQL/PQTypes/Internal/Connection.hs
@@ -1,4 +1,5 @@
module Database.PostgreSQL.PQTypes.Internal.Connection (
module Database.PostgreSQL.PQTypes.Internal.Connection
( -- * Connection
Connection(..)
, ConnectionData(..)
, withConnectionData
Expand All @@ -11,22 +12,29 @@ module Database.PostgreSQL.PQTypes.Internal.Connection (
, poolSource
, connect
, disconnect
-- * Running queries
, runQueryIO
, QueryName(..)
, runPreparedQueryIO
) where

import Control.Arrow (first)
import Control.Concurrent
import Control.Concurrent.Async
import Control.Monad
import Control.Monad.Base
import Control.Monad.Catch
import Data.Bifunctor
import Data.Function
import Data.IORef
import Data.Kind
import Data.Pool
import Data.String
import Foreign.C.String
import Foreign.ForeignPtr
import Foreign.Ptr
import GHC.Conc (closeFdWith)
import qualified Control.Exception as E
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BS
import qualified Data.Foldable as F
import qualified Data.Set as S
import qualified Data.Text as T
Expand All @@ -36,13 +44,20 @@ import Database.PostgreSQL.PQTypes.Internal.C.Interface
import Database.PostgreSQL.PQTypes.Internal.C.Types
import Database.PostgreSQL.PQTypes.Internal.Composite
import Database.PostgreSQL.PQTypes.Internal.Error
import Database.PostgreSQL.PQTypes.Internal.Error.Code
import Database.PostgreSQL.PQTypes.Internal.Exception
import Database.PostgreSQL.PQTypes.Internal.Utils
import Database.PostgreSQL.PQTypes.SQL.Class
import Database.PostgreSQL.PQTypes.SQL.Raw
import Database.PostgreSQL.PQTypes.ToSQL

data ConnectionSettings = ConnectionSettings
{ -- | Connection info string.
csConnInfo :: !T.Text
-- | Client-side encoding. If set to 'Nothing', database encoding is used.
, csClientEncoding :: !(Maybe T.Text)
-- | A custom role to set with "SET ROLE".
, csRole :: !(Maybe (RawSQL ()))
-- | A list of composite types to register. In order to be able to
-- (de)serialize specific composite types, you need to register them.
, csComposites :: ![T.Text]
Expand All @@ -56,6 +71,7 @@ defaultConnectionSettings =
ConnectionSettings
{ csConnInfo = T.empty
, csClientEncoding = Just "UTF-8"
, csRole = Nothing
, csComposites = []
}

Expand Down Expand Up @@ -171,23 +187,26 @@ poolSource cs mkPoolConfig = do
-- 'disconnect', otherwise there will be a resource leak.
connect :: ConnectionSettings -> IO Connection
connect ConnectionSettings{..} = mask $ \unmask -> do
conn <- BS.useAsCString (T.encodeUtf8 csConnInfo) (openConnection unmask)
(`onException` c_PQfinish conn) . unmask $ do
status <- c_PQstatus conn
connPtr <- BS.useAsCString (T.encodeUtf8 csConnInfo) (openConnection unmask)
(`onException` c_PQfinish connPtr) . unmask $ do
status <- c_PQstatus connPtr
when (status /= c_CONNECTION_OK) $
throwLibPQError conn fname
throwLibPQError connPtr fname
F.forM_ csClientEncoding $ \enc -> do
res <- BS.useAsCString (T.encodeUtf8 enc) (c_PQsetClientEncoding conn)
res <- BS.useAsCString (T.encodeUtf8 enc) (c_PQsetClientEncoding connPtr)
when (res == -1) $
throwLibPQError conn fname
c_PQinitTypes conn
registerComposites conn csComposites
preparedQueries <- newIORef S.empty
fmap Connection . newMVar $ Just ConnectionData
{ cdPtr = conn
, cdStats = initialStats
, cdPreparedQueries = preparedQueries
}
throwLibPQError connPtr fname
c_PQinitTypes connPtr
registerComposites connPtr csComposites
conn <- do
preparedQueries <- newIORef S.empty
fmap Connection . newMVar $ Just ConnectionData
{ cdPtr = connPtr
, cdStats = initialStats
, cdPreparedQueries = preparedQueries
}
F.forM_ csRole $ \role -> runQueryIO conn $ "SET ROLE " <> role
pure conn
where
fname = "connect"

Expand Down Expand Up @@ -236,3 +255,152 @@ disconnect (Connection mvconn) = modifyMVar_ mvconn $ \mconn -> do

Nothing -> E.throwIO (HPQTypesError "disconnect: no connection (shouldn't happen)")
return Nothing

----------------------------------------
-- Query running

-- | Low-level function for running an SQL query.
runQueryIO
:: IsSQL sql
=> Connection
-> sql
-> IO (Int, ForeignPtr PGresult)
runQueryIO conn sql = do
runQueryImpl "runQueryIO" conn sql $ \ConnectionData{..} -> do
let allocParam = ParamAllocator $ withPGparam cdPtr
withSQL sql allocParam $ \param query -> (,)
<$> (fromIntegral <$> c_PQparamCount param)
<*> c_PQparamExec cdPtr nullPtr param query c_RESULT_BINARY

-- | Name of a prepared query.
newtype QueryName = QueryName T.Text
deriving (Eq, Ord, Show, IsString)

-- | Low-level function for running a prepared SQL query.
runPreparedQueryIO
:: IsSQL sql
=> Connection
-> QueryName
-> sql
-> IO (Int, ForeignPtr PGresult)
runPreparedQueryIO conn (QueryName queryName) sql = do
runQueryImpl "runPreparedQueryIO" conn sql $ \ConnectionData{..} -> do
when (T.null queryName) $ do
E.throwIO DBException
{ dbeQueryContext = sql
, dbeError = HPQTypesError "runPreparedQueryIO: unnamed prepared query is not supported"
}
let allocParam = ParamAllocator $ withPGparam cdPtr
withSQL sql allocParam $ \param query -> do
preparedQueries <- readIORef cdPreparedQueries
BS.useAsCString (T.encodeUtf8 queryName) $ \cname -> do
when (queryName `S.notMember` preparedQueries) . E.mask_ $ do
-- Mask asynchronous exceptions, because if preparation of the query
-- succeeds, we need to reflect that fact in cdPreparedQueries since
-- you can't prepare a query with the same name more than once.
res <- c_PQparamPrepare cdPtr nullPtr param cname query
void . withForeignPtr res $ verifyResult sql cdPtr
modifyIORef' cdPreparedQueries $ S.insert queryName
(,) <$> (fromIntegral <$> c_PQparamCount param)
<*> c_PQparamExecPrepared cdPtr nullPtr param cname c_RESULT_BINARY

-- | Shared implementation of 'runQueryIO' and 'runPreparedQueryIO'.
runQueryImpl
:: IsSQL sql
=> String
-> Connection
-> sql
-> (ConnectionData -> IO (Int, ForeignPtr PGresult))
-> IO (Int, ForeignPtr PGresult)
runQueryImpl fname conn sql execSql = do
withConnDo $ \cd@ConnectionData{..} -> E.mask $ \restore -> do
-- While the query runs, the current thread will not be able to receive
-- asynchronous exceptions. This prevents clients of the library from
-- interrupting execution of the query. To remedy that we spawn a separate
-- thread for the query execution and while we wait for its completion, we
-- are able to receive asynchronous exceptions (assuming that threaded GHC
-- runtime system is used) and react appropriately.
queryRunner <- async . restore $ do
(paramCount, res) <- execSql cd
affected <- withForeignPtr res $ verifyResult sql cdPtr
stats' <- case affected of
Left _ -> return cdStats {
statsQueries = statsQueries cdStats + 1
, statsParams = statsParams cdStats + paramCount
}
Right rows -> do
columns <- fromIntegral <$> withForeignPtr res c_PQnfields
return ConnectionStats {
statsQueries = statsQueries cdStats + 1
, statsRows = statsRows cdStats + rows
, statsValues = statsValues cdStats + (rows * columns)
, statsParams = statsParams cdStats + paramCount
}
-- Force evaluation of modified stats to squash a space leak.
stats' `seq` return (cd { cdStats = stats' }, (either id id affected, res))
-- If we receive an exception while waiting for the execution to complete,
-- we need to send a request to PostgreSQL for query cancellation and wait
-- for the query runner thread to terminate. It is paramount we make the
-- exception handler uninterruptible as we can't exit from the main block
-- until the query runner thread has terminated.
E.onException (restore $ wait queryRunner) . E.uninterruptibleMask_ $ do
c_PQcancel cdPtr >>= \case
-- If query cancellation request was successfully processed, there is
-- nothing else to do apart from waiting for the runner to terminate.
Nothing -> cancel queryRunner
-- Otherwise we check what happened with the runner. If it already
-- finished we're fine, just ignore the result. If it didn't, something
-- weird is going on. Maybe the cancellation request went through when
-- the thread wasn't making a request to the server? In any case, try to
-- cancel again and wait for the thread to terminate.
Just _ -> poll queryRunner >>= \case
Just _ -> return ()
Nothing -> do
void $ c_PQcancel cdPtr
cancel queryRunner
where
withConnDo = withConnectionData conn fname

verifyResult :: IsSQL sql => sql -> Ptr PGconn -> Ptr PGresult -> IO (Either Int Int)
verifyResult sql conn res = do
-- works even if res is NULL
rst <- c_PQresultStatus res
case rst of
_ | rst == c_PGRES_COMMAND_OK -> do
sn <- c_PQcmdTuples res >>= BS.packCString
case BS.readInt sn of
Nothing
| BS.null sn -> return . Left $ 0
| otherwise -> throwParseError sn
Just (n, rest)
| rest /= BS.empty -> throwParseError sn
| otherwise -> return . Left $ n
_ | rst == c_PGRES_TUPLES_OK -> Right . fromIntegral <$> c_PQntuples res
_ | rst == c_PGRES_FATAL_ERROR -> throwSQLError
_ | rst == c_PGRES_BAD_RESPONSE -> throwSQLError
_ | otherwise -> return . Left $ 0
where
throwSQLError = rethrowWithContext sql =<< if res == nullPtr
then return . E.toException . QueryError
=<< safePeekCString' =<< c_PQerrorMessage conn
else E.toException <$> (DetailedQueryError
<$> field c_PG_DIAG_SEVERITY
<*> (stringToErrorCode <$> field c_PG_DIAG_SQLSTATE)
<*> field c_PG_DIAG_MESSAGE_PRIMARY
<*> mfield c_PG_DIAG_MESSAGE_DETAIL
<*> mfield c_PG_DIAG_MESSAGE_HINT
<*> ((mread =<<) <$> mfield c_PG_DIAG_STATEMENT_POSITION)
<*> ((mread =<<) <$> mfield c_PG_DIAG_INTERNAL_POSITION)
<*> mfield c_PG_DIAG_INTERNAL_QUERY
<*> mfield c_PG_DIAG_CONTEXT
<*> mfield c_PG_DIAG_SOURCE_FILE
<*> ((mread =<<) <$> mfield c_PG_DIAG_SOURCE_LINE)
<*> mfield c_PG_DIAG_SOURCE_FUNCTION)
where
field f = maybe "" id <$> mfield f
mfield f = safePeekCString =<< c_PQresultErrorField res f

throwParseError sn = E.throwIO DBException {
dbeQueryContext = sql
, dbeError = HPQTypesError ("verifyResult: string returned by PQcmdTuples is not a valid number: " ++ show sn)
}
8 changes: 5 additions & 3 deletions src/Database/PostgreSQL/PQTypes/Internal/Monad.hs
Expand Up @@ -14,14 +14,14 @@ import Control.Monad.Reader.Class
import Control.Monad.State.Strict
import Control.Monad.Trans.Control
import Control.Monad.Writer.Class
import Data.Bifunctor
import qualified Control.Monad.Trans.State.Strict as S
import qualified Control.Monad.Fail as MF

import Database.PostgreSQL.PQTypes.Class
import Database.PostgreSQL.PQTypes.Internal.Connection
import Database.PostgreSQL.PQTypes.Internal.Error
import Database.PostgreSQL.PQTypes.Internal.Notification
import Database.PostgreSQL.PQTypes.Internal.Query
import Database.PostgreSQL.PQTypes.Internal.State
import Database.PostgreSQL.PQTypes.SQL
import Database.PostgreSQL.PQTypes.SQL.Class
Expand Down Expand Up @@ -71,8 +71,10 @@ mapDBT f g m = DBT . StateT $ g . runStateT (unDBT m) . f
----------------------------------------

instance (m ~ n, MonadBase IO m, MonadMask m) => MonadDB (DBT_ m n) where
runQuery sql = DBT . StateT $ liftBase . runQueryIO sql
runPreparedQuery name sql = DBT . StateT $ liftBase . runPreparedQueryIO name sql
runQuery sql = DBT . StateT $ \st -> liftBase $ do
second (updateStateWith st sql) <$> runQueryIO (dbConnection st) sql
runPreparedQuery name sql = DBT . StateT $ \st -> liftBase $ do
second (updateStateWith st sql) <$> runPreparedQueryIO (dbConnection st) name sql

getLastQuery = DBT . gets $ dbLastQuery

Expand Down

0 comments on commit b97c83c

Please sign in to comment.