diff --git a/cmd/dinosql/main.go b/cmd/dinosql/main.go index c65eac4b49..53fe791051 100644 --- a/cmd/dinosql/main.go +++ b/cmd/dinosql/main.go @@ -8,20 +8,8 @@ import ( ) func main() { - pkg := flag.String("package", "db", "package name for Go code") - sch := flag.String("schema", "", "input directory of SQL migrations") - prepare := flag.Bool("prepare", false, "include prepared query support") - tags := flag.Bool("tags", false, "add tags to database records") - out := flag.String("out", "db.go", "output file") flag.Parse() - - settings := dinosql.GenerateSettings{ - Package: *pkg, - EmitPreparedQueries: *prepare, - EmitTags: *tags, - } - - if err := dinosql.Exec(*sch, flag.Arg(0), *out, settings); err != nil { + if err := dinosql.Exec(flag.Arg(0)); err != nil { log.Fatal(err) } } diff --git a/exec.go b/exec.go index dd1a320827..bcd0c0e1f5 100644 --- a/exec.go +++ b/exec.go @@ -1,20 +1,31 @@ package dinosql import ( + "encoding/json" "io/ioutil" ) -func Exec(schemaDir, queryDir, out string, settings GenerateSettings) error { - s, err := ParseSchmea(schemaDir) +func Exec(settingsPath string) error { + blob, err := ioutil.ReadFile(settingsPath) if err != nil { return err } - q, err := ParseQueries(s, queryDir) + var settings GenerateSettings + if err := json.Unmarshal(blob, &settings); err != nil { + return err + } + + s, err := ParseSchmea(settings.SchemaDir) + if err != nil { + return err + } + + q, err := ParseQueries(s, settings.QueryDir) if err != nil { return err } source := generate(q, settings) - return ioutil.WriteFile(out, []byte(source), 0644) + return ioutil.WriteFile(settings.Out, []byte(source), 0644) } diff --git a/goose.go b/goose.go new file mode 100644 index 0000000000..aa4037d0c2 --- /dev/null +++ b/goose.go @@ -0,0 +1,15 @@ +package dinosql + +import "strings" + +// Remove all lines after a `-- +goose Down` comment +func RemoveGooseRollback(contents string) string { + lines := strings.Split(contents, "\n") + for i, line := range lines { + if strings.HasPrefix(strings.TrimSpace(line), "-- +goose Down") { + lines = lines[:i] + break + } + } + return strings.Join(lines, "\n") +} diff --git a/goose_test.go b/goose_test.go new file mode 100644 index 0000000000..2aadaec17e --- /dev/null +++ b/goose_test.go @@ -0,0 +1,26 @@ +package dinosql + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +const inputMigration = ` +-- +goose Up +ALTER TABLE archived_jobs ADD COLUMN expires_at TIMESTAMP WITH TIME ZONE; + +-- +goose Down +ALTER TABLE archived_jobs DROP COLUMN expires_at; +` + +const outputMigration = ` +-- +goose Up +ALTER TABLE archived_jobs ADD COLUMN expires_at TIMESTAMP WITH TIME ZONE; +` + +func TestRemoveGooseRollback(t *testing.T) { + if diff := cmp.Diff(outputMigration, RemoveGooseRollback(inputMigration)); diff != "" { + t.Errorf("migration mismatch:\n%s", diff) + } +} diff --git a/parser.go b/parser.go index 2c98849d33..2e8cda3cd5 100644 --- a/parser.go +++ b/parser.go @@ -57,7 +57,8 @@ func ParseSchmea(dir string) (*postgres.Schema, error) { if err != nil { return nil, err } - tree, err := pg.Parse(string(blob)) + contents := RemoveGooseRollback(string(blob)) + tree, err := pg.Parse(contents) if err != nil { return nil, err } @@ -1053,10 +1054,20 @@ func lowerTitle(s string) string { return string(a) } +type TypeOverride struct { + Package string `json:"package"` + PostgresType string `json:"postgres_type"` + GoType string `json:"go_type"` +} + type GenerateSettings struct { - Package string - EmitPreparedQueries bool - EmitTags bool + SchemaDir string `json:"schema"` + QueryDir string `json:"queries"` + Out string `json:"out"` + Package string `json:"package"` + EmitPreparedQueries bool `json:"emit_prepared_queries"` + EmitTags bool `json:"emit_tags"` + Overrides []TypeOverride `json:"overrides"` } func generate(r *Result, settings GenerateSettings) string {