diff --git a/internal/codegen/golang/go_type.go b/internal/codegen/golang/go_type.go index 3050cae3da..718a8dca1c 100644 --- a/internal/codegen/golang/go_type.go +++ b/internal/codegen/golang/go_type.go @@ -11,8 +11,8 @@ func goType(r *compiler.Result, col *compiler.Column, settings config.CombinedSe if oride.GoTypeName == "" { continue } - sameTable := sameTableName(col.Table, oride.Table, r.Catalog.DefaultSchema) - if oride.Column != "" && oride.ColumnName == col.Name && sameTable { + sameTable := oride.Matches(col.Table, r.Catalog.DefaultSchema) + if oride.Column != "" && oride.ColumnName.MatchString(col.Name) && sameTable { return oride.GoTypeName } } diff --git a/internal/codegen/python/gen.go b/internal/codegen/python/gen.go index a30dd3ef50..3173bc6d8f 100644 --- a/internal/codegen/python/gen.go +++ b/internal/codegen/python/gen.go @@ -181,8 +181,8 @@ func pyInnerType(r *compiler.Result, col *compiler.Column, settings config.Combi if !oride.PythonType.IsSet() { continue } - sameTable := sameTableName(col.Table, oride.Table, r.Catalog.DefaultSchema) - if oride.Column != "" && oride.ColumnName == col.Name && sameTable { + sameTable := oride.Matches(col.Table, r.Catalog.DefaultSchema) + if oride.Column != "" && oride.ColumnName.MatchString(col.Name) && sameTable { return oride.PythonType.TypeString() } if oride.DBType != "" && oride.DBType == col.DataType && oride.Nullable != (col.NotNull || col.IsArray) { diff --git a/internal/config/config.go b/internal/config/config.go index d7277226b0..90f41da815 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,7 +9,8 @@ import ( "os" "strings" - "github.com/kyleconroy/sqlc/internal/core" + "github.com/kyleconroy/sqlc/internal/sql/ast" + yaml "gopkg.in/yaml.v3" ) @@ -165,15 +166,50 @@ type Override struct { // fully qualified name of the column, e.g. `accounts.id` Column string `json:"column" yaml:"column"` - ColumnName string - Table core.FQN + ColumnName *Match + TableCatalog *Match + TableSchema *Match + TableRel *Match GoImportPath string GoPackage string GoTypeName string GoBasicType bool } -func (o *Override) Parse() error { +func (o *Override) Matches(n *ast.TableName, defaultSchema string) bool { + if n == nil { + return false + } + + schema := n.Schema + if n.Schema == "" { + schema = defaultSchema + } + + if o.TableCatalog != nil && !o.TableCatalog.MatchString(n.Catalog) { + return false + } + + if o.TableSchema == nil && schema != "" { + return false + } + + if o.TableSchema != nil && !o.TableSchema.MatchString(schema) { + return false + } + + if o.TableRel == nil && n.Name != "" { + return false + } + + if o.TableRel != nil && !o.TableRel.MatchString(n.Name) { + return false + } + + return true +} + +func (o *Override) Parse() (err error) { // validate deprecated postgres_type field if o.Deprecated_PostgresType != "" { @@ -203,16 +239,40 @@ func (o *Override) Parse() error { colParts := strings.Split(o.Column, ".") switch len(colParts) { case 2: - o.ColumnName = colParts[1] - o.Table = core.FQN{Schema: "public", Rel: colParts[0]} + if o.ColumnName, err = MatchCompile(colParts[1]); err != nil { + return err + } + if o.TableRel, err = MatchCompile(colParts[0]); err != nil { + return err + } + if o.TableSchema, err = MatchCompile("public"); err != nil { + return err + } case 3: - o.ColumnName = colParts[2] - o.Table = core.FQN{Schema: colParts[0], Rel: colParts[1]} + if o.ColumnName, err = MatchCompile(colParts[2]); err != nil { + return err + } + if o.TableRel, err = MatchCompile(colParts[1]); err != nil { + return err + } + if o.TableSchema, err = MatchCompile(colParts[0]); err != nil { + return err + } case 4: - o.ColumnName = colParts[3] - o.Table = core.FQN{Catalog: colParts[0], Schema: colParts[1], Rel: colParts[2]} + if o.ColumnName, err = MatchCompile(colParts[3]); err != nil { + return err + } + if o.TableRel, err = MatchCompile(colParts[2]); err != nil { + return err + } + if o.TableSchema, err = MatchCompile(colParts[1]); err != nil { + return err + } + if o.TableCatalog, err = MatchCompile(colParts[0]); err != nil { + return err + } default: - return fmt.Errorf("Override `column` specifier %q is not the proper format, expected '[catalog.][schema.]colname.tablename'", o.Column) + return fmt.Errorf("Override `column` specifier %q is not the proper format, expected '[catalog.][schema.]tablename.colname'", o.Column) } } diff --git a/internal/config/match.go b/internal/config/match.go new file mode 100644 index 0000000000..b0f4d534f3 --- /dev/null +++ b/internal/config/match.go @@ -0,0 +1,57 @@ +package config + +import ( + "fmt" + "regexp" +) + +// Match is a wrapper of *regexp.Regexp. +// It contains the match pattern compiled into a regular expression. +type Match struct { + *regexp.Regexp +} + +// Compile takes our match expression as a string, and compiles it into a *Match object. +// Will return an error on an invalid pattern. +func MatchCompile(pattern string) (match *Match, err error) { + regex := "" + escaped := false + arr := []byte(pattern) + + for i := 0; i < len(arr); i++ { + if escaped { + escaped = false + switch arr[i] { + case '*', '?', '\\': + regex += "\\" + string(arr[i]) + default: + return nil, fmt.Errorf("Invalid escaped character '%c'", arr[i]) + } + } else { + switch arr[i] { + case '\\': + escaped = true + case '*': + regex += ".*" + case '?': + regex += "." + case '.', '(', ')', '+', '|', '^', '$', '[', ']', '{', '}': + regex += "\\" + string(arr[i]) + default: + regex += string(arr[i]) + } + } + } + + if escaped { + return nil, fmt.Errorf("Unterminated escape at end of pattern") + } + + var r *regexp.Regexp + + if r, err = regexp.Compile("^" + regex + "$"); err != nil { + return nil, err + } + + return &Match{r}, nil +} diff --git a/internal/core/fqn.go b/internal/core/fqn.go index 5d76342787..ebdcba8c11 100644 --- a/internal/core/fqn.go +++ b/internal/core/fqn.go @@ -7,14 +7,3 @@ type FQN struct { Schema string Rel string } - -func (f FQN) String() string { - s := f.Rel - if f.Schema != "" { - s = f.Schema + "." + s - } - if f.Catalog != "" { - s = f.Catalog + "." + s - } - return s -} diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/go/models.go b/internal/endtoend/testdata/overrides_go_types/mysql/go/models.go index 9b987955d3..28c4e9ecf2 100644 --- a/internal/endtoend/testdata/overrides_go_types/mysql/go/models.go +++ b/internal/endtoend/testdata/overrides_go_types/mysql/go/models.go @@ -6,6 +6,18 @@ import ( "github.com/kyleconroy/sqlc-testdata/pkg" ) +type Bar struct { + Other string + Total int64 + AlsoRetyped pkg.CustomType +} + +type Baz struct { + Other string + Total int64 + AlsoRetyped pkg.CustomType +} + type Foo struct { Other string Total int64 diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/schema.sql b/internal/endtoend/testdata/overrides_go_types/mysql/schema.sql index c0c5fc47dc..31b183ad86 100644 --- a/internal/endtoend/testdata/overrides_go_types/mysql/schema.sql +++ b/internal/endtoend/testdata/overrides_go_types/mysql/schema.sql @@ -3,3 +3,15 @@ CREATE TABLE foo ( total bigint NOT NULL, retyped text NOT NULL ); + +CREATE TABLE bar ( + other text NOT NULL, + total bigint NOT NULL, + also_retyped text NOT NULL +); + +CREATE TABLE baz ( + other text NOT NULL, + total bigint NOT NULL, + also_retyped text NOT NULL +); diff --git a/internal/endtoend/testdata/overrides_go_types/mysql/sqlc.json b/internal/endtoend/testdata/overrides_go_types/mysql/sqlc.json index 592fb072a0..787bf1459f 100644 --- a/internal/endtoend/testdata/overrides_go_types/mysql/sqlc.json +++ b/internal/endtoend/testdata/overrides_go_types/mysql/sqlc.json @@ -11,6 +11,10 @@ { "go_type": "github.com/kyleconroy/sqlc-testdata/pkg.CustomType", "column": "foo.retyped" + }, + { + "go_type": "github.com/kyleconroy/sqlc-testdata/pkg.CustomType", + "column": "*.also_retyped" } ] } diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/go/models.go b/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/go/models.go index 6a07b893c1..8edb1fad2d 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/go/models.go +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/go/models.go @@ -3,13 +3,25 @@ package override import ( + "database/sql" + orm "database/sql" - "github.com/gofrs/uuid" fuid "github.com/gofrs/uuid" + uuid "github.com/gofrs/uuid" null "github.com/volatiletech/null/v8" null_v4 "gopkg.in/guregu/null.v4" ) +type Bar struct { + ID uuid.UUID + OtherID fuid.UUID + MoreID fuid.UUID + Age sql.NullInt32 + Balance interface{} + Bio sql.NullString + About sql.NullString +} + type Foo struct { ID uuid.UUID OtherID fuid.UUID diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/go/query.sql.go index 84a15c92cc..b02ed88fa4 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/go/query.sql.go +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/go/query.sql.go @@ -6,7 +6,7 @@ package override import ( "context" - "github.com/gofrs/uuid" + uuid "github.com/gofrs/uuid" ) const loadFoo = `-- name: LoadFoo :many diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/schema.sql b/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/schema.sql index 4e0d1f5af7..04551110f3 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/schema.sql +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/schema.sql @@ -6,3 +6,13 @@ CREATE TABLE foo ( bio text, about text ); + +CREATE TABLE bar ( + id uuid NOT NULL, + other_id uuid NOT NULL, + more_id uuid NOT NULL, + age integer, + balance double, + bio text, + about text +); diff --git a/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/sqlc.json b/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/sqlc.json index 7290df6fcd..238b937a36 100644 --- a/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/sqlc.json +++ b/internal/endtoend/testdata/overrides_go_types/postgresql/pgx/sqlc.json @@ -10,14 +10,15 @@ "queries": "query.sql", "overrides": [ { - "column": "foo.id", + "column": "*.id", "go_type": { "import": "github.com/gofrs/uuid", + "package": "uuid", "type": "UUID" }, }, { - "column": "foo.other_id", + "column": "*.*_id", "go_type": { "import": "github.com/gofrs/uuid", "package": "fuid",