Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multiple fk issue with MySQL migrations #1025

Merged
merged 4 commits into from
Feb 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions persistent-mysql/ChangeLog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog for persistent-mysql

## 2.10.2.3

* Fix issue with multiple foreign keys on single column. [#1025](https://github.com/yesodweb/persistent/pull/1025)

## 2.10.2.2

* Compatibility with latest persistent-template for test suite [#1002](https://github.com/yesodweb/persistent/pull/1002/files)
Expand Down
86 changes: 53 additions & 33 deletions persistent-mysql/Database/Persist/MySQL.hs
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@ migrate' :: MySQL.ConnectInfo
-> IO (Either [Text] [(Bool, Text)])
migrate' connectInfo allDefs getter val = do
let name = entityDB val
(idClmn, old) <- getColumns connectInfo getter val
let (newcols, udefs, fdefs) = mkColumns allDefs val
(idClmn, old) <- getColumns connectInfo getter val newcols
let udspair = map udToPair udefs
case (idClmn, old, partitionEithers old) of
-- Nothing found, create everything
Expand Down Expand Up @@ -467,11 +467,11 @@ udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)
-- in the database.
getColumns :: MySQL.ConnectInfo
-> (Text -> IO Statement)
-> EntityDef
-> EntityDef -> [Column]
-> IO ( [Either Text (Either Column (DBName, [DBName]))] -- ID column
, [Either Text (Either Column (DBName, [DBName]))] -- everything else
)
getColumns connectInfo getter def = do
getColumns connectInfo getter def cols = do
-- Find out ID column.
stmtIdClmn <- getter $ T.concat
[ "SELECT COLUMN_NAME, "
Expand Down Expand Up @@ -522,15 +522,22 @@ getColumns connectInfo getter def = do
-- Return both
return (ids, cs ++ us)
where
refMap = Map.fromList $ foldl ref [] cols
where ref rs c = case cReference c of
Nothing -> rs
(Just r) -> (unDBName $ cName c, r) : rs
vals = [ PersistText $ pack $ MySQL.connectDatabase connectInfo
, PersistText $ unDBName $ entityDB def
, PersistText $ unDBName $ fieldDB $ entityId def ]

helperClmns = CL.mapM getIt .| CL.consume
where
getIt = fmap (either Left (Right . Left)) .
liftIO .
getColumn connectInfo getter (entityDB def)
getIt row = fmap (either Left (Right . Left)) .
liftIO .
getColumn connectInfo getter (entityDB def) row $ ref
where ref = case row of
(PersistText cname : _) -> (Map.lookup cname refMap)
_ -> Nothing

helperCntrs = do
let check [ PersistText cntrName
Expand All @@ -546,6 +553,7 @@ getColumn :: MySQL.ConnectInfo
-> (Text -> IO Statement)
-> DBName
-> [PersistValue]
-> Maybe (DBName, DBName)
-> IO (Either Text Column)
getColumn connectInfo getter tname [ PersistText cname
, PersistText null_
Expand All @@ -554,7 +562,7 @@ getColumn connectInfo getter tname [ PersistText cname
, colMaxLen
, colPrecision
, colScale
, default'] =
, default'] refName =
fmap (either (Left . pack) Right) $
runExceptT $ do
-- Default value
Expand All @@ -569,30 +577,7 @@ getColumn connectInfo getter tname [ PersistText cname
Right t -> return (Just t)
_ -> fail $ "Invalid default column: " ++ show default'

-- Foreign key (if any)
stmt <- lift . getter $ T.concat
[ "SELECT REFERENCED_TABLE_NAME, "
, "CONSTRAINT_NAME, "
, "ORDINAL_POSITION "
, "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE "
, "WHERE TABLE_SCHEMA = ? "
, "AND TABLE_NAME = ? "
, "AND COLUMN_NAME = ? "
, "AND REFERENCED_TABLE_SCHEMA = ? "
, "ORDER BY CONSTRAINT_NAME, "
, "COLUMN_NAME"
]
let vars = [ PersistText $ pack $ MySQL.connectDatabase connectInfo
, PersistText $ unDBName $ tname
, PersistText cname
, PersistText $ pack $ MySQL.connectDatabase connectInfo ]
cntrs <- liftIO $ with (stmtQuery stmt vars) (\src -> runConduit $ src .| CL.consume)
ref <- case cntrs of
[] -> return Nothing
[[PersistText tab, PersistText ref, PersistInt64 pos]] ->
return $ if pos == 1 then Just (DBName tab, DBName ref) else Nothing
_ -> fail "MySQL.getColumn/getRef: never here"

ref <- getRef refName
let colMaxLen' = case colMaxLen of
PersistInt64 l -> Just (fromIntegral l)
_ -> Nothing
Expand All @@ -613,8 +598,43 @@ getColumn connectInfo getter tname [ PersistText cname
, cMaxLen = maxLen
, cReference = ref
}

getColumn _ _ _ x =
where getRef Nothing = return Nothing
getRef (Just (_, refName')) = do
-- Foreign key (if any)
stmt <- lift . getter $ T.concat
[ "SELECT REFERENCED_TABLE_NAME, "
, "CONSTRAINT_NAME, "
, "ORDINAL_POSITION "
, "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE "
, "WHERE TABLE_SCHEMA = ? "
, "AND TABLE_NAME = ? "
, "AND COLUMN_NAME = ? "
, "AND REFERENCED_TABLE_SCHEMA = ? "
, "AND CONSTRAINT_NAME = ? "
, "ORDER BY CONSTRAINT_NAME, "
, "COLUMN_NAME"
]
let vars = [ PersistText $ pack $ MySQL.connectDatabase connectInfo
, PersistText $ unDBName $ tname
, PersistText cname
, PersistText $ pack $ MySQL.connectDatabase connectInfo
, PersistText $ unDBName refName' ]
cntrs <- liftIO $ with (stmtQuery stmt vars) (\src -> runConduit $ src .| CL.consume)
case cntrs of
[] -> return Nothing
[[PersistText tab, PersistText ref, PersistInt64 pos]] ->
return $ if pos == 1 then Just (DBName tab, DBName ref) else Nothing
xs -> error $ mconcat
[ "MySQL.getColumn/getRef: error fetching constraints. Expected a single result for foreign key query for table: "
, T.unpack (unDBName tname)
, " and column: "
, T.unpack cname
, " but got: "
, show xs
]


getColumn _ _ _ x _ =
return $ Left $ pack $ "Invalid result from INFORMATION_SCHEMA: " ++ show x

-- | Extra column information from MySQL schema
Expand Down
2 changes: 1 addition & 1 deletion persistent-mysql/persistent-mysql.cabal
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: persistent-mysql
version: 2.10.2.2
version: 2.10.2.3
license: MIT
license-file: LICENSE
author: Felipe Lessa <felipe.lessa@gmail.com>, Michael Snoyman
Expand Down
13 changes: 13 additions & 0 deletions persistent-mysql/test/CustomConstraintTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ CustomConstraint1
CustomConstraint2
cc_id CustomConstraint1Id constraint=custom_constraint
deriving Show

CustomConstraint3
-- | This will lead to a constraint with the name custom_constraint3_cc_id1_fkey
cc_id1 CustomConstraint1Id
cc_id2 CustomConstraint1Id
deriving Show
|]

specs :: (MonadIO m, MonadFail m) => RunDb SqlBackend m -> Spec
Expand All @@ -45,3 +51,10 @@ specs runDb = do
,PersistText "cc_id"
,PersistText "custom_constraint"]
liftIO $ 1 @?= (exists :: Int)
it "allows multiple constraints on a single column" $ runDb $ do
runMigration customConstraintMigrate
-- | Here we add another foreign key on the same column where the default one already exists. In practice, this could be a compound key with another field.
rawExecute "ALTER TABLE custom_constraint3 ADD CONSTRAINT extra_constraint FOREIGN KEY(cc_id1) REFERENCES custom_constraint1(id)" []
-- | This is where the error is thrown in `getColumn`
_ <- getMigration customConstraintMigrate
pure ()