Skip to content

Commit

Permalink
Fix multiple fk issue with MySQL migrations (#1025)
Browse files Browse the repository at this point in the history
* Added failing test case

* Ported fix for multiple foreign keys

* Bump version number

* Update changelog
  • Loading branch information
robbassi committed Feb 1, 2020
1 parent 72f7618 commit 09d9545
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 34 deletions.
4 changes: 4 additions & 0 deletions persistent-mysql/ChangeLog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog for persistent-mysql

## 2.10.2.3

* Fix issue with multiple foreign keys on single column. [#1025](https://github.com/yesodweb/persistent/pull/1025)

## 2.10.2.2

* Compatibility with latest persistent-template for test suite [#1002](https://github.com/yesodweb/persistent/pull/1002/files)
Expand Down
86 changes: 53 additions & 33 deletions persistent-mysql/Database/Persist/MySQL.hs
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,8 @@ migrate' :: MySQL.ConnectInfo
-> IO (Either [Text] [(Bool, Text)])
migrate' connectInfo allDefs getter val = do
let name = entityDB val
(idClmn, old) <- getColumns connectInfo getter val
let (newcols, udefs, fdefs) = mkColumns allDefs val
(idClmn, old) <- getColumns connectInfo getter val newcols
let udspair = map udToPair udefs
case (idClmn, old, partitionEithers old) of
-- Nothing found, create everything
Expand Down Expand Up @@ -467,11 +467,11 @@ udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)
-- in the database.
getColumns :: MySQL.ConnectInfo
-> (Text -> IO Statement)
-> EntityDef
-> EntityDef -> [Column]
-> IO ( [Either Text (Either Column (DBName, [DBName]))] -- ID column
, [Either Text (Either Column (DBName, [DBName]))] -- everything else
)
getColumns connectInfo getter def = do
getColumns connectInfo getter def cols = do
-- Find out ID column.
stmtIdClmn <- getter $ T.concat
[ "SELECT COLUMN_NAME, "
Expand Down Expand Up @@ -522,15 +522,22 @@ getColumns connectInfo getter def = do
-- Return both
return (ids, cs ++ us)
where
refMap = Map.fromList $ foldl ref [] cols
where ref rs c = case cReference c of
Nothing -> rs
(Just r) -> (unDBName $ cName c, r) : rs
vals = [ PersistText $ pack $ MySQL.connectDatabase connectInfo
, PersistText $ unDBName $ entityDB def
, PersistText $ unDBName $ fieldDB $ entityId def ]

helperClmns = CL.mapM getIt .| CL.consume
where
getIt = fmap (either Left (Right . Left)) .
liftIO .
getColumn connectInfo getter (entityDB def)
getIt row = fmap (either Left (Right . Left)) .
liftIO .
getColumn connectInfo getter (entityDB def) row $ ref
where ref = case row of
(PersistText cname : _) -> (Map.lookup cname refMap)
_ -> Nothing

helperCntrs = do
let check [ PersistText cntrName
Expand All @@ -546,6 +553,7 @@ getColumn :: MySQL.ConnectInfo
-> (Text -> IO Statement)
-> DBName
-> [PersistValue]
-> Maybe (DBName, DBName)
-> IO (Either Text Column)
getColumn connectInfo getter tname [ PersistText cname
, PersistText null_
Expand All @@ -554,7 +562,7 @@ getColumn connectInfo getter tname [ PersistText cname
, colMaxLen
, colPrecision
, colScale
, default'] =
, default'] refName =
fmap (either (Left . pack) Right) $
runExceptT $ do
-- Default value
Expand All @@ -569,30 +577,7 @@ getColumn connectInfo getter tname [ PersistText cname
Right t -> return (Just t)
_ -> fail $ "Invalid default column: " ++ show default'

-- Foreign key (if any)
stmt <- lift . getter $ T.concat
[ "SELECT REFERENCED_TABLE_NAME, "
, "CONSTRAINT_NAME, "
, "ORDINAL_POSITION "
, "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE "
, "WHERE TABLE_SCHEMA = ? "
, "AND TABLE_NAME = ? "
, "AND COLUMN_NAME = ? "
, "AND REFERENCED_TABLE_SCHEMA = ? "
, "ORDER BY CONSTRAINT_NAME, "
, "COLUMN_NAME"
]
let vars = [ PersistText $ pack $ MySQL.connectDatabase connectInfo
, PersistText $ unDBName $ tname
, PersistText cname
, PersistText $ pack $ MySQL.connectDatabase connectInfo ]
cntrs <- liftIO $ with (stmtQuery stmt vars) (\src -> runConduit $ src .| CL.consume)
ref <- case cntrs of
[] -> return Nothing
[[PersistText tab, PersistText ref, PersistInt64 pos]] ->
return $ if pos == 1 then Just (DBName tab, DBName ref) else Nothing
_ -> fail "MySQL.getColumn/getRef: never here"

ref <- getRef refName
let colMaxLen' = case colMaxLen of
PersistInt64 l -> Just (fromIntegral l)
_ -> Nothing
Expand All @@ -613,8 +598,43 @@ getColumn connectInfo getter tname [ PersistText cname
, cMaxLen = maxLen
, cReference = ref
}

getColumn _ _ _ x =
where getRef Nothing = return Nothing
getRef (Just (_, refName')) = do
-- Foreign key (if any)
stmt <- lift . getter $ T.concat
[ "SELECT REFERENCED_TABLE_NAME, "
, "CONSTRAINT_NAME, "
, "ORDINAL_POSITION "
, "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE "
, "WHERE TABLE_SCHEMA = ? "
, "AND TABLE_NAME = ? "
, "AND COLUMN_NAME = ? "
, "AND REFERENCED_TABLE_SCHEMA = ? "
, "AND CONSTRAINT_NAME = ? "
, "ORDER BY CONSTRAINT_NAME, "
, "COLUMN_NAME"
]
let vars = [ PersistText $ pack $ MySQL.connectDatabase connectInfo
, PersistText $ unDBName $ tname
, PersistText cname
, PersistText $ pack $ MySQL.connectDatabase connectInfo
, PersistText $ unDBName refName' ]
cntrs <- liftIO $ with (stmtQuery stmt vars) (\src -> runConduit $ src .| CL.consume)
case cntrs of
[] -> return Nothing
[[PersistText tab, PersistText ref, PersistInt64 pos]] ->
return $ if pos == 1 then Just (DBName tab, DBName ref) else Nothing
xs -> error $ mconcat
[ "MySQL.getColumn/getRef: error fetching constraints. Expected a single result for foreign key query for table: "
, T.unpack (unDBName tname)
, " and column: "
, T.unpack cname
, " but got: "
, show xs
]


getColumn _ _ _ x _ =
return $ Left $ pack $ "Invalid result from INFORMATION_SCHEMA: " ++ show x

-- | Extra column information from MySQL schema
Expand Down
2 changes: 1 addition & 1 deletion persistent-mysql/persistent-mysql.cabal
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: persistent-mysql
version: 2.10.2.2
version: 2.10.2.3
license: MIT
license-file: LICENSE
author: Felipe Lessa <felipe.lessa@gmail.com>, Michael Snoyman
Expand Down
13 changes: 13 additions & 0 deletions persistent-mysql/test/CustomConstraintTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ CustomConstraint1
CustomConstraint2
cc_id CustomConstraint1Id constraint=custom_constraint
deriving Show

CustomConstraint3
-- | This will lead to a constraint with the name custom_constraint3_cc_id1_fkey
cc_id1 CustomConstraint1Id
cc_id2 CustomConstraint1Id
deriving Show
|]

specs :: (MonadIO m, MonadFail m) => RunDb SqlBackend m -> Spec
Expand All @@ -45,3 +51,10 @@ specs runDb = do
,PersistText "cc_id"
,PersistText "custom_constraint"]
liftIO $ 1 @?= (exists :: Int)
it "allows multiple constraints on a single column" $ runDb $ do
runMigration customConstraintMigrate
-- | Here we add another foreign key on the same column where the default one already exists. In practice, this could be a compound key with another field.
rawExecute "ALTER TABLE custom_constraint3 ADD CONSTRAINT extra_constraint FOREIGN KEY(cc_id1) REFERENCES custom_constraint1(id)" []
-- | This is where the error is thrown in `getColumn`
_ <- getMigration customConstraintMigrate
pure ()

0 comments on commit 09d9545

Please sign in to comment.