Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schemadiff: validate() table structure at the end of apply() #10189

Merged
18 changes: 17 additions & 1 deletion go/vt/schemadiff/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,17 @@ func TestDiffSchemas(t *testing.T) {
from: "create table t(id int primary key)",
to: "create table t(id int primary key)",
},
{
name: "change of table column",
from: "create table t(id int primary key, v varchar(10))",
to: "create table t(id int primary key, v varchar(20))",
diffs: []string{
"alter table t modify column v varchar(20)",
},
cdiffs: []string{
"ALTER TABLE `t` MODIFY COLUMN `v` varchar(20)",
},
},
{
name: "change of table columns, added",
from: "create table t(id int primary key)",
Expand Down Expand Up @@ -329,7 +340,7 @@ func TestDiffSchemas(t *testing.T) {
},
},
{
name: "create table (2)",
name: "create table 2",
from: ";;; ; ; ;;;",
to: "create table t(id int primary key)",
diffs: []string{
Expand Down Expand Up @@ -501,11 +512,16 @@ func TestDiffSchemas(t *testing.T) {
// Validate "apply()" on "from" converges with "to"
schema1, err := NewSchemaFromSQL(ts.from)
assert.NoError(t, err)
schema1SQL := schema1.ToSQL()

schema2, err := NewSchemaFromSQL(ts.to)
assert.NoError(t, err)
applied, err := schema1.Apply(diffs)
require.NoError(t, err)

// validate schema1 unaffected by Apply
assert.Equal(t, schema1SQL, schema1.ToSQL())

appliedDiff, err := schema2.Diff(applied, hints)
require.NoError(t, err)
assert.Empty(t, appliedDiff)
Expand Down
6 changes: 5 additions & 1 deletion go/vt/schemadiff/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,11 @@ func (s *Schema) apply(diffs []EntityDiff) error {
// These diffs are CREATE/DROP/ALTER TABLE/VIEW.
// The operation does not modify this object. Instead, if successful, a new (modified) Schema is returned.
func (s *Schema) Apply(diffs []EntityDiff) (*Schema, error) {
dup, err := NewSchemaFromStatements(s.ToStatements())
// we export to queries, then import back.
// The reason we don't just clone this object's fields, or even export/import to Statements,
// is that we want this schema to be immutable an unaffected by the apply() on the duplicate.
// statements/slices/maps will have shared pointers and changes will propagate back to this schema.
dup, err := NewSchemaFromQueries(s.ToQueries())
if err != nil {
return nil, err
}
Expand Down
108 changes: 104 additions & 4 deletions go/vt/schemadiff/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package schemadiff

import (
"fmt"
"math"
"sort"
"strconv"
"strings"

Expand Down Expand Up @@ -869,9 +871,37 @@ func (c *CreateTableEntity) Drop() EntityDiff {
return &DropTableEntityDiff{from: c, dropTable: dropTable}
}

func sortAlterOptions(diff *AlterTableEntityDiff) {
optionOrder := func(opt sqlparser.AlterOption) int {
switch opt.(type) {
case *sqlparser.DropKey:
return 1
case *sqlparser.DropColumn:
return 2
case *sqlparser.ModifyColumn:
return 3
case *sqlparser.AddColumns:
return 4
case *sqlparser.AddIndexDefinition:
return 5
case *sqlparser.AddConstraintDefinition:
return 6
case sqlparser.TableOptions, *sqlparser.TableOptions:
return 7
default:
return math.MaxInt
}
}
opts := diff.alterTable.AlterOptions
sort.SliceStable(opts, func(i, j int) bool {
return optionOrder(opts[i]) < optionOrder(opts[j])
})
}

// apply attempts to apply an ALTER TABLE diff onto this entity's table definition.
// supported modifications are only those created by schemadiff's Diff() function.
func (c *CreateTableEntity) apply(diff *AlterTableEntityDiff) error {
sortAlterOptions(diff)
if spec := diff.alterTable.PartitionSpec; spec != nil {
switch {
case spec.Action == sqlparser.RemoveAction && spec.IsAll:
Expand Down Expand Up @@ -930,6 +960,11 @@ func (c *CreateTableEntity) apply(diff *AlterTableEntityDiff) error {
return nil
}

columnExists := map[string]bool{}
for _, col := range c.CreateTable.TableSpec.Columns {
columnExists[col.Name.String()] = true
}

// apply a single AlterOption; only supported types are those generated by Diff()
applyAlterOption := func(opt sqlparser.AlterOption) error {
switch opt := opt.(type) {
Expand Down Expand Up @@ -962,9 +997,16 @@ func (c *CreateTableEntity) apply(diff *AlterTableEntityDiff) error {
}
case *sqlparser.AddIndexDefinition:
// validate no existing key by same name
keyName := opt.IndexDefinition.Info.Name.String()
for _, index := range c.TableSpec.Indexes {
if index.Info.Name.String() == opt.IndexDefinition.Info.Name.String() {
return errors.Wrap(ErrApplyDuplicateKey, opt.IndexDefinition.Info.Name.String())
if index.Info.Name.String() == keyName {
return errors.Wrap(ErrApplyDuplicateKey, keyName)
}
}
for _, col := range opt.IndexDefinition.Columns {
colName := col.Column.String()
if !columnExists[colName] {
return errors.Wrapf(ErrInvalidColumnInKey, "key: %v, column: %v", keyName, colName)
}
}
c.TableSpec.Indexes = append(c.TableSpec.Indexes, opt.IndexDefinition)
Expand All @@ -979,8 +1021,9 @@ func (c *CreateTableEntity) apply(diff *AlterTableEntityDiff) error {
case *sqlparser.DropColumn:
// we expect the column to exist
found := false
colName := opt.Name.Name.String()
for i, col := range c.TableSpec.Columns {
if col.Name.String() == opt.Name.Name.String() {
if col.Name.String() == colName {
found = true
c.TableSpec.Columns = append(c.TableSpec.Columns[0:i], c.TableSpec.Columns[i+1:]...)
break
Expand All @@ -989,15 +1032,17 @@ func (c *CreateTableEntity) apply(diff *AlterTableEntityDiff) error {
if !found {
return errors.Wrap(ErrApplyColumnNotFound, opt.Name.Name.String())
}
delete(columnExists, colName)
case *sqlparser.AddColumns:
if len(opt.Columns) != 1 {
// our Diff only ever generates a single column per AlterOption
return errors.Wrap(ErrUnsupportedApplyOperation, sqlparser.String(opt))
}
// validate no column by same name
addedCol := opt.Columns[0]
colName := addedCol.Name.String()
for _, col := range c.TableSpec.Columns {
if col.Name.String() == addedCol.Name.String() {
if col.Name.String() == colName {
return errors.Wrap(ErrApplyDuplicateColumn, addedCol.Name.String())
}
}
Expand All @@ -1006,6 +1051,7 @@ func (c *CreateTableEntity) apply(diff *AlterTableEntityDiff) error {
if err := reorderColumn(len(c.TableSpec.Columns)-1, opt.First, opt.After); err != nil {
return err
}
columnExists[colName] = true
case *sqlparser.ModifyColumn:
// we expect the column to exist
found := false
Expand Down Expand Up @@ -1053,6 +1099,12 @@ func (c *CreateTableEntity) apply(diff *AlterTableEntityDiff) error {
return err
}
}
if err := c.postApplyNormalize(); err != nil {
return err
}
if err := c.validate(); err != nil {
return err
}
return nil
}

Expand Down Expand Up @@ -1090,3 +1142,51 @@ func (c *CreateTableEntity) Apply(diff EntityDiff) (Entity, error) {
}
return dup, nil
}

// postApplyNormalize runs at the end of apply() and to reorganize/edit things that
// a MySQL will do implicitly:
// - edit or remove keys if referenced columns are dropped
func (c *CreateTableEntity) postApplyNormalize() error {
// reduce or remove keys based on existing column list
// (a column may have been removed)postApplyNormalize
columnExists := map[string]bool{}
for _, col := range c.CreateTable.TableSpec.Columns {
columnExists[col.Name.String()] = true
}
nonEmptyIndexes := []*sqlparser.IndexDefinition{}
for _, key := range c.CreateTable.TableSpec.Indexes {
existingColumns := []*sqlparser.IndexColumn{}
for _, col := range key.Columns {
colName := col.Column.String()
if columnExists[colName] {
existingColumns = append(existingColumns, col)
}
}
if len(existingColumns) > 0 {
key.Columns = existingColumns
nonEmptyIndexes = append(nonEmptyIndexes, key)
}
}
c.CreateTable.TableSpec.Indexes = nonEmptyIndexes

return nil
}

// validate checks that the table structure is valid:
// - all columns referenced by keys exist
func (c *CreateTableEntity) validate() error {
// validate all columns referenced by indexes do in fact exist
columnExists := map[string]bool{}
for _, col := range c.CreateTable.TableSpec.Columns {
columnExists[col.Name.String()] = true
}
for _, key := range c.CreateTable.TableSpec.Indexes {
for _, col := range key.Columns {
colName := col.Column.String()
if !columnExists[colName] {
return errors.Wrapf(ErrInvalidColumnInKey, "key: %v, column: %v", key.Info.Name.String(), colName)
}
}
}
return nil
}
119 changes: 119 additions & 0 deletions go/vt/schemadiff/table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package schemadiff

import (
"errors"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -736,3 +737,121 @@ func TestCreateTableDiff(t *testing.T) {
})
}
}

func TestValidate(t *testing.T) {
tt := []struct {
name string
from string
to string
alter string
expectErr error
}{
{
name: "add column",
from: "create table t (id int primary key)",
alter: "alter table t add column i int",
to: "create table t (id int primary key, i int)",
},
{
name: "add key",
from: "create table t (id int primary key, i int)",
alter: "alter table t add key i_idx(i)",
to: "create table t (id int primary key, i int, key i_idx(i))",
},
{
name: "add column and key",
from: "create table t (id int primary key)",
alter: "alter table t add column i int, add key i_idx(i)",
to: "create table t (id int primary key, i int, key i_idx(i))",
},
{
name: "add key, missing column",
from: "create table t (id int primary key, i int)",
alter: "alter table t add key j_idx(j)",
expectErr: ErrInvalidColumnInKey,
},
{
name: "add key, missing column 2",
from: "create table t (id int primary key, i int)",
alter: "alter table t add key j_idx(j, i)",
expectErr: ErrInvalidColumnInKey,
},
{
name: "drop column, ok",
from: "create table t (id int primary key, i int, i2 int, key i_idx(i))",
alter: "alter table t drop column i2",
to: "create table t (id int primary key, i int, key i_idx(i))",
},
{
name: "drop column, affect keys",
from: "create table t (id int primary key, i int, key i_idx(i))",
alter: "alter table t drop column i",
to: "create table t (id int primary key)",
},
{
name: "drop column, affect keys 2",
from: "create table t (id int primary key, i int, i2 int, key i_idx(i, i2))",
alter: "alter table t drop column i",
to: "create table t (id int primary key, i2 int, key i_idx(i2))",
},
{
name: "drop column, affect keys 3",
from: "create table t (id int primary key, i int, i2 int, key i_idx(i, i2))",
alter: "alter table t drop column i2",
to: "create table t (id int primary key, i int, key i_idx(i))",
},
{
name: "drop column, affect keys 4",
from: "create table t (id int primary key, i int, i2 int, key some_key(id, i), key i_idx(i, i2))",
alter: "alter table t drop column i2",
to: "create table t (id int primary key, i int, key some_key(id, i), key i_idx(i))",
},
{
name: "add multiple keys, multi columns, ok",
from: "create table t (id int primary key, i1 int, i2 int, i3 int)",
alter: "alter table t add key i12_idx(i1, i2), add key i32_idx(i3, i2), add key i21_idx(i2, i1)",
to: "create table t (id int primary key, i1 int, i2 int, i3 int, key i12_idx(i1, i2), key i32_idx(i3, i2), key i21_idx(i2, i1))",
},
{
name: "add multiple keys, multi columns, missing column",
from: "create table t (id int primary key, i1 int, i2 int, i4 int)",
alter: "alter table t add key i12_idx(i1, i2), add key i32_idx(i3, i2), add key i21_idx(i2, i1)",
expectErr: ErrInvalidColumnInKey,
},
}
hints := DiffHints{}
for _, ts := range tt {
t.Run(ts.name, func(t *testing.T) {
stmt, err := sqlparser.Parse(ts.from)
require.NoError(t, err)
fromCreateTable, ok := stmt.(*sqlparser.CreateTable)
require.True(t, ok)

stmt, err = sqlparser.Parse(ts.alter)
require.NoError(t, err)
alterTable, ok := stmt.(*sqlparser.AlterTable)
require.True(t, ok)

from := NewCreateTableEntity(fromCreateTable)
a := &AlterTableEntityDiff{from: from, alterTable: alterTable}
applied, err := from.Apply(a)
if ts.expectErr != nil {
assert.Error(t, err)
assert.True(t, errors.Is(err, ts.expectErr))
} else {
assert.NoError(t, err)
assert.NotNil(t, applied)

stmt, err := sqlparser.Parse(ts.to)
require.NoError(t, err)
toCreateTable, ok := stmt.(*sqlparser.CreateTable)
require.True(t, ok)

to := NewCreateTableEntity(toCreateTable)
diff, err := applied.Diff(to, &hints)
require.NoError(t, err)
assert.Empty(t, diff, "diff found: %v.\applied: %v\nto: %v", diff.CanonicalStatementString(), applied.Create().CanonicalStatementString(), to.Create().CanonicalStatementString())
}
})
}
}
2 changes: 2 additions & 0 deletions go/vt/schemadiff/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ var (
ErrApplyDuplicateKey = errors.New("duplicate key")
ErrApplyDuplicateColumn = errors.New("duplicate column")
ErrApplyDuplicateConstraint = errors.New("duplicate constraint")

ErrInvalidColumnInKey = errors.New("invalid column referenced by key")
)

// Entity stands for a database object we can diff:
Expand Down