Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support 'CHECK' option of CREATE TABLE command for psqldef #97

Merged
merged 1 commit into from Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 24 additions & 11 deletions adapter/postgres/postgres.go
Expand Up @@ -136,6 +136,12 @@ func buildDumpTableDDL(table string, columns []column, pkeyCols, indexDefs, fore
if col.Default != "" && !col.IsAutoIncrement {
fmt.Fprintf(&queryBuilder, " DEFAULT %s", col.Default)
}
if col.Check != "" {
fmt.Fprintf(&queryBuilder, " CHECK %s", col.Check)
if col.CheckNoInherit {
fmt.Fprint(&queryBuilder, " NO INHERIT")
}
}
}
if len(pkeyCols) > 0 {
fmt.Fprint(&queryBuilder, ",\n"+indent)
Expand All @@ -162,6 +168,8 @@ type column struct {
Default string
IsAutoIncrement bool
IsUnique bool
Check string
CheckNoInherit bool
}

func (c *column) GetDataType() string {
Expand Down Expand Up @@ -195,16 +203,18 @@ func (c *column) GetDataType() string {
}

func (d *PostgresDatabase) getColumns(table string) ([]column, error) {
query := `SELECT column_name, column_default, is_nullable, character_maximum_length,
CASE WHEN data_type = 'ARRAY' THEN format_type(atttypid, atttypmod) ELSE data_type END,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
query := `SELECT s.column_name, s.column_default, s.is_nullable, s.character_maximum_length,
CASE WHEN s.data_type = 'ARRAY' THEN format_type(f.atttypid, f.atttypmod) ELSE s.data_type END,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey,
CASE WHEN pc.contype = 'c' THEN pc.consrc ELSE NULL END AS check,
CASE WHEN pc.connoinherit = 't' THEN true ELSE false END AS no_inherit
FROM pg_attribute f
JOIN pg_class c ON c.oid = f.attrelid JOIN pg_type t ON t.oid = f.atttypid
LEFT JOIN pg_attrdef d ON d.adrelid = c.oid AND d.adnum = f.attnum
LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey)
LEFT JOIN pg_class AS g ON p.confrelid = g.oid
LEFT JOIN information_schema.columns s ON s.column_name=f.attname AND s.table_name=c.relname
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) AND p.contype = 'u'
LEFT JOIN pg_constraint pc ON pc.conrelid = c.oid AND f.attnum = ANY (pc.conkey) AND pc.contype = 'c'
LEFT JOIN information_schema.columns s ON s.column_name=f.attname AND s.table_name = c.relname
WHERE c.relkind = 'r'::char AND n.nspname = $1 AND c.relname = $2 AND f.attnum > 0 ORDER BY f.attnum;`

schema, table := splitTableName(table)
Expand All @@ -218,9 +228,9 @@ WHERE c.relkind = 'r'::char AND n.nspname = $1 AND c.relname = $2 AND f.attnum >
for rows.Next() {
col := column{}
var colName, isNullable, dataType string
var maxLenStr, colDefault *string
var isUnique bool
err = rows.Scan(&colName, &colDefault, &isNullable, &maxLenStr, &dataType, &isUnique)
var maxLenStr, colDefault, check *string
var isUnique, noInherit bool
err = rows.Scan(&colName, &colDefault, &isNullable, &maxLenStr, &dataType, &isUnique, &check, &noInherit)
if err != nil {
return nil, err
}
Expand All @@ -239,10 +249,13 @@ WHERE c.relkind = 'r'::char AND n.nspname = $1 AND c.relname = $2 AND f.attnum >
if colDefault != nil && strings.HasPrefix(*colDefault, "nextval(") {
col.IsAutoIncrement = true
}
col.Nullable = (isNullable == "YES")
col.Nullable = isNullable == "YES"
col.dataType = dataType
col.Length = maxLen

if check != nil {
col.Check = *check
col.CheckNoInherit = noInherit
}
cols = append(cols, col)
}
return cols, nil
Expand Down
60 changes: 60 additions & 0 deletions cmd/psqldef/psqldef_test.go
Expand Up @@ -245,6 +245,66 @@ func TestCreateTableWithReferences(t *testing.T) {
assertApplyOutput(t, createTableA+createTableB, nothingModified)
}

func TestCreateTableWithCheck(t *testing.T) {
resetTestDatabase()

createTable := stripHeredoc(`
CREATE TABLE a (
a_id INTEGER PRIMARY KEY CHECK (a_id > 0),
my_text TEXT UNIQUE NOT NULL
);
`,
)
assertApplyOutput(t, createTable, applyPrefix+createTable)
assertApplyOutput(t, createTable, nothingModified)

createTable = stripHeredoc(`
CREATE TABLE a (
a_id INTEGER PRIMARY KEY CHECK (a_id > 1),
my_text TEXT UNIQUE NOT NULL
);
`,
)
assertApplyOutput(t, createTable, applyPrefix+
`ALTER TABLE "public"."a" DROP CONSTRAINT a_a_id_check;`+"\n"+
`ALTER TABLE "public"."a" ADD CONSTRAINT a_a_id_check CHECK (a_id > 1);`+"\n")
assertApplyOutput(t, createTable, nothingModified)

createTable = stripHeredoc(`
CREATE TABLE a (
a_id INTEGER PRIMARY KEY,
my_text TEXT UNIQUE NOT NULL
);
`,
)
assertApplyOutput(t, createTable, applyPrefix+
`ALTER TABLE "public"."a" DROP CONSTRAINT a_a_id_check;`+"\n")
assertApplyOutput(t, createTable, nothingModified)

createTable = stripHeredoc(`
CREATE TABLE a (
a_id INTEGER PRIMARY KEY CHECK (a_id > 2) NO INHERIT,
my_text TEXT UNIQUE NOT NULL
);
`,
)
assertApplyOutput(t, createTable, applyPrefix+
`ALTER TABLE "public"."a" ADD CONSTRAINT a_a_id_check CHECK (a_id > 2) NO INHERIT;`+"\n")
assertApplyOutput(t, createTable, nothingModified)

createTable = stripHeredoc(`
CREATE TABLE a (
a_id INTEGER PRIMARY KEY CHECK (a_id > 3) NO INHERIT,
my_text TEXT UNIQUE NOT NULL
);
`,
)
assertApplyOutput(t, createTable, applyPrefix+
`ALTER TABLE "public"."a" DROP CONSTRAINT a_a_id_check;`+"\n"+
`ALTER TABLE "public"."a" ADD CONSTRAINT a_a_id_check CHECK (a_id > 3) NO INHERIT;`+"\n")
assertApplyOutput(t, createTable, nothingModified)
}

func TestCreatePolicy(t *testing.T) {
resetTestDatabase()

Expand Down
36 changes: 19 additions & 17 deletions schema/ast.go
Expand Up @@ -49,23 +49,25 @@ type Table struct {
}

type Column struct {
name string
position int
typeName string
unsigned bool
notNull *bool
autoIncrement bool
array bool
defaultVal *Value
length *Value
scale *Value
charset string
collate string
timezone bool // for Postgres `with time zone`
keyOption ColumnKeyOption
onUpdate *Value
enumValues []string
references string
name string
position int
typeName string
unsigned bool
notNull *bool
autoIncrement bool
array bool
defaultVal *Value
length *Value
scale *Value
check string
checkNoInherit bool
charset string
collate string
timezone bool // for Postgres `with time zone`
keyOption ColumnKeyOption
onUpdate *Value
enumValues []string
references string
// TODO: keyopt
// XXX: zerofill?
}
Expand Down
23 changes: 23 additions & 0 deletions schema/generator.go
Expand Up @@ -260,6 +260,21 @@ func (g *Generator) generateDDLsForCreateTable(currentTable Table, desired Creat
ddls = append(ddls, ddl)
}

if currentColumn.check != desiredColumn.check || currentColumn.checkNoInherit != desiredColumn.checkNoInherit {
constraintName := fmt.Sprintf("%s_%s_check", strings.Replace(desired.table.name, "public.", "", 1), desiredColumn.name)
if currentColumn.check != "" {
ddl := fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s", g.escapeTableName(desired.table.name), constraintName)
ddls = append(ddls, ddl)
}
if desiredColumn.check != "" {
ddl := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s)", g.escapeTableName(desired.table.name), constraintName, desiredColumn.check)
if desiredColumn.checkNoInherit {
ddl += " NO INHERIT"
}
ddls = append(ddls, ddl)
}
}

// TODO: support adding a column's `references`
// TODO: support SET/DROP NOT NULL and other properties
default:
Expand Down Expand Up @@ -620,6 +635,13 @@ func (g *Generator) generateColumnDefinition(column Column, enableUnique bool) (
definition += fmt.Sprintf("ON UPDATE %s ", string(column.onUpdate.raw))
}

if column.check != "" {
definition += fmt.Sprintf("CHECK (%s) ", column.check)
}
if column.checkNoInherit {
definition += "NO INHERIT "
}

switch column.keyOption {
case ColumnKeyNone:
// noop
Expand Down Expand Up @@ -893,6 +915,7 @@ func (g *Generator) haveSameColumnDefinition(current Column, desired Column) boo
(current.unsigned == desired.unsigned) &&
((current.notNull != nil && *current.notNull) == ((desired.notNull != nil && *desired.notNull) || desired.keyOption == ColumnKeyPrimary)) && // `PRIMARY KEY` implies `NOT NULL`
(current.timezone == desired.timezone) &&
(current.check == desired.check) &&
(desired.charset == "" || current.charset == desired.charset) && // detect change column only when set explicitly. TODO: can we calculate implicit charset?
(desired.collate == "" || current.collate == desired.collate) && // detect change column only when set explicitly. TODO: can we calculate implicit collate?
reflect.DeepEqual(current.onUpdate, desired.onUpdate)
Expand Down
4 changes: 4 additions & 0 deletions schema/parser.go
Expand Up @@ -99,6 +99,10 @@ func parseTable(mode GeneratorMode, stmt *sqlparser.DDL) Table {
enumValues: parsedCol.Type.EnumValues,
references: parsedCol.Type.References,
}
if parsedCol.Type.Check != nil {
column.check = sqlparser.String(parsedCol.Type.Check.Expr)
}
column.checkNoInherit = castBool(parsedCol.Type.CheckNoInherit)
columns = append(columns, column)
}

Expand Down
20 changes: 14 additions & 6 deletions sqlparser/ast.go
Expand Up @@ -909,12 +909,14 @@ type ColumnType struct {
Type string

// Generic field options.
NotNull *BoolVal
Autoincrement BoolVal
Default *SQLVal
OnUpdate *SQLVal
Comment *SQLVal
Array BoolVal
NotNull *BoolVal
Autoincrement BoolVal
Default *SQLVal
OnUpdate *SQLVal
Comment *SQLVal
Check *Where
CheckNoInherit BoolVal
Array BoolVal

// Numeric field options
Length *SQLVal
Expand Down Expand Up @@ -986,6 +988,12 @@ func (ct *ColumnType) Format(buf *TrackedBuffer) {
if ct.Comment != nil {
opts = append(opts, keywordStrings[COMMENT_KEYWORD], String(ct.Comment))
}
if ct.Check != nil {
opts = append(opts, keywordStrings[CHECK], String(ct.Check))
}
if ct.CheckNoInherit {
opts = append(opts, keywordStrings[NO], keywordStrings[INHERIT])
}
if ct.KeyOpt == colKeyPrimary {
opts = append(opts, keywordStrings[PRIMARY], keywordStrings[KEY])
}
Expand Down