diff --git a/checks.go b/checks.go index e9bad2018c..9fbdc53769 100644 --- a/checks.go +++ b/checks.go @@ -6,6 +6,15 @@ import ( nodes "github.com/lfittl/pg_query_go/nodes" ) +type Error struct { + Message string + Code string +} + +func (e Error) Error() string { + return e.Message +} + func validateParamRef(n nodes.Node) error { var allrefs []nodes.ParamRef @@ -24,7 +33,10 @@ func validateParamRef(n nodes.Node) error { for i := 1; i <= len(seen); i += 1 { if _, ok := seen[i]; !ok { - return fmt.Errorf("missing parameter reference: $%d", i) + return Error{ + Code: "42P18", + Message: fmt.Sprintf("could not determine data type of parameter $%d", i), + } } } diff --git a/checks_test.go b/checks_test.go index a65970417d..56ce0159b2 100644 --- a/checks_test.go +++ b/checks_test.go @@ -1,43 +1,54 @@ package dinosql import ( - "fmt" "testing" "github.com/google/go-cmp/cmp" - pg "github.com/lfittl/pg_query_go" ) -func TestValidateParamRef(t *testing.T) { - // equateErrorMessage reports errors to be equal if both are nil - // or both have the same message. - equateErrorMessage := cmp.Comparer(func(x, y error) bool { - if x == nil || y == nil { - return x == nil && y == nil - } - return x.Error() == y.Error() - }) - +func TestParserErrors(t *testing.T) { for _, tc := range []struct { query string - err error + err Error }{ { "SELECT foo FROM bar WHERE baz = $4;", - fmt.Errorf("missing parameter reference: $1"), + Error{Code: "42P18", Message: "could not determine data type of parameter $1"}, + }, + { + "SELECT foo FROM bar WHERE baz = $1 AND baz = $3;", + Error{Code: "42P18", Message: "could not determine data type of parameter $2"}, + }, + { + "ALTER TABLE unknown RENAME TO known;", + Error{Code: "42P01", Message: "relation \"unknown\" does not exist"}, }, { - "SELECT foo FROM bar WHERE baz = $1;", - nil, + "ALTER TABLE unknown DROP COLUMN dropped;", + Error{Code: "42P01", Message: "relation \"unknown\" does not exist"}, + }, + { + ` + CREATE TABLE bar (id serial not null); + + -- name: foo :one + SELECT foo FROM bar; + `, + Error{Code: "42703", Message: "column \"foo\" does not exist"}, }, } { - tree, err := pg.Parse(tc.query) - if err != nil { - t.Fatal(err) - } - actual := validateParamRef(tree.Statements[0]) - if diff := cmp.Diff(tc.err, actual, equateErrorMessage); diff != "" { - t.Errorf("error mismatch: \n%s", diff) - } + test := tc + t.Run(test.query, func(t *testing.T) { + _, err := parseSQL(test.query) + + var actual Error + if err != nil { + actual = err.(Error) + } + + if diff := cmp.Diff(test.err, actual); diff != "" { + t.Errorf("error mismatch: \n%s", diff) + } + }) } } diff --git a/parser.go b/parser.go index 6de1c6d552..82879b77c6 100644 --- a/parser.go +++ b/parser.go @@ -26,11 +26,15 @@ func parseSQL(in string) (*Result, error) { if err != nil { return nil, err } - parse(&s, tree) + if err := parse(&s, tree); err != nil { + return nil, err + } var q []Query r := Result{Schema: &s} - parseFuncs(&s, &r, in, tree) + if err := parseFuncs(&s, &r, in, tree); err != nil { + return nil, err + } q = append(q, r.Queries...) return &Result{Schema: &s, Queries: q}, nil @@ -56,12 +60,14 @@ func ParseSchmea(dir string) (*postgres.Schema, error) { if err != nil { return nil, err } - parse(&s, tree) + if err := parse(&s, tree); err != nil { + return nil, err + } } return &s, nil } -func parse(s *postgres.Schema, tree pg.ParsetreeList) { +func parse(s *postgres.Schema, tree pg.ParsetreeList) error { for _, stmt := range tree.Statements { raw, ok := stmt.(nodes.RawStmt) if !ok { @@ -76,7 +82,10 @@ func parse(s *postgres.Schema, tree pg.ParsetreeList) { } } if idx < 0 { - panic("could not find table " + *n.Relation.Relname) + return Error{ + Code: "42P01", + Message: fmt.Sprintf("relation \"%s\" does not exist", *n.Relation.Relname), + } } for _, cmd := range n.Cmds.Items { @@ -148,7 +157,10 @@ func parse(s *postgres.Schema, tree pg.ParsetreeList) { } } if idx < 0 { - panic("could not find table " + *n.Relation.Relname) + return Error{ + Code: "42P01", + Message: fmt.Sprintf("relation \"%s\" does not exist", *n.Relation.Relname), + } } s.Tables[idx].Name = *n.Newname s.Tables[idx].GoName = structName(*n.Newname) @@ -157,6 +169,8 @@ func parse(s *postgres.Schema, tree pg.ParsetreeList) { // spew.Dump(n) } } + + return nil } func join(list nodes.List, sep string) string { @@ -326,19 +340,21 @@ func rangeVars(root nodes.Node) []nodes.RangeVar { return vars } -func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeList) { +func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeList) error { for _, stmt := range tree.Statements { + if err := validateParamRef(stmt); err != nil { + return err + } raw, ok := stmt.(nodes.RawStmt) if !ok { continue } - switch n := raw.Stmt.(type) { + switch raw.Stmt.(type) { case nodes.SelectStmt: case nodes.DeleteStmt: case nodes.InsertStmt: case nodes.UpdateStmt: default: - log.Printf("%T\n", n) continue } @@ -347,17 +363,21 @@ func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeL c := columnNames(s, t) rawSQL, _ := pluckQuery(source, raw) - meta, err := parseQueryMetadata(rawSQL) - if err != nil { - panic(err) - } - refs := extractArgs(raw.Stmt) outs := findOutputs(raw.Stmt) tab := getTable(s, t) + args, err := resolveRefs(s, rvs, refs) + if err != nil { + return err + } + + meta, err := parseQueryMetadata(rawSQL) + if err != nil { + continue + } meta.Table = tab - meta.Args = resolveRefs(s, rvs, refs) + meta.Args = args if len(outs) == 0 { meta.SQL = rawSQL @@ -373,12 +393,17 @@ func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeL meta.Fields = fieldsFromRefs(tab, outs) meta.SQL = rawSQL } else { - meta.ReturnType = returnType(tab, outs) + rt, err := returnType(tab, outs) + if err != nil { + return err + } + meta.ReturnType = rt meta.SQL = rawSQL } r.Queries = append(r.Queries, meta) } + return nil } func fieldsFromRefs(t postgres.Table, refs []outputRef) []Field { @@ -416,13 +441,12 @@ func fieldsFromTable(t postgres.Table) []Field { return f } -func returnType(t postgres.Table, refs []outputRef) string { +func returnType(t postgres.Table, refs []outputRef) (string, error) { if len(refs) != 1 { - // panic("too many return columns") - return "interface{}" + return "", fmt.Errorf("too many return columns") } if refs[0].typ != "" { - return refs[0].typ + return refs[0].typ, nil } if refs[0].ref != nil { fields := refs[0].ref.Fields.Items @@ -434,11 +458,15 @@ func returnType(t postgres.Table, refs []outputRef) string { } for _, c := range t.Columns { if c.Name == name { - return c.GoType + return c.GoType, nil } } + return "", Error{ + Code: "42703", + Message: fmt.Sprintf("column \"%s\" does not exist", name), + } } - return "interface{}" + return "", fmt.Errorf("could not figure out return type") } func extractArgs(n nodes.Node) []paramRef { @@ -567,7 +595,7 @@ func findOutputs(root nodes.Node) []outputRef { return v.a.refs } -func resolveRefs(s *postgres.Schema, rvs []nodes.RangeVar, args []paramRef) []Arg { +func resolveRefs(s *postgres.Schema, rvs []nodes.RangeVar, args []paramRef) ([]Arg, error) { typeMap := map[string]map[string]string{} for _, t := range s.Tables { typeMap[t.Name] = map[string]string{} @@ -617,7 +645,10 @@ func resolveRefs(s *postgres.Schema, rvs []nodes.RangeVar, args []paramRef) []Ar if typ, ok := typeMap[table][key]; ok { a = append(a, Arg{Name: argName(key), Type: typ}) } else { - panic("unknown column: " + alias + key) + return nil, Error{ + Code: "42703", + Message: fmt.Sprintf("column \"%s\" does not exist", key), + } } } case nodes.ResTarget: @@ -625,7 +656,10 @@ func resolveRefs(s *postgres.Schema, rvs []nodes.RangeVar, args []paramRef) []Ar if typ, ok := typeMap[defaultTable][key]; ok { a = append(a, Arg{Name: argName(key), Type: typ}) } else { - panic("unknown column: " + key) + return nil, Error{ + Code: "42703", + Message: fmt.Sprintf("column \"%s\" does not exist", key), + } } case nodes.ParamRef: a = append(a, Arg{Name: "_", Type: "interface{}"}) @@ -633,7 +667,7 @@ func resolveRefs(s *postgres.Schema, rvs []nodes.RangeVar, args []paramRef) []Ar panic(fmt.Sprintf("unsupported type: %T", n)) } } - return a + return a, nil } func columnNames(s *postgres.Schema, table string) []string { diff --git a/soup.go b/soup.go index dafc652c00..fb649476e2 100644 --- a/soup.go +++ b/soup.go @@ -166,7 +166,9 @@ func Walk(f Visitor, node nodes.Node) { walkn(f, n.Options) case nodes.AlterSeqStmt: - walkn(f, n.Sequence) + if n.Sequence != nil { + walkn(f, *n.Sequence) + } walkn(f, n.Options) case nodes.AlterSubscriptionStmt: @@ -196,7 +198,9 @@ func Walk(f Visitor, node nodes.Node) { walkn(f, n.Options) case nodes.AlterTableStmt: - walkn(f, n.Relation) + if n.Relation != nil { + walkn(f, *n.Relation) + } walkn(f, n.Cmds) case nodes.AlterUserMappingStmt: @@ -257,7 +261,9 @@ func Walk(f Visitor, node nodes.Node) { // pass case nodes.ClusterStmt: - walkn(f, n.Relation) + if n.Relation != nil { + walkn(f, *n.Relation) + } case nodes.CoalesceExpr: walkn(f, n.Xpr) @@ -304,7 +310,9 @@ func Walk(f Visitor, node nodes.Node) { walkn(f, n.Ctecolcollations) case nodes.CompositeTypeStmt: - walkn(f, n.Typevar) + if n.Typevar != nil { + walkn(f, *n.Typevar) + } walkn(f, n.Coldeflist) case nodes.Const: @@ -316,7 +324,9 @@ func Walk(f Visitor, node nodes.Node) { walkn(f, n.Exclusions) walkn(f, n.Options) walkn(f, n.WhereClause) - walkn(f, n.Pktable) + if n.Pktable != nil { + walkn(f, *n.Pktable) + } walkn(f, n.FkAttrs) walkn(f, n.PkAttrs) walkn(f, n.OldConpfeqop) @@ -329,7 +339,9 @@ func Walk(f Visitor, node nodes.Node) { walkn(f, n.Arg) case nodes.CopyStmt: - walkn(f, n.Relation) + if n.Relation != nil { + walkn(f, *n.Relation) + } walkn(f, n.Query) walkn(f, n.Attlist) walkn(f, n.Options) @@ -416,7 +428,9 @@ func Walk(f Visitor, node nodes.Node) { walkn(f, n.Plvalidator) case nodes.CreatePolicyStmt: - walkn(f, n.Table) + if n.Table != nil { + walkn(f, *n.Table) + } walkn(f, n.Roles) walkn(f, n.Qual) walkn(f, n.WithCheck) @@ -437,7 +451,9 @@ func Walk(f Visitor, node nodes.Node) { walkn(f, n.SchemaElts) case nodes.CreateSeqStmt: - walkn(f, n.Sequence) + if n.Sequence != nil { + walkn(f, *n.Sequence) + } walkn(f, n.Options) case nodes.CreateStatsStmt: @@ -447,7 +463,9 @@ func Walk(f Visitor, node nodes.Node) { walkn(f, n.Relations) case nodes.CreateStmt: - walkn(f, n.Relation) + if n.Relation != nil { + walkn(f, *n.Relation) + } walkn(f, n.TableElts) walkn(f, n.InhRelations) if n.Partbound != nil { @@ -482,13 +500,17 @@ func Walk(f Visitor, node nodes.Node) { } case nodes.CreateTrigStmt: - walkn(f, n.Relation) + if n.Relation != nil { + walkn(f, *n.Relation) + } walkn(f, n.Funcname) walkn(f, n.Args) walkn(f, n.Columns) walkn(f, n.WhenClause) walkn(f, n.TransitionRels) - walkn(f, n.Constrrel) + if n.Constrrel != nil { + walkn(f, *n.Constrrel) + } case nodes.CreateUserMappingStmt: walkn(f, n.User)