Skip to content

Commit

Permalink
feat: goose validate command (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman authored Feb 25, 2023
1 parent 60610d3 commit 8c25e3b
Show file tree
Hide file tree
Showing 9 changed files with 554 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
uses: golangci/golangci-lint-action@v3
with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: v1.50.1
version: latest

# Optional: working directory, useful for monorepos
# working-directory: somedir
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ lint: tools

.PHONY: tools
tools:
@go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.50.1
@go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest

test-packages:
go test $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples)
Expand Down Expand Up @@ -49,4 +49,4 @@ docker-start-postgres:
-e POSTGRES_DB=${GOOSE_POSTGRES_DBNAME} \
-p ${GOOSE_POSTGRES_PORT}:5432 \
-l goose_test \
postgres:14-alpine
postgres:14-alpine -c log_statement=all
67 changes: 67 additions & 0 deletions cmd/goose/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@ import (
"io/fs"
"log"
"os"
"path/filepath"
"runtime/debug"
"sort"
"strconv"
"strings"
"text/tabwriter"
"text/template"

"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/cfg"
"github.com/pressly/goose/v3/internal/migrationstats"
"github.com/pressly/goose/v3/internal/migrationstats/migrationstatsos"
)

var (
Expand Down Expand Up @@ -95,6 +101,11 @@ func main() {
fmt.Printf("%s=%q\n", env.Name, env.Value)
}
return
case "validate":
if err := printValidate(*dir, *verbose); err != nil {
log.Fatalf("goose validate: %v", err)
}
return
}

args = mergeArgs(args)
Expand Down Expand Up @@ -278,3 +289,59 @@ func gooseInit(dir string) error {
}
return goose.CreateWithTemplate(nil, dir, sqlMigrationTemplate, "initial", "sql")
}

func gatherFilenames(filename string) ([]string, error) {
stat, err := os.Stat(filename)
if err != nil {
return nil, err
}
var filenames []string
if stat.IsDir() {
for _, pattern := range []string{"*.sql", "*.go"} {
file, err := filepath.Glob(filepath.Join(filename, pattern))
if err != nil {
return nil, err
}
filenames = append(filenames, file...)
}
} else {
filenames = append(filenames, filename)
}
sort.Strings(filenames)
return filenames, nil
}

func printValidate(filename string, verbose bool) error {
filenames, err := gatherFilenames(filename)
if err != nil {
return err
}
fileWalker := migrationstatsos.NewFileWalker(filenames...)
stats, err := migrationstats.GatherStats(fileWalker, false)
if err != nil {
return err
}
// TODO(mf): we should introduce a --debug flag, which allows printing
// more internal debug information and leave verbose for additional information.
if !verbose {
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', tabwriter.TabIndent)
fmtPattern := "%v\t%v\t%v\t%v\t%v\t\n"
fmt.Fprintf(w, fmtPattern, "Type", "Txn", "Up", "Down", "Name")
fmt.Fprintf(w, fmtPattern, "────", "───", "──", "────", "────")
for _, m := range stats {
txnStr := "✔"
if !m.Tx {
txnStr = "✘"
}
fmt.Fprintf(w, fmtPattern,
strings.TrimPrefix(filepath.Ext(m.FileName), "."),
txnStr,
m.UpCount,
m.DownCount,
filepath.Base(m.FileName),
)
}
return w.Flush()
}
129 changes: 129 additions & 0 deletions internal/migrationstats/migration_go.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package migrationstats

import (
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"io"
)

const (
registerGoFuncName = "AddMigration"
registerGoFuncNameNoTx = "AddMigrationNoTx"
)

type goMigration struct {
name string
useTx *bool
upFuncName, downFuncName string
}

func parseGoFile(r io.Reader) (*goMigration, error) {
astFile, err := parser.ParseFile(
token.NewFileSet(),
"", // filename
r,
// We don't need to resolve imports, so we can skip it.
// This speeds up the parsing process.
// See https://github.com/golang/go/issues/46485
parser.SkipObjectResolution,
)
if err != nil {
return nil, err
}
for _, decl := range astFile.Decls {
fn, ok := decl.(*ast.FuncDecl)
if !ok || fn == nil || fn.Name == nil {
continue
}
if fn.Name.Name == "init" {
return parseInitFunc(fn)
}
}
return nil, errors.New("no init function")
}

func parseInitFunc(fd *ast.FuncDecl) (*goMigration, error) {
if fd == nil {
return nil, fmt.Errorf("function declaration must not be nil")
}
if fd.Body == nil {
return nil, fmt.Errorf("no function body")
}
if len(fd.Body.List) == 0 {
return nil, fmt.Errorf("no registered goose functions")
}
gf := new(goMigration)
for _, statement := range fd.Body.List {
expr, ok := statement.(*ast.ExprStmt)
if !ok {
continue
}
call, ok := expr.X.(*ast.CallExpr)
if !ok {
continue
}
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok || sel == nil {
continue
}
funcName := sel.Sel.Name
b := false
switch funcName {
case registerGoFuncName:
b = true
gf.useTx = &b
case registerGoFuncNameNoTx:
gf.useTx = &b
default:
continue
}
if gf.name != "" {
return nil, fmt.Errorf("found duplicate registered functions:\nprevious: %v\ncurrent: %v", gf.name, funcName)
}
gf.name = funcName

if len(call.Args) != 2 {
return nil, fmt.Errorf("registered goose functions have 2 arguments: got %d", len(call.Args))
}
getNameFromExpr := func(expr ast.Expr) (string, error) {
arg, ok := expr.(*ast.Ident)
if !ok {
return "", fmt.Errorf("failed to assert argument identifer: got %T", arg)
}
return arg.Name, nil
}
var err error
gf.upFuncName, err = getNameFromExpr(call.Args[0])
if err != nil {
return nil, err
}
gf.downFuncName, err = getNameFromExpr(call.Args[1])
if err != nil {
return nil, err
}
}
// validation
switch gf.name {
case registerGoFuncName, registerGoFuncNameNoTx:
default:
return nil, fmt.Errorf("goose register function must be one of: %s or %s",
registerGoFuncName,
registerGoFuncNameNoTx,
)
}
if gf.useTx == nil {
return nil, errors.New("validation error: failed to identify transaction: got nil bool")
}
// The up and down functions can either be named Go functions or "nil", an
// empty string means there is a flaw in our parsing logic of the Go source code.
if gf.upFuncName == "" {
return nil, fmt.Errorf("validation error: up function is empty string")
}
if gf.downFuncName == "" {
return nil, fmt.Errorf("validation error: down function is empty string")
}
return gf, nil
}
47 changes: 47 additions & 0 deletions internal/migrationstats/migration_sql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package migrationstats

import (
"bytes"
"fmt"
"io"
"io/ioutil"

"github.com/pressly/goose/v3/internal/sqlparser"
)

type sqlMigration struct {
useTx bool
upCount, downCount int
}

func parseSQLFile(r io.Reader, debug bool) (*sqlMigration, error) {
by, err := ioutil.ReadAll(r)
if err != nil {
return nil, err
}
upStatements, txUp, err := sqlparser.ParseSQLMigration(
bytes.NewReader(by),
sqlparser.DirectionUp,
debug,
)
if err != nil {
return nil, err
}
downStatements, txDown, err := sqlparser.ParseSQLMigration(
bytes.NewReader(by),
sqlparser.DirectionDown,
debug,
)
if err != nil {
return nil, err
}
// This is a sanity check to ensure that the parser is behaving as expected.
if txUp != txDown {
return nil, fmt.Errorf("up and down statements must have the same transaction mode")
}
return &sqlMigration{
useTx: txUp,
upCount: len(upStatements),
downCount: len(downStatements),
}, nil
}
78 changes: 78 additions & 0 deletions internal/migrationstats/migrationstats.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package migrationstats

import (
"fmt"
"io"
"path/filepath"

"github.com/pressly/goose/v3"
)

// FileWalker walks all files for GatherStats.
type FileWalker interface {
// Walk invokes fn for each file.
Walk(fn func(filename string, r io.Reader) error) error
}

// Stats contains the stats for a migration file.
type Stats struct {
// FileName is the name of the file.
FileName string
// Version is the version of the migration.
Version int64
// Tx is true if the .sql migration file has a +goose NO TRANSACTION annotation
// or the .go migration file calls AddMigrationNoTx.
Tx bool
// UpCount is the number of statements in the Up migration.
UpCount int
// DownCount is the number of statements in the Down migration.
DownCount int
}

// GatherStats returns the migration file stats.
func GatherStats(fw FileWalker, debug bool) ([]*Stats, error) {
var stats []*Stats
err := fw.Walk(func(filename string, r io.Reader) error {
version, err := goose.NumericComponent(filename)
if err != nil {
return fmt.Errorf("failed to get version from file %q: %w", filename, err)
}
var up, down int
var tx bool
switch filepath.Ext(filename) {
case ".sql":
m, err := parseSQLFile(r, debug)
if err != nil {
return fmt.Errorf("failed to parse file %q: %w", filename, err)
}
up, down = m.upCount, m.downCount
tx = m.useTx
case ".go":
m, err := parseGoFile(r)
if err != nil {
return fmt.Errorf("failed to parse file %q: %w", filename, err)
}
up, down = nilAsNumber(m.upFuncName), nilAsNumber(m.downFuncName)
tx = *m.useTx
}
stats = append(stats, &Stats{
FileName: filename,
Version: version,
Tx: tx,
UpCount: up,
DownCount: down,
})
return nil
})
if err != nil {
return nil, err
}
return stats, nil
}

func nilAsNumber(s string) int {
if s != "nil" {
return 1
}
return 0
}
Loading

0 comments on commit 8c25e3b

Please sign in to comment.