Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PersistLiteral support for SQL keywords #1122

Merged
merged 16 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions persistent-mongoDB/Database/Persist/MongoDB.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-orphans #-}
{-# OPTIONS_GHC -fno-warn-deprecations #-} -- Pattern match 'PersistDbSpecific'
-- | Use persistent-mongodb the same way you would use other persistent
-- libraries and refer to the general persistent documentation.
-- There are some new MongoDB specific filters under the filters section.
Expand Down Expand Up @@ -1047,6 +1048,8 @@ instance DB.Val PersistValue where
val (PersistRational _) = throw $ PersistMongoDBUnsupported "PersistRational not implemented for the MongoDB backend"
val (PersistArray a) = DB.val $ PersistList a
val (PersistDbSpecific _) = throw $ PersistMongoDBUnsupported "PersistDbSpecific not implemented for the MongoDB backend"
val (PersistLiteral _) = throw $ PersistMongoDBUnsupported "PersistLiteral not implemented for the MongoDB backend"
val (PersistLiteralEscaped _) = throw $ PersistMongoDBUnsupported "PersistLiteralEscaped not implemented for the MongoDB backend"
cast' (DB.Float x) = Just (PersistDouble x)
cast' (DB.Int32 x) = Just $ PersistInt64 $ fromIntegral x
cast' (DB.Int64 x) = Just $ PersistInt64 x
Expand Down
2 changes: 1 addition & 1 deletion persistent-mongoDB/test/EmbedTestMongo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ specs = describe "embedded entities" $ do
it "can embed an Entity" $ db $ do
let foo = ARecord "foo"
bar = ARecord "bar"
_ <- insertMany [foo, bar]
insertMany_ [foo, bar]
arecords <- selectList ([ARecordName ==. "foo"] ||. [ARecordName ==. "bar"]) []
length arecords @== 2

Expand Down
114 changes: 91 additions & 23 deletions persistent-mysql/Database/Persist/MySQL.hs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-deprecations #-} -- Pattern match 'PersistDbSpecific'
-- | A MySQL backend for @persistent@.
module Database.Persist.MySQL
( withMySQLPool
Expand Down Expand Up @@ -64,7 +67,6 @@ import Data.Text (Text, pack)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.IO as T
import Text.Read (readMaybe)
import System.Environment (getEnvironment)

import Database.Persist.Sql
Expand Down Expand Up @@ -240,6 +242,8 @@ instance MySQL.Param P where
MySQL.Plain $ BBB.fromString $ show (fromRational r :: Pico)
-- FIXME: Too Ambiguous, can not select precision without information about field
render (P (PersistDbSpecific s)) = MySQL.Plain $ BBS.fromByteString s
render (P (PersistLiteral l)) = MySQL.Plain $ BBS.fromByteString l
ldub marked this conversation as resolved.
Show resolved Hide resolved
render (P (PersistLiteralEscaped e)) = MySQL.Escape e
render (P (PersistArray a)) = MySQL.render (P (PersistList a))
render (P (PersistObjectId _)) =
error "Refusing to serialize a PersistObjectId to a MySQL value"
Expand Down Expand Up @@ -313,7 +317,7 @@ getGetter field = go (MySQLBase.fieldType field)
-- Conversion using PersistDbSpecific
go MySQLBase.Geometry _ _ = \_ m ->
case m of
Just g -> PersistDbSpecific g
Just g -> PersistLiteral g
Nothing -> error "Unexpected null in database specific value"
-- Unsupported
go other _ _ = error $ "MySQL.getGetter: type " ++
Expand Down Expand Up @@ -481,12 +485,13 @@ findMaxLenOfColumn allDefs name col =

-- | Find out the maxlen of a field
findMaxLenOfField :: FieldDef -> Maybe Integer
findMaxLenOfField fieldDef = do
maxLenAttr <- listToMaybe
. mapMaybe (T.stripPrefix "maxlen=" . T.toLower)
findMaxLenOfField fieldDef =
listToMaybe
. mapMaybe (\case
FieldAttrMaxlen x -> Just x
_ -> Nothing)
. fieldAttrs
$ fieldDef
readMaybe $ T.unpack maxLenAttr

-- | Helper for 'AddReference' that finds out the which primary key columns to reference.
addReference
Expand Down Expand Up @@ -517,6 +522,8 @@ data AlterColumn = Change Column
| Drop
| Default String
| NoDefault
| Gen SqlType (Maybe Integer) String
| NoGen SqlType (Maybe Integer)
| Update' String
-- | See the definition of the 'showAlter' function to see how these fields are used.
| AddReference
Expand Down Expand Up @@ -565,7 +572,8 @@ getColumns connectInfo getter def cols = do
, "CHARACTER_MAXIMUM_LENGTH, "
, "NUMERIC_PRECISION, "
, "NUMERIC_SCALE, "
, "COLUMN_DEFAULT "
, "COLUMN_DEFAULT, "
, "GENERATION_EXPRESSION "
, "FROM INFORMATION_SCHEMA.COLUMNS "
, "WHERE TABLE_SCHEMA = ? "
, "AND TABLE_NAME = ? "
Expand Down Expand Up @@ -635,13 +643,15 @@ getColumn connectInfo getter tname [ PersistText cname
, colMaxLen
, colPrecision
, colScale
, default'] cRef =
, default'
, generated
] cRef =
fmap (either (Left . pack) Right) $
runExceptT $ do
-- Default value
default_ <-
case default' of
PersistNull -> return Nothing
PersistNull -> return Nothing
PersistText t -> return (Just t)
PersistByteString bs ->
case T.decodeUtf8' bs of
Expand All @@ -650,12 +660,31 @@ getColumn connectInfo getter tname [ PersistText cname
$ "Invalid default column: "
++ show default'
++ " (error: " ++ show exc ++ ")"
Right t ->
Right t ->
return (Just t)
_ ->
fail $ "Invalid default column: " ++ show default'

generated_ <-
case generated of
PersistNull -> return Nothing
PersistText "" -> return Nothing
PersistByteString "" -> return Nothing
PersistText t -> return (Just t)
PersistByteString bs ->
case T.decodeUtf8' bs of
Left exc ->
fail
$ "Invalid generated column: "
++ show generated
++ " (error: " ++ show exc ++ ")"
Right t ->
return (Just t)
_ ->
fail $ "Invalid generated column: " ++ show generated

ref <- getRef (crConstraintName <$> cRef)

let colMaxLen' =
case colMaxLen of
PersistInt64 l -> Just (fromIntegral l)
Expand All @@ -666,13 +695,16 @@ getColumn connectInfo getter tname [ PersistText cname
, ciNumericPrecision = colPrecision
, ciNumericScale = colScale
}

(typ, maxLen) <- parseColumnType dataType ci

-- Okay!
return Column
{ cName = DBName $ cname
, cNull = null_ == "YES"
, cSqlType = typ
, cDefault = default_
, cGenerated = generated_
, cDefaultConstraintName = Nothing
, cMaxLen = maxLen
, cReference = ref
Expand Down Expand Up @@ -821,7 +853,7 @@ getAlters allDefs edef (c1, u1) (c2, u2) =
(col', ty, ml)


-- | @findAlters newColumn oldColumns@ finds out what needs to be
-- | @findAlters x y newColumn oldColumns@ finds out what needs to be
-- changed in the columns @oldColumns@ for @newColumn@ to be
-- supported.
findAlters
Expand All @@ -830,7 +862,7 @@ findAlters
-> Column
-> [Column]
-> ([AlterColumn'], [Column])
findAlters edef allDefs col@(Column name isNull type_ def _defConstraintName maxLen ref) cols =
findAlters edef allDefs col@(Column name isNull type_ def gen _defConstraintName maxLen ref) cols =
case filter ((name ==) . cName) cols of
-- new fkey that didn't exist before
[] ->
Expand All @@ -842,7 +874,7 @@ findAlters edef allDefs col@(Column name isNull type_ def _defConstraintName max
cnstr = [addReference allDefs cname tname name (crFieldCascade cr)]
in
(map ((,) tname) (Add' col : cnstr), cols)
Column _ isNull' type_' def' _defConstraintName' maxLen' ref' : _ ->
Column _ isNull' type_' def' gen' _defConstraintName' maxLen' ref' : _ ->
let -- Foreign key
refDrop =
case (ref == ref', ref') of
Expand All @@ -861,15 +893,25 @@ findAlters edef allDefs col@(Column name isNull type_ def _defConstraintName max
-- Type and nullability
modType | showSqlType type_ maxLen False `ciEquals` showSqlType type_' maxLen' False && isNull == isNull' = []
| otherwise = [(name, Change col)]

-- Default value
-- Avoid DEFAULT NULL, since it is always unnecessary, and is an error for text/blob fields
modDef | def == def' = []
| otherwise =
case def of
Nothing -> [(name, NoDefault)]
Just s -> if T.toUpper s == "NULL" then []
else [(name, Default $ T.unpack s)]
in ( refDrop ++ modType ++ modDef ++ refAdd
modDef =
if def == def' then []
else case def of
Nothing -> [(name, NoDefault)]
Just s ->
if T.toUpper s == "NULL" then []
else [(name, Default $ T.unpack s)]

-- Does the generated value need to change?
modGen =
if gen == gen' then []
else case gen of
Nothing -> [(name, NoGen type_ maxLen)]
Just genExpr -> [(name, Gen type_ maxLen $ T.unpack genExpr)]

in ( refDrop ++ modType ++ modDef ++ modGen ++ refAdd
, filter ((name /=) . cName) cols
)

Expand All @@ -882,11 +924,16 @@ findAlters edef allDefs col@(Column name isNull type_ def _defConstraintName max
-- | Prints the part of a @CREATE TABLE@ statement about a given
-- column.
showColumn :: Column -> String
showColumn (Column n nu t def _defConstraintName maxLen ref) = concat
showColumn (Column n nu t def gen _defConstraintName maxLen ref) = concat
[ escapeDBName n
, " "
, showSqlType t maxLen True
, " "
, case gen of
Nothing -> ""
Just genExpr ->
if T.toUpper genExpr == "NULL" then ""
else " GENERATED ALWAYS AS (" <> T.unpack genExpr <> ") STORED "
, if nu then "NULL" else "NOT NULL"
, case def of
Nothing -> ""
Expand Down Expand Up @@ -958,14 +1005,14 @@ showAlterTable table (DropUniqueConstraint cname) = concat

-- | Render an action that must be done on a column.
showAlter :: DBName -> AlterColumn' -> String
showAlter table (oldName, Change (Column n nu t def defConstraintName maxLen _ref)) =
showAlter table (oldName, Change (Column n nu t def gen defConstraintName maxLen _ref)) =
concat
[ "ALTER TABLE "
, escapeDBName table
, " CHANGE "
, escapeDBName oldName
, " "
, showColumn (Column n nu t def defConstraintName maxLen Nothing)
, showColumn (Column n nu t def gen defConstraintName maxLen Nothing)
]
showAlter table (_, Add' col) =
concat
Expand Down Expand Up @@ -998,6 +1045,27 @@ showAlter table (n, NoDefault) =
, escapeDBName n
, " DROP DEFAULT"
]
showAlter table (col, Gen typ len expr) =
concat
[ "ALTER TABLE "
, escapeDBName table
, " MODIFY COLUMN "
, escapeDBName col
, " "
, showSqlType typ len True
, " GENERATED ALWAYS AS ("
, expr
, ") STORED"
]
showAlter table (col, NoGen typ len) =
concat
[ "ALTER TABLE "
, escapeDBName table
, " MODIFY COLUMN "
, escapeDBName col
, " "
, showSqlType typ len True
]
showAlter table (n, Update' s) =
concat
[ "UPDATE "
Expand Down
16 changes: 9 additions & 7 deletions persistent-mysql/test/CustomConstraintTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ specs runDb = do
describe "custom constraint used in migration" $ before_ (runDb $ void $ runMigrationSilent customConstraintMigrate) $ after_ (runDb clean) $ do

it "custom constraint is actually created" $ runDb $ do
runMigrationSilent customConstraintMigrate -- run a second time to ensure the constraint isn't dropped
void $ runMigrationSilent customConstraintMigrate -- run a second time to ensure the constraint isn't dropped
let query = T.concat ["SELECT COUNT(*) "
,"FROM information_schema.key_column_usage "
,"WHERE ordinal_position=1 "
Expand All @@ -53,12 +53,14 @@ specs runDb = do
,"AND table_name=? "
,"AND column_name=? "
,"AND constraint_name=?"]
[Single exists] <- rawSql query [PersistText "custom_constraint1"
,PersistText "id"
,PersistText "custom_constraint2"
,PersistText "cc_id"
,PersistText "custom_constraint"]
liftIO $ 1 @?= (exists :: Int)
[Single exists_] <- rawSql query
[ PersistText "custom_constraint1"
, PersistText "id"
, PersistText "custom_constraint2"
, PersistText "cc_id"
, PersistText "custom_constraint"
]
liftIO $ 1 @?= (exists_ :: Int)

it "allows multiple constraints on a single column" $ runDb $ do
-- Here we add another foreign key on the same column where the
Expand Down
2 changes: 1 addition & 1 deletion persistent-mysql/test/InsertDuplicateUpdate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ specs = describe "DuplicateKeyUpdate" $ do
it "performs only updates given if record already exists" $ db $ do
deleteWhere ([] :: [Filter Item])
let newDescription = "I am a new description"
_ <- insert item1
insert_ item1
insertOnDuplicateKeyUpdate
(Item "item1" "i am inserted description" (Just 1) (Just 2))
[ItemDescription =. newDescription]
Expand Down
2 changes: 2 additions & 0 deletions persistent-mysql/test/main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import qualified UniqueTest
import qualified UpsertTest
import qualified CustomConstraintTest
import qualified LongIdentifierTest
import qualified GeneratedColumnTestSQL
import qualified ForeignKey

type Tuple a b = (a, b)
Expand Down Expand Up @@ -199,6 +200,7 @@ main = do
-- TODO: implement automatic truncation for too long foreign keys, so we can run this test.
xdescribe "The migration for this test currently fails because of MySQL's 64 character limit for identifiers. See https://github.com/yesodweb/persistent/issues/1000 for details" $
LongIdentifierTest.specsWith db
GeneratedColumnTestSQL.specsWith db

roundFn :: RealFrac a => a -> Integer
roundFn = round
Expand Down
Loading