diff --git a/internal/compiler/compile.go b/internal/compiler/compile.go index 287a92dd7b..aacdc01e59 100644 --- a/internal/compiler/compile.go +++ b/internal/compiler/compile.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "os" + "regexp" "sort" "strings" @@ -23,6 +24,35 @@ type Parser interface { Parse(io.Reader) ([]ast.Statement, error) } +// copied over from gen.go +func structName(name string) string { + out := "" + for _, p := range strings.Split(name, "_") { + if p == "id" { + out += "ID" + } else { + out += strings.Title(p) + } + } + return out +} + +var identPattern = regexp.MustCompile("[^a-zA-Z0-9_]+") + +func enumValueName(value string) string { + name := "" + id := strings.Replace(value, "-", "_", -1) + id = strings.Replace(id, ":", "_", -1) + id = strings.Replace(id, "/", "_", -1) + id = identPattern.ReplaceAllString(id, "") + for _, part := range strings.Split(id, "_") { + name += strings.Title(part) + } + return name +} + +// end copypasta + func Run(conf config.SQL, combo config.CombinedSettings) (*Result, error) { var p Parser @@ -53,25 +83,52 @@ func Run(conf config.SQL, combo config.CombinedSettings) (*Result, error) { } var structs []dinosql.GoStruct + var enums []dinosql.GoEnum for _, schema := range c.Schemas { for _, table := range schema.Tables { s := dinosql.GoStruct{ - Table: pg.FQN{Schema: table.Rel.Schema, Rel: table.Rel.Name}, + Table: pg.FQN{Schema: schema.Name, Rel: table.Rel.Name}, Name: strings.Title(table.Rel.Name), } for _, col := range table.Columns { s.Fields = append(s.Fields, dinosql.GoField{ - Name: strings.Title(col.Name), + Name: structName(col.Name), Type: "string", Tags: map[string]string{"json:": col.Name}, }) } structs = append(structs, s) } + for _, typ := range schema.Types { + switch t := typ.(type) { + case catalog.Enum: + var name string + // TODO: This name should be public, not main + if schema.Name == "main" { + name = t.Name + } else { + name = schema.Name + "_" + t.Name + } + e := dinosql.GoEnum{ + Name: structName(name), + } + for _, v := range t.Vals { + e.Constants = append(e.Constants, dinosql.GoConstant{ + Name: e.Name + enumValueName(v), + Value: v, + Type: e.Name, + }) + } + enums = append(enums, e) + } + } } + if len(structs) > 0 { sort.Slice(structs, func(i, j int) bool { return structs[i].Name < structs[j].Name }) } - - return &Result{structs: structs}, nil + if len(enums) > 0 { + sort.Slice(enums, func(i, j int) bool { return enums[i].Name < enums[j].Name }) + } + return &Result{structs: structs, enums: enums}, nil } diff --git a/internal/compiler/result.go b/internal/compiler/result.go index 950f478c9c..563b51024d 100644 --- a/internal/compiler/result.go +++ b/internal/compiler/result.go @@ -6,6 +6,7 @@ import ( ) type Result struct { + enums []dinosql.GoEnum structs []dinosql.GoStruct queries []dinosql.GoQuery } @@ -19,5 +20,5 @@ func (r *Result) GoQueries(settings config.CombinedSettings) []dinosql.GoQuery { } func (r *Result) Enums(settings config.CombinedSettings) []dinosql.GoEnum { - return nil + return r.enums } diff --git a/internal/endtoend/testdata/experimental_elephant/go/models.go b/internal/endtoend/testdata/experimental_elephant/go/models.go index 518b1bade7..397456ec7f 100644 --- a/internal/endtoend/testdata/experimental_elephant/go/models.go +++ b/internal/endtoend/testdata/experimental_elephant/go/models.go @@ -2,7 +2,29 @@ package querytest -import () +import ( + "fmt" +) + +type Mood string + +const ( + MoodSad Mood = "sad" + MoodOk Mood = "ok" + MoodHappy Mood = "happy" +) + +func (e *Mood) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = Mood(s) + case string: + *e = Mood(s) + default: + return fmt.Errorf("unsupported scan type for Mood: %T", src) + } + return nil +} type Baz struct { Name string diff --git a/internal/endtoend/testdata/experimental_elephant/query.sql b/internal/endtoend/testdata/experimental_elephant/query.sql index b02a022957..5eaecd2223 100644 --- a/internal/endtoend/testdata/experimental_elephant/query.sql +++ b/internal/endtoend/testdata/experimental_elephant/query.sql @@ -6,6 +6,8 @@ CREATE TABLE bar ( baz text NOT NULL ); +CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); + SELECT bar FROM foo; DROP TABLE bar; diff --git a/internal/postgresql/parse.go b/internal/postgresql/parse.go index 70a76fe841..4e6e5d9bc0 100644 --- a/internal/postgresql/parse.go +++ b/internal/postgresql/parse.go @@ -22,6 +22,30 @@ func stringSlice(list nodes.List) []string { return items } +func parseTypeName(node nodes.Node) (*ast.TypeName, error) { + switch n := node.(type) { + + case nodes.List: + parts := stringSlice(n) + switch len(parts) { + case 1: + return &ast.TypeName{ + Name: parts[0], + }, nil + case 2: + return &ast.TypeName{ + Schema: parts[0], + Name: parts[1], + }, nil + default: + return nil, fmt.Errorf("invalid type name: %s", join(n, ".")) + } + + default: + return nil, fmt.Errorf("unexpected node type: %T", n) + } +} + func parseTableName(node nodes.Node) (*ast.TableName, error) { switch n := node.(type) { @@ -180,6 +204,25 @@ func translate(node nodes.Node) (ast.Node, error) { } return create, nil + case nodes.CreateEnumStmt: + name, err := parseTypeName(n.TypeName) + if err != nil { + return nil, err + } + stmt := &ast.CreateEnumStmt{ + TypeName: name, + Vals: &ast.List{}, + } + for _, val := range n.Vals.Items { + switch v := val.(type) { + case nodes.String: + stmt.Vals.Items = append(stmt.Vals.Items, &ast.String{ + Str: v.Str, + }) + } + } + return stmt, nil + case nodes.DropStmt: drop := &ast.DropTableStmt{ IfExists: n.MissingOk, diff --git a/internal/sql/ast/ast.go b/internal/sql/ast/ast.go index fb5ef1cd99..3675466ba8 100644 --- a/internal/sql/ast/ast.go +++ b/internal/sql/ast/ast.go @@ -57,6 +57,15 @@ func (n *AlterTableCmd) Pos() int { return 0 } +type CreateEnumStmt struct { + TypeName *TypeName + Vals *List +} + +func (n *CreateEnumStmt) Pos() int { + return 0 +} + type CreateTableStmt struct { IfNotExists bool Name *TableName @@ -88,7 +97,8 @@ func (n *ColumnDef) Pos() int { } type TypeName struct { - Name string + Schema string + Name string } func (n *TypeName) Pos() int { @@ -127,3 +137,11 @@ type ColumnRef struct { func (n *ColumnRef) Pos() int { return 0 } + +type String struct { + Str string +} + +func (n *String) Pos() int { + return 0 +} diff --git a/internal/sql/catalog/catalog.go b/internal/sql/catalog/catalog.go index aa2c0332e9..0ec54a3783 100644 --- a/internal/sql/catalog/catalog.go +++ b/internal/sql/catalog/catalog.go @@ -21,6 +21,8 @@ func Build(stmts []ast.Statement) (*Catalog, error) { switch n := stmts[i].Raw.Stmt.(type) { case *ast.AlterTableStmt: err = c.alterTable(n) + case *ast.CreateEnumStmt: + err = c.createEnum(n) case *ast.CreateTableStmt: err = c.createTable(n) case *ast.DropTableStmt: @@ -33,8 +35,19 @@ func Build(stmts []ast.Statement) (*Catalog, error) { return c, nil } +func stringSlice(list *ast.List) []string { + items := []string{} + for _, item := range list.Items { + if n, ok := item.(*ast.String); ok { + items = append(items, n.Str) + } + } + return items +} + // TODO: This need to be rich error types var ErrRelationNotFound = errors.New("relation not found") +var ErrRelationAlreadyExists = errors.New("relation already exists") var ErrSchemaNotFound = errors.New("schema not found") var ErrColumnNotFound = errors.New("column not found") var ErrColumnExists = errors.New("column already exists") @@ -159,6 +172,37 @@ func (c *Catalog) alterTable(stmt *ast.AlterTableStmt) error { return nil } +func (c *Catalog) createEnum(stmt *ast.CreateEnumStmt) error { + ns := stmt.TypeName.Schema + if ns == "" { + ns = c.DefaultSchema + } + schema, err := c.getSchema(ns) + if err != nil { + return err + } + // Because tables have associated data types, the type name must also + // be distinct from the name of any existing table in the same + // schema. + // https://www.postgresql.org/docs/current/sql-createtype.html + tbl := &ast.TableName{ + Name: stmt.TypeName.Name, + } + if _, _, err := schema.getTable(tbl); err == nil { + // return wrap(pg.ErrorRelationAlreadyExists(fqn.Rel), raw.StmtLocation) + return ErrRelationAlreadyExists + } + if _, err := schema.getType(stmt.TypeName); err == nil { + // return wrap(pg.ErrorTypeAlreadyExists(fqn.Rel), raw.StmtLocation) + return ErrRelationAlreadyExists + } + schema.Types = append(schema.Types, Enum{ + Name: stmt.TypeName.Name, + Vals: stringSlice(stmt.Vals), + }) + return nil +} + func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error { ns := stmt.Name.Schema if ns == "" { @@ -223,9 +267,22 @@ type Catalog struct { type Schema struct { Name string Tables []*Table + Types []Type Comment string } +func (s *Schema) getType(rel *ast.TypeName) (Type, error) { + for i := range s.Types { + switch typ := s.Types[i].(type) { + case Enum: + if typ.Name == rel.Name { + return s.Types[i], nil + } + } + } + return nil, ErrRelationNotFound +} + func (s *Schema) getTable(rel *ast.TableName) (*Table, int, error) { for i := range s.Tables { if s.Tables[i].Rel.Name == rel.Name { @@ -247,3 +304,16 @@ type Column struct { Type ast.TypeName IsNotNull bool } + +type Type interface { + isType() +} + +type Enum struct { + Name string + Vals []string + Comment string +} + +func (e Enum) isType() { +}