Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions internal/compiler/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
}
}

// Override the inferrerd type and nullability of annotated named params
for i, param := range anlys.Parameters {
if !param.Column.IsNamedParam {
continue
}
if paramMetadata, ok := md.Params[param.Column.Name]; ok {
anlys.Parameters[i].Column.DataType = paramMetadata.DatabaseType
switch paramMetadata.Nullability {
case metadata.ParamNullabilityForceNotNull:
anlys.Parameters[i].Column.NotNull = true
case metadata.ParamNullabilityForceNullable:
anlys.Parameters[i].Column.NotNull = false
}
}
}

expanded := anlys.Query

// If the query string was edited, make sure the syntax is valid
Expand Down
31 changes: 31 additions & 0 deletions internal/endtoend/testdata/param_type_annotations/db/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions internal/endtoend/testdata/param_type_annotations/db/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

73 changes: 73 additions & 0 deletions internal/endtoend/testdata/param_type_annotations/db/query.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions internal/endtoend/testdata/param_type_annotations/query.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
-- name: TestSqlcArg :one
-- @param foo text
SELECT * FROM test WHERE id = sqlc.arg('foo');

-- name: TestAt :one
-- @param foo integer
SELECT * FROM test WHERE name = @foo;

-- name: TestForceNotNull :one
-- @param foo! uuid
SELECT * FROM test WHERE name = @foo;

-- name: TestForceNullable :one
-- @param foo? uuid
SELECT * FROM test WHERE id = @foo;

-- name: TestGibberish :one
-- @param foo? uuid sdfagyi
SELECT * FROM test WHERE id = @foo;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE test (id INTEGER NOT NULL, name TEXT);
8 changes: 8 additions & 0 deletions internal/endtoend/testdata/param_type_annotations/sqlc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
version: "2"
sql:
- schema: schema.sql
queries: query.sql
engine: postgresql
gen:
go:
out: db
33 changes: 29 additions & 4 deletions internal/metadata/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type Metadata struct {
Name string
Cmd string
Comments []string
Params map[string]string
Params map[string]ParamMetadata
Flags map[string]bool

Filename string
Expand All @@ -34,6 +34,19 @@ const (
CmdBatchOne = ":batchone"
)

type ParamMetadata struct {
DatabaseType string
Nullability ParamNullability
}

type ParamNullability int

const (
ParamNullabilityUnspecified ParamNullability = iota // ""
ParamNullabilityForceNotNull // "!"
ParamNullabilityForceNullable // "?"
)

// A query name must be a valid Go identifier
//
// https://golang.org/ref/spec#Identifiers
Expand Down Expand Up @@ -113,8 +126,8 @@ func ParseQueryNameAndType(t string, commentStyle CommentSyntax) (string, string
return "", "", nil
}

func ParseParamsAndFlags(comments []string) (map[string]string, map[string]bool, error) {
params := make(map[string]string)
func ParseParamsAndFlags(comments []string) (map[string]ParamMetadata, map[string]bool, error) {
params := make(map[string]ParamMetadata)
flags := make(map[string]bool)

for _, line := range comments {
Expand All @@ -137,7 +150,19 @@ func ParseParamsAndFlags(comments []string) (map[string]string, map[string]bool,
paramToken := s.Text()
rest = append(rest, paramToken)
}
params[name] = strings.Join(rest, " ")
var nullability ParamNullability
switch {
case strings.HasSuffix(name, "!"):
name = name[:len(name)-1]
nullability = ParamNullabilityForceNotNull
case strings.HasSuffix(name, "?"):
name = name[:len(name)-1]
nullability = ParamNullabilityForceNullable
}
params[name] = ParamMetadata{
DatabaseType: strings.Join(rest, " "),
Nullability: nullability,
}
default:
flags[token] = true
}
Expand Down
2 changes: 1 addition & 1 deletion internal/metadata/meta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func TestParseQueryParams(t *testing.T) {
t.Errorf("expected param not found")
}

if pt != "UUID" {
if pt.DatabaseType != "UUID" {
t.Error("unexpected param metadata:", pt)
}

Expand Down
18 changes: 6 additions & 12 deletions internal/sql/validate/param_ref.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,26 @@ import (
)

func ParamRef(n ast.Node) (map[int]bool, bool, error) {
var allrefs []*ast.ParamRef
var dollar bool
var nodollar bool
seen := map[int]bool{}
var dollar, nodollar bool
// Find all parameter references
astutils.Walk(astutils.VisitorFunc(func(node ast.Node) {
switch n := node.(type) {
case *ast.ParamRef:
ref := node.(*ast.ParamRef)
if ref.Dollar {
if n.Dollar {
dollar = true
} else {
nodollar = true
}
allrefs = append(allrefs, n)
if n.Number > 0 {
seen[n.Number] = true
}
}
}), n)
if dollar && nodollar {
return nil, false, errors.New("can not mix $1 format with ? format")
}

seen := map[int]bool{}
for _, r := range allrefs {
if r.Number > 0 {
seen[r.Number] = true
}
}
for i := 1; i <= len(seen); i += 1 {
if _, ok := seen[i]; !ok {
return seen, !nodollar, &sqlerr.Error{
Expand Down