Skip to content

internal/config: Create new config package #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import (
"os/exec"
"path/filepath"

"github.com/kyleconroy/sqlc/internal/dinosql"

"github.com/davecgh/go-spew/spew"
pg "github.com/lfittl/pg_query_go"
"github.com/spf13/cobra"

"github.com/kyleconroy/sqlc/internal/config"
"github.com/kyleconroy/sqlc/internal/dinosql"
)

// Do runs the command logic.
Expand Down Expand Up @@ -74,7 +75,7 @@ var initCmd = &cobra.Command{
if _, err := os.Stat("sqlc.json"); !os.IsNotExist(err) {
return nil
}
blob, err := json.MarshalIndent(dinosql.GenerateSettings{Version: "1"}, "", " ")
blob, err := json.MarshalIndent(config.GenerateSettings{Version: "1"}, "", " ")
if err != nil {
return err
}
Expand Down Expand Up @@ -117,7 +118,7 @@ var checkCmd = &cobra.Command{
return err
}

settings, err := dinosql.ParseConfig(file)
settings, err := config.ParseConfig(file)
if err != nil {
return err
}
Expand Down
15 changes: 8 additions & 7 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"path/filepath"
"strings"

"github.com/kyleconroy/sqlc/internal/config"
"github.com/kyleconroy/sqlc/internal/dinosql"
"github.com/kyleconroy/sqlc/internal/mysql"
)
Expand Down Expand Up @@ -39,14 +40,14 @@ func Generate(dir string, stderr io.Writer) (map[string]string, error) {
return nil, err
}

settings, err := dinosql.ParseConfig(bytes.NewReader(blob))
settings, err := config.ParseConfig(bytes.NewReader(blob))
if err != nil {
switch err {
case dinosql.ErrMissingVersion:
case config.ErrMissingVersion:
fmt.Fprintf(stderr, errMessageNoVersion)
case dinosql.ErrUnknownVersion:
case config.ErrUnknownVersion:
fmt.Fprintf(stderr, errMessageUnknownVersion)
case dinosql.ErrNoPackages:
case config.ErrNoPackages:
fmt.Fprintf(stderr, errMessageNoPackages)
}
fmt.Fprintf(stderr, "error parsing sqlc.json: %s\n", err)
Expand All @@ -58,7 +59,7 @@ func Generate(dir string, stderr io.Writer) (map[string]string, error) {

for _, pkg := range settings.Packages {
name := pkg.Name
combo := dinosql.Combine(settings, pkg)
combo := config.Combine(settings, pkg)
var result dinosql.Generateable

// TODO: This feels like a hack that will bite us later
Expand All @@ -67,7 +68,7 @@ func Generate(dir string, stderr io.Writer) (map[string]string, error) {

switch pkg.Engine {

case dinosql.EngineMySQL:
case config.EngineMySQL:
// Experimental MySQL support
q, err := mysql.GeneratePkg(name, pkg.Schema, pkg.Queries, combo)
if err != nil {
Expand All @@ -84,7 +85,7 @@ func Generate(dir string, stderr io.Writer) (map[string]string, error) {
}
result = q

case dinosql.EnginePostgreSQL:
case config.EnginePostgreSQL:
c, err := dinosql.ParseCatalog(pkg.Schema)
if err != nil {
fmt.Fprintf(stderr, "# package %s\n", name)
Expand Down
12 changes: 6 additions & 6 deletions internal/dinosql/config.go → internal/config/config.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dinosql
package config

import (
"encoding/json"
Expand Down Expand Up @@ -74,8 +74,8 @@ type Override struct {
ColumnName string
Table pg.FQN
GoTypeName string
goPackage string
goBasicType bool
GoPackage string
GoBasicType bool
}

func (c *GenerateSettings) ValidateGlobalOverrides() error {
Expand Down Expand Up @@ -154,7 +154,7 @@ func (o *Override) Parse() error {
if !found {
return fmt.Errorf("Package override `go_type` specifier %q is not a Go basic type e.g. 'string'", o.GoType)
}
o.goBasicType = true
o.GoBasicType = true
} else {
// assume the type lives in a Go package
if lastDot == -1 {
Expand All @@ -174,12 +174,12 @@ func (o *Override) Parse() error {
if strings.HasSuffix(typename, "-go") {
typename = typename[:len(typename)-len("-go")]
}
o.goPackage = o.GoType[:lastDot]
o.GoPackage = o.GoType[:lastDot]
}
o.GoTypeName = typename
isPointer := o.GoType[0] == '*'
if isPointer {
o.goPackage = o.goPackage[1:]
o.GoPackage = o.GoPackage[1:]
o.GoTypeName = "*" + o.GoTypeName
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package dinosql
package config

import (
"strings"
Expand Down Expand Up @@ -107,10 +107,10 @@ func TestTypeOverrides(t *testing.T) {
if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" {
t.Errorf("type name mismatch;\n%s", diff)
}
if diff := cmp.Diff(tt.pkg, tt.override.goPackage); diff != "" {
if diff := cmp.Diff(tt.pkg, tt.override.GoPackage); diff != "" {
t.Errorf("package mismatch;\n%s", diff)
}
if diff := cmp.Diff(tt.basic, tt.override.goBasicType); diff != "" {
if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" {
t.Errorf("basic mismatch;\n%s", diff)
}
})
Expand Down
49 changes: 25 additions & 24 deletions internal/dinosql/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"unicode"

"github.com/kyleconroy/sqlc/internal/catalog"
"github.com/kyleconroy/sqlc/internal/config"
core "github.com/kyleconroy/sqlc/internal/pg"

"github.com/jinzhu/inflection"
Expand Down Expand Up @@ -159,12 +160,12 @@ type GoQuery struct {
}

type Generateable interface {
Structs(settings CombinedSettings) []GoStruct
GoQueries(settings CombinedSettings) []GoQuery
Enums(settings CombinedSettings) []GoEnum
Structs(settings config.CombinedSettings) []GoStruct
GoQueries(settings config.CombinedSettings) []GoQuery
Enums(settings config.CombinedSettings) []GoEnum
}

func UsesType(r Generateable, typ string, settings CombinedSettings) bool {
func UsesType(r Generateable, typ string, settings config.CombinedSettings) bool {
for _, strct := range r.Structs(settings) {
for _, f := range strct.Fields {
fType := strings.TrimPrefix(f.Type, "[]")
Expand All @@ -176,7 +177,7 @@ func UsesType(r Generateable, typ string, settings CombinedSettings) bool {
return false
}

func UsesArrays(r Generateable, settings CombinedSettings) bool {
func UsesArrays(r Generateable, settings config.CombinedSettings) bool {
for _, strct := range r.Structs(settings) {
for _, f := range strct.Fields {
if strings.HasPrefix(f.Type, "[]") {
Expand All @@ -187,7 +188,7 @@ func UsesArrays(r Generateable, settings CombinedSettings) bool {
return false
}

func Imports(r Generateable, settings CombinedSettings) func(string) [][]string {
func Imports(r Generateable, settings config.CombinedSettings) func(string) [][]string {
return func(filename string) [][]string {
if filename == "db.go" {
imps := []string{"context", "database/sql"}
Expand All @@ -209,7 +210,7 @@ func Imports(r Generateable, settings CombinedSettings) func(string) [][]string
}
}

func InterfaceImports(r Generateable, settings CombinedSettings) [][]string {
func InterfaceImports(r Generateable, settings config.CombinedSettings) [][]string {
gq := r.GoQueries(settings)
uses := func(name string) bool {
for _, q := range gq {
Expand Down Expand Up @@ -246,10 +247,10 @@ func InterfaceImports(r Generateable, settings CombinedSettings) [][]string {
pkg := make(map[string]struct{})
overrideTypes := map[string]string{}
for _, o := range settings.Overrides {
if o.goBasicType {
if o.GoBasicType {
continue
}
overrideTypes[o.GoTypeName] = o.goPackage
overrideTypes[o.GoTypeName] = o.GoPackage
}

_, overrideNullTime := overrideTypes["pq.NullTime"]
Expand Down Expand Up @@ -283,7 +284,7 @@ func InterfaceImports(r Generateable, settings CombinedSettings) [][]string {
return [][]string{stds, pkgs}
}

func ModelImports(r Generateable, settings CombinedSettings) [][]string {
func ModelImports(r Generateable, settings config.CombinedSettings) [][]string {
std := make(map[string]struct{})
if UsesType(r, "sql.Null", settings) {
std["database/sql"] = struct{}{}
Expand All @@ -302,10 +303,10 @@ func ModelImports(r Generateable, settings CombinedSettings) [][]string {
pkg := make(map[string]struct{})
overrideTypes := map[string]string{}
for _, o := range settings.Overrides {
if o.goBasicType {
if o.GoBasicType {
continue
}
overrideTypes[o.GoTypeName] = o.goPackage
overrideTypes[o.GoTypeName] = o.GoPackage
}

_, overrideNullTime := overrideTypes["pq.NullTime"]
Expand Down Expand Up @@ -339,7 +340,7 @@ func ModelImports(r Generateable, settings CombinedSettings) [][]string {
return [][]string{stds, pkgs}
}

func QueryImports(r Generateable, settings CombinedSettings, filename string) [][]string {
func QueryImports(r Generateable, settings config.CombinedSettings, filename string) [][]string {
// for _, strct := range r.Structs() {
// for _, f := range strct.Fields {
// if strings.HasPrefix(f.Type, "[]") {
Expand Down Expand Up @@ -437,10 +438,10 @@ func QueryImports(r Generateable, settings CombinedSettings, filename string) []
pkg := make(map[string]struct{})
overrideTypes := map[string]string{}
for _, o := range settings.Overrides {
if o.goBasicType {
if o.GoBasicType {
continue
}
overrideTypes[o.GoTypeName] = o.goPackage
overrideTypes[o.GoTypeName] = o.GoPackage
}

if sliceScan() {
Expand Down Expand Up @@ -489,7 +490,7 @@ func enumValueName(value string) string {
return name
}

func (r Result) Enums(settings CombinedSettings) []GoEnum {
func (r Result) Enums(settings config.CombinedSettings) []GoEnum {
var enums []GoEnum
for name, schema := range r.Catalog.Schemas {
if name == "pg_catalog" {
Expand Down Expand Up @@ -522,7 +523,7 @@ func (r Result) Enums(settings CombinedSettings) []GoEnum {
return enums
}

func StructName(name string, settings CombinedSettings) string {
func StructName(name string, settings config.CombinedSettings) string {
if rename := settings.Global.Rename[name]; rename != "" {
return rename
}
Expand All @@ -537,7 +538,7 @@ func StructName(name string, settings CombinedSettings) string {
return out
}

func (r Result) Structs(settings CombinedSettings) []GoStruct {
func (r Result) Structs(settings config.CombinedSettings) []GoStruct {
var structs []GoStruct
for name, schema := range r.Catalog.Schemas {
if name == "pg_catalog" {
Expand Down Expand Up @@ -572,7 +573,7 @@ func (r Result) Structs(settings CombinedSettings) []GoStruct {
return structs
}

func (r Result) goType(col core.Column, settings CombinedSettings) string {
func (r Result) goType(col core.Column, settings config.CombinedSettings) string {
// package overrides have a higher precedence
for _, oride := range settings.Overrides {
if oride.Column != "" && oride.ColumnName == col.Name && oride.Table == col.Table {
Expand All @@ -586,7 +587,7 @@ func (r Result) goType(col core.Column, settings CombinedSettings) string {
return typ
}

func (r Result) goInnerType(col core.Column, settings CombinedSettings) string {
func (r Result) goInnerType(col core.Column, settings config.CombinedSettings) string {
columnType := col.DataType
notNull := col.NotNull || col.IsArray

Expand Down Expand Up @@ -741,7 +742,7 @@ func (r Result) goInnerType(col core.Column, settings CombinedSettings) string {
// JSON tags: count, count_2, count_2
//
// This is unlikely to happen, so don't fix it yet
func (r Result) columnsToStruct(name string, columns []core.Column, settings CombinedSettings) *GoStruct {
func (r Result) columnsToStruct(name string, columns []core.Column, settings config.CombinedSettings) *GoStruct {
gs := GoStruct{
Name: name,
}
Expand Down Expand Up @@ -801,7 +802,7 @@ func compareFQN(a *core.FQN, b *core.FQN) bool {
return a.Catalog == b.Catalog && a.Schema == b.Schema && a.Rel == b.Rel
}

func (r Result) GoQueries(settings CombinedSettings) []GoQuery {
func (r Result) GoQueries(settings config.CombinedSettings) []GoQuery {
structs := r.Structs(settings)

qs := make([]GoQuery, 0, len(r.Queries))
Expand Down Expand Up @@ -1182,7 +1183,7 @@ type tmplCtx struct {
Enums []GoEnum
Structs []GoStruct
GoQueries []GoQuery
Settings GenerateSettings
Settings config.GenerateSettings

// TODO: Race conditions
SourceName string
Expand All @@ -1198,7 +1199,7 @@ func LowerTitle(s string) string {
return string(a)
}

func Generate(r Generateable, settings CombinedSettings) (map[string]string, error) {
func Generate(r Generateable, settings config.CombinedSettings) (map[string]string, error) {
funcMap := template.FuncMap{
"lowerTitle": LowerTitle,
"imports": Imports(r, settings),
Expand Down
3 changes: 2 additions & 1 deletion internal/dinosql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"unicode"

"github.com/kyleconroy/sqlc/internal/catalog"
"github.com/kyleconroy/sqlc/internal/config"
core "github.com/kyleconroy/sqlc/internal/pg"
"github.com/kyleconroy/sqlc/internal/postgres"

Expand Down Expand Up @@ -188,7 +189,7 @@ type Result struct {
Catalog core.Catalog
}

func ParseQueries(c core.Catalog, pkg PackageSettings) (*Result, error) {
func ParseQueries(c core.Catalog, pkg config.PackageSettings) (*Result, error) {
f, err := os.Stat(pkg.Queries)
if err != nil {
return nil, fmt.Errorf("path %s does not exist", pkg.Queries)
Expand Down
7 changes: 4 additions & 3 deletions internal/mysql/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/kyleconroy/sqlc/internal/dinosql"
"vitess.io/vitess/go/vt/sqlparser"

"github.com/kyleconroy/sqlc/internal/config"
)

func TestCustomArgErr(t *testing.T) {
Expand Down Expand Up @@ -38,7 +39,7 @@ func TestCustomArgErr(t *testing.T) {
},
},
}
settings := dinosql.Combine(dinosql.GenerateSettings{}, dinosql.PackageSettings{})
settings := config.Combine(config.GenerateSettings{}, config.PackageSettings{})
generator := PackageGenerator{mockSchema, settings, "db"}
for _, tcase := range tests {
q, err := generator.parseContents("queries.sql", tcase.input)
Expand Down Expand Up @@ -82,7 +83,7 @@ func TestPositionedErr(t *testing.T) {
},
}

settings := dinosql.Combine(dinosql.GenerateSettings{}, dinosql.PackageSettings{})
settings := config.Combine(config.GenerateSettings{}, config.PackageSettings{})
for _, tcase := range tests {
generator := PackageGenerator{mockSchema, settings, "db"}
q, err := generator.parseContents("queries.sql", tcase.input)
Expand Down
Loading