Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion checks.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ import (
nodes "github.com/lfittl/pg_query_go/nodes"
)

type Error struct {
Message string
Code string
}

func (e Error) Error() string {
return e.Message
}

func validateParamRef(n nodes.Node) error {
var allrefs []nodes.ParamRef

Expand All @@ -24,7 +33,10 @@ func validateParamRef(n nodes.Node) error {

for i := 1; i <= len(seen); i += 1 {
if _, ok := seen[i]; !ok {
return fmt.Errorf("missing parameter reference: $%d", i)
return Error{
Code: "42P18",
Message: fmt.Sprintf("could not determine data type of parameter $%d", i),
}
}
}

Expand Down
59 changes: 35 additions & 24 deletions checks_test.go
Original file line number Diff line number Diff line change
@@ -1,43 +1,54 @@
package dinosql

import (
"fmt"
"testing"

"github.com/google/go-cmp/cmp"
pg "github.com/lfittl/pg_query_go"
)

func TestValidateParamRef(t *testing.T) {
// equateErrorMessage reports errors to be equal if both are nil
// or both have the same message.
equateErrorMessage := cmp.Comparer(func(x, y error) bool {
if x == nil || y == nil {
return x == nil && y == nil
}
return x.Error() == y.Error()
})

func TestParserErrors(t *testing.T) {
for _, tc := range []struct {
query string
err error
err Error
}{
{
"SELECT foo FROM bar WHERE baz = $4;",
fmt.Errorf("missing parameter reference: $1"),
Error{Code: "42P18", Message: "could not determine data type of parameter $1"},
},
{
"SELECT foo FROM bar WHERE baz = $1 AND baz = $3;",
Error{Code: "42P18", Message: "could not determine data type of parameter $2"},
},
{
"ALTER TABLE unknown RENAME TO known;",
Error{Code: "42P01", Message: "relation \"unknown\" does not exist"},
},
{
"SELECT foo FROM bar WHERE baz = $1;",
nil,
"ALTER TABLE unknown DROP COLUMN dropped;",
Error{Code: "42P01", Message: "relation \"unknown\" does not exist"},
},
{
`
CREATE TABLE bar (id serial not null);

-- name: foo :one
SELECT foo FROM bar;
`,
Error{Code: "42703", Message: "column \"foo\" does not exist"},
},
} {
tree, err := pg.Parse(tc.query)
if err != nil {
t.Fatal(err)
}
actual := validateParamRef(tree.Statements[0])
if diff := cmp.Diff(tc.err, actual, equateErrorMessage); diff != "" {
t.Errorf("error mismatch: \n%s", diff)
}
test := tc
t.Run(test.query, func(t *testing.T) {
_, err := parseSQL(test.query)

var actual Error
if err != nil {
actual = err.(Error)
}

if diff := cmp.Diff(test.err, actual); diff != "" {
t.Errorf("error mismatch: \n%s", diff)
}
})
}
}
86 changes: 60 additions & 26 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ func parseSQL(in string) (*Result, error) {
if err != nil {
return nil, err
}
parse(&s, tree)
if err := parse(&s, tree); err != nil {
return nil, err
}

var q []Query
r := Result{Schema: &s}
parseFuncs(&s, &r, in, tree)
if err := parseFuncs(&s, &r, in, tree); err != nil {
return nil, err
}
q = append(q, r.Queries...)

return &Result{Schema: &s, Queries: q}, nil
Expand All @@ -56,12 +60,14 @@ func ParseSchmea(dir string) (*postgres.Schema, error) {
if err != nil {
return nil, err
}
parse(&s, tree)
if err := parse(&s, tree); err != nil {
return nil, err
}
}
return &s, nil
}

func parse(s *postgres.Schema, tree pg.ParsetreeList) {
func parse(s *postgres.Schema, tree pg.ParsetreeList) error {
for _, stmt := range tree.Statements {
raw, ok := stmt.(nodes.RawStmt)
if !ok {
Expand All @@ -76,7 +82,10 @@ func parse(s *postgres.Schema, tree pg.ParsetreeList) {
}
}
if idx < 0 {
panic("could not find table " + *n.Relation.Relname)
return Error{
Code: "42P01",
Message: fmt.Sprintf("relation \"%s\" does not exist", *n.Relation.Relname),
}
}

for _, cmd := range n.Cmds.Items {
Expand Down Expand Up @@ -148,7 +157,10 @@ func parse(s *postgres.Schema, tree pg.ParsetreeList) {
}
}
if idx < 0 {
panic("could not find table " + *n.Relation.Relname)
return Error{
Code: "42P01",
Message: fmt.Sprintf("relation \"%s\" does not exist", *n.Relation.Relname),
}
}
s.Tables[idx].Name = *n.Newname
s.Tables[idx].GoName = structName(*n.Newname)
Expand All @@ -157,6 +169,8 @@ func parse(s *postgres.Schema, tree pg.ParsetreeList) {
// spew.Dump(n)
}
}

return nil
}

func join(list nodes.List, sep string) string {
Expand Down Expand Up @@ -326,19 +340,21 @@ func rangeVars(root nodes.Node) []nodes.RangeVar {
return vars
}

func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeList) {
func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeList) error {
for _, stmt := range tree.Statements {
if err := validateParamRef(stmt); err != nil {
return err
}
raw, ok := stmt.(nodes.RawStmt)
if !ok {
continue
}
switch n := raw.Stmt.(type) {
switch raw.Stmt.(type) {
case nodes.SelectStmt:
case nodes.DeleteStmt:
case nodes.InsertStmt:
case nodes.UpdateStmt:
default:
log.Printf("%T\n", n)
continue
}

Expand All @@ -347,17 +363,21 @@ func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeL
c := columnNames(s, t)

rawSQL, _ := pluckQuery(source, raw)
meta, err := parseQueryMetadata(rawSQL)
if err != nil {
panic(err)
}

refs := extractArgs(raw.Stmt)
outs := findOutputs(raw.Stmt)

tab := getTable(s, t)
args, err := resolveRefs(s, rvs, refs)
if err != nil {
return err
}

meta, err := parseQueryMetadata(rawSQL)
if err != nil {
continue
}
meta.Table = tab
meta.Args = resolveRefs(s, rvs, refs)
meta.Args = args

if len(outs) == 0 {
meta.SQL = rawSQL
Expand All @@ -373,12 +393,17 @@ func parseFuncs(s *postgres.Schema, r *Result, source string, tree pg.ParsetreeL
meta.Fields = fieldsFromRefs(tab, outs)
meta.SQL = rawSQL
} else {
meta.ReturnType = returnType(tab, outs)
rt, err := returnType(tab, outs)
if err != nil {
return err
}
meta.ReturnType = rt
meta.SQL = rawSQL
}

r.Queries = append(r.Queries, meta)
}
return nil
}

func fieldsFromRefs(t postgres.Table, refs []outputRef) []Field {
Expand Down Expand Up @@ -416,13 +441,12 @@ func fieldsFromTable(t postgres.Table) []Field {
return f
}

func returnType(t postgres.Table, refs []outputRef) string {
func returnType(t postgres.Table, refs []outputRef) (string, error) {
if len(refs) != 1 {
// panic("too many return columns")
return "interface{}"
return "", fmt.Errorf("too many return columns")
}
if refs[0].typ != "" {
return refs[0].typ
return refs[0].typ, nil
}
if refs[0].ref != nil {
fields := refs[0].ref.Fields.Items
Expand All @@ -434,11 +458,15 @@ func returnType(t postgres.Table, refs []outputRef) string {
}
for _, c := range t.Columns {
if c.Name == name {
return c.GoType
return c.GoType, nil
}
}
return "", Error{
Code: "42703",
Message: fmt.Sprintf("column \"%s\" does not exist", name),
}
}
return "interface{}"
return "", fmt.Errorf("could not figure out return type")
}

func extractArgs(n nodes.Node) []paramRef {
Expand Down Expand Up @@ -567,7 +595,7 @@ func findOutputs(root nodes.Node) []outputRef {
return v.a.refs
}

func resolveRefs(s *postgres.Schema, rvs []nodes.RangeVar, args []paramRef) []Arg {
func resolveRefs(s *postgres.Schema, rvs []nodes.RangeVar, args []paramRef) ([]Arg, error) {
typeMap := map[string]map[string]string{}
for _, t := range s.Tables {
typeMap[t.Name] = map[string]string{}
Expand Down Expand Up @@ -617,23 +645,29 @@ func resolveRefs(s *postgres.Schema, rvs []nodes.RangeVar, args []paramRef) []Ar
if typ, ok := typeMap[table][key]; ok {
a = append(a, Arg{Name: argName(key), Type: typ})
} else {
panic("unknown column: " + alias + key)
return nil, Error{
Code: "42703",
Message: fmt.Sprintf("column \"%s\" does not exist", key),
}
}
}
case nodes.ResTarget:
key := *n.Name
if typ, ok := typeMap[defaultTable][key]; ok {
a = append(a, Arg{Name: argName(key), Type: typ})
} else {
panic("unknown column: " + key)
return nil, Error{
Code: "42703",
Message: fmt.Sprintf("column \"%s\" does not exist", key),
}
}
case nodes.ParamRef:
a = append(a, Arg{Name: "_", Type: "interface{}"})
default:
panic(fmt.Sprintf("unsupported type: %T", n))
}
}
return a
return a, nil
}

func columnNames(s *postgres.Schema, table string) []string {
Expand Down
Loading