Skip to content

ast: Add support for CREATE TYPE as ENUM #388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 61 additions & 4 deletions internal/compiler/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"os"
"regexp"
"sort"
"strings"

Expand All @@ -23,6 +24,35 @@ type Parser interface {
Parse(io.Reader) ([]ast.Statement, error)
}

// copied over from gen.go
func structName(name string) string {
out := ""
for _, p := range strings.Split(name, "_") {
if p == "id" {
out += "ID"
} else {
out += strings.Title(p)
}
}
return out
}

var identPattern = regexp.MustCompile("[^a-zA-Z0-9_]+")

func enumValueName(value string) string {
name := ""
id := strings.Replace(value, "-", "_", -1)
id = strings.Replace(id, ":", "_", -1)
id = strings.Replace(id, "/", "_", -1)
id = identPattern.ReplaceAllString(id, "")
for _, part := range strings.Split(id, "_") {
name += strings.Title(part)
}
return name
}

// end copypasta

func Run(conf config.SQL, combo config.CombinedSettings) (*Result, error) {
var p Parser

Expand Down Expand Up @@ -53,25 +83,52 @@ func Run(conf config.SQL, combo config.CombinedSettings) (*Result, error) {
}

var structs []dinosql.GoStruct
var enums []dinosql.GoEnum
for _, schema := range c.Schemas {
for _, table := range schema.Tables {
s := dinosql.GoStruct{
Table: pg.FQN{Schema: table.Rel.Schema, Rel: table.Rel.Name},
Table: pg.FQN{Schema: schema.Name, Rel: table.Rel.Name},
Name: strings.Title(table.Rel.Name),
}
for _, col := range table.Columns {
s.Fields = append(s.Fields, dinosql.GoField{
Name: strings.Title(col.Name),
Name: structName(col.Name),
Type: "string",
Tags: map[string]string{"json:": col.Name},
})
}
structs = append(structs, s)
}
for _, typ := range schema.Types {
switch t := typ.(type) {
case catalog.Enum:
var name string
// TODO: This name should be public, not main
if schema.Name == "main" {
name = t.Name
} else {
name = schema.Name + "_" + t.Name
}
e := dinosql.GoEnum{
Name: structName(name),
}
for _, v := range t.Vals {
e.Constants = append(e.Constants, dinosql.GoConstant{
Name: e.Name + enumValueName(v),
Value: v,
Type: e.Name,
})
}
enums = append(enums, e)
}
}
}

if len(structs) > 0 {
sort.Slice(structs, func(i, j int) bool { return structs[i].Name < structs[j].Name })
}

return &Result{structs: structs}, nil
if len(enums) > 0 {
sort.Slice(enums, func(i, j int) bool { return enums[i].Name < enums[j].Name })
}
return &Result{structs: structs, enums: enums}, nil
}
3 changes: 2 additions & 1 deletion internal/compiler/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
)

type Result struct {
enums []dinosql.GoEnum
structs []dinosql.GoStruct
queries []dinosql.GoQuery
}
Expand All @@ -19,5 +20,5 @@ func (r *Result) GoQueries(settings config.CombinedSettings) []dinosql.GoQuery {
}

func (r *Result) Enums(settings config.CombinedSettings) []dinosql.GoEnum {
return nil
return r.enums
}
24 changes: 23 additions & 1 deletion internal/endtoend/testdata/experimental_elephant/go/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions internal/endtoend/testdata/experimental_elephant/query.sql
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ CREATE TABLE bar (
baz text NOT NULL
);

CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');

SELECT bar FROM foo;

DROP TABLE bar;
Expand Down
43 changes: 43 additions & 0 deletions internal/postgresql/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,30 @@ func stringSlice(list nodes.List) []string {
return items
}

func parseTypeName(node nodes.Node) (*ast.TypeName, error) {
switch n := node.(type) {

case nodes.List:
parts := stringSlice(n)
switch len(parts) {
case 1:
return &ast.TypeName{
Name: parts[0],
}, nil
case 2:
return &ast.TypeName{
Schema: parts[0],
Name: parts[1],
}, nil
default:
return nil, fmt.Errorf("invalid type name: %s", join(n, "."))
}

default:
return nil, fmt.Errorf("unexpected node type: %T", n)
}
}

func parseTableName(node nodes.Node) (*ast.TableName, error) {
switch n := node.(type) {

Expand Down Expand Up @@ -180,6 +204,25 @@ func translate(node nodes.Node) (ast.Node, error) {
}
return create, nil

case nodes.CreateEnumStmt:
name, err := parseTypeName(n.TypeName)
if err != nil {
return nil, err
}
stmt := &ast.CreateEnumStmt{
TypeName: name,
Vals: &ast.List{},
}
for _, val := range n.Vals.Items {
switch v := val.(type) {
case nodes.String:
stmt.Vals.Items = append(stmt.Vals.Items, &ast.String{
Str: v.Str,
})
}
}
return stmt, nil

case nodes.DropStmt:
drop := &ast.DropTableStmt{
IfExists: n.MissingOk,
Expand Down
20 changes: 19 additions & 1 deletion internal/sql/ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ func (n *AlterTableCmd) Pos() int {
return 0
}

type CreateEnumStmt struct {
TypeName *TypeName
Vals *List
}

func (n *CreateEnumStmt) Pos() int {
return 0
}

type CreateTableStmt struct {
IfNotExists bool
Name *TableName
Expand Down Expand Up @@ -88,7 +97,8 @@ func (n *ColumnDef) Pos() int {
}

type TypeName struct {
Name string
Schema string
Name string
}

func (n *TypeName) Pos() int {
Expand Down Expand Up @@ -127,3 +137,11 @@ type ColumnRef struct {
func (n *ColumnRef) Pos() int {
return 0
}

type String struct {
Str string
}

func (n *String) Pos() int {
return 0
}
70 changes: 70 additions & 0 deletions internal/sql/catalog/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ func Build(stmts []ast.Statement) (*Catalog, error) {
switch n := stmts[i].Raw.Stmt.(type) {
case *ast.AlterTableStmt:
err = c.alterTable(n)
case *ast.CreateEnumStmt:
err = c.createEnum(n)
case *ast.CreateTableStmt:
err = c.createTable(n)
case *ast.DropTableStmt:
Expand All @@ -33,8 +35,19 @@ func Build(stmts []ast.Statement) (*Catalog, error) {
return c, nil
}

func stringSlice(list *ast.List) []string {
items := []string{}
for _, item := range list.Items {
if n, ok := item.(*ast.String); ok {
items = append(items, n.Str)
}
}
return items
}

// TODO: This need to be rich error types
var ErrRelationNotFound = errors.New("relation not found")
var ErrRelationAlreadyExists = errors.New("relation already exists")
var ErrSchemaNotFound = errors.New("schema not found")
var ErrColumnNotFound = errors.New("column not found")
var ErrColumnExists = errors.New("column already exists")
Expand Down Expand Up @@ -159,6 +172,37 @@ func (c *Catalog) alterTable(stmt *ast.AlterTableStmt) error {
return nil
}

func (c *Catalog) createEnum(stmt *ast.CreateEnumStmt) error {
ns := stmt.TypeName.Schema
if ns == "" {
ns = c.DefaultSchema
}
schema, err := c.getSchema(ns)
if err != nil {
return err
}
// Because tables have associated data types, the type name must also
// be distinct from the name of any existing table in the same
// schema.
// https://www.postgresql.org/docs/current/sql-createtype.html
tbl := &ast.TableName{
Name: stmt.TypeName.Name,
}
if _, _, err := schema.getTable(tbl); err == nil {
// return wrap(pg.ErrorRelationAlreadyExists(fqn.Rel), raw.StmtLocation)
return ErrRelationAlreadyExists
}
if _, err := schema.getType(stmt.TypeName); err == nil {
// return wrap(pg.ErrorTypeAlreadyExists(fqn.Rel), raw.StmtLocation)
return ErrRelationAlreadyExists
}
schema.Types = append(schema.Types, Enum{
Name: stmt.TypeName.Name,
Vals: stringSlice(stmt.Vals),
})
return nil
}

func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error {
ns := stmt.Name.Schema
if ns == "" {
Expand Down Expand Up @@ -223,9 +267,22 @@ type Catalog struct {
type Schema struct {
Name string
Tables []*Table
Types []Type
Comment string
}

func (s *Schema) getType(rel *ast.TypeName) (Type, error) {
for i := range s.Types {
switch typ := s.Types[i].(type) {
case Enum:
if typ.Name == rel.Name {
return s.Types[i], nil
}
}
}
return nil, ErrRelationNotFound
}

func (s *Schema) getTable(rel *ast.TableName) (*Table, int, error) {
for i := range s.Tables {
if s.Tables[i].Rel.Name == rel.Name {
Expand All @@ -247,3 +304,16 @@ type Column struct {
Type ast.TypeName
IsNotNull bool
}

type Type interface {
isType()
}

type Enum struct {
Name string
Vals []string
Comment string
}

func (e Enum) isType() {
}