From 420d8dad63c0f376064f6249d7404c2e19af8d8d Mon Sep 17 00:00:00 2001 From: crazycs Date: Sat, 24 Nov 2018 21:06:34 +0800 Subject: [PATCH] ddl: add check when add foreign key. (#8050) (#8421) --- ddl/ddl_api.go | 7 +++++-- ddl/ddl_db_test.go | 7 ++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 5a9dfb1fb1a6..478f1ccad23e 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -1713,13 +1713,16 @@ func (d *ddl) CreateIndex(ctx sessionctx.Context, ti ast.Ident, unique bool, ind return errors.Trace(err) } -func buildFKInfo(fkName model.CIStr, keys []*ast.IndexColName, refer *ast.ReferenceDef) (*model.FKInfo, error) { +func buildFKInfo(fkName model.CIStr, keys []*ast.IndexColName, refer *ast.ReferenceDef, cols []*table.Column) (*model.FKInfo, error) { var fkInfo model.FKInfo fkInfo.Name = fkName fkInfo.RefTable = refer.Table.Name fkInfo.Cols = make([]model.CIStr, len(keys)) for i, key := range keys { + if table.FindCol(cols, key.Column.Name.O) == nil { + return nil, errKeyColumnDoesNotExits.Gen("key column %s doesn't exist in table", key.Column.Name) + } fkInfo.Cols[i] = key.Column.Name } @@ -1747,7 +1750,7 @@ func (d *ddl) CreateForeignKey(ctx sessionctx.Context, ti ast.Ident, fkName mode return errors.Trace(infoschema.ErrTableNotExists.GenByArgs(ti.Schema, ti.Name)) } - fkInfo, err := buildFKInfo(fkName, keys, refer) + fkInfo, err := buildFKInfo(fkName, keys, refer, t.Cols()) if err != nil { return errors.Trace(err) } diff --git a/ddl/ddl_db_test.go b/ddl/ddl_db_test.go index d3cabcad097d..a7ce45e60fd7 100644 --- a/ddl/ddl_db_test.go +++ b/ddl/ddl_db_test.go @@ -1447,9 +1447,14 @@ func (s *testDBSuite) TestTableForeignKey(c *C) { s.tk = testkit.NewTestKit(c, s.store) s.tk.MustExec("use test") s.tk.MustExec("create table t1 (a int, b int);") + // test create table with foreign key. failSQL := "create table t2 (c int, foreign key (a) references t1(a));" s.testErrorCode(c, failSQL, tmysql.ErrKeyColumnDoesNotExits) - s.tk.MustExec("drop table if exists t1,t2;") + // test add foreign key. + s.tk.MustExec("create table t3 (a int, b int);") + failSQL = "alter table t1 add foreign key (c) REFERENCES t3(a);" + s.testErrorCode(c, failSQL, tmysql.ErrKeyColumnDoesNotExits) + s.tk.MustExec("drop table if exists t1,t2,t3;") } func (s *testDBSuite) TestCreateTableWithPartition(c *C) {