-
Notifications
You must be signed in to change notification settings - Fork 518
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
554 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
package migrationstats | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"go/ast" | ||
"go/parser" | ||
"go/token" | ||
"io" | ||
) | ||
|
||
const ( | ||
registerGoFuncName = "AddMigration" | ||
registerGoFuncNameNoTx = "AddMigrationNoTx" | ||
) | ||
|
||
type goMigration struct { | ||
name string | ||
useTx *bool | ||
upFuncName, downFuncName string | ||
} | ||
|
||
func parseGoFile(r io.Reader) (*goMigration, error) { | ||
astFile, err := parser.ParseFile( | ||
token.NewFileSet(), | ||
"", // filename | ||
r, | ||
// We don't need to resolve imports, so we can skip it. | ||
// This speeds up the parsing process. | ||
// See https://github.com/golang/go/issues/46485 | ||
parser.SkipObjectResolution, | ||
) | ||
if err != nil { | ||
return nil, err | ||
} | ||
for _, decl := range astFile.Decls { | ||
fn, ok := decl.(*ast.FuncDecl) | ||
if !ok || fn == nil || fn.Name == nil { | ||
continue | ||
} | ||
if fn.Name.Name == "init" { | ||
return parseInitFunc(fn) | ||
} | ||
} | ||
return nil, errors.New("no init function") | ||
} | ||
|
||
func parseInitFunc(fd *ast.FuncDecl) (*goMigration, error) { | ||
if fd == nil { | ||
return nil, fmt.Errorf("function declaration must not be nil") | ||
} | ||
if fd.Body == nil { | ||
return nil, fmt.Errorf("no function body") | ||
} | ||
if len(fd.Body.List) == 0 { | ||
return nil, fmt.Errorf("no registered goose functions") | ||
} | ||
gf := new(goMigration) | ||
for _, statement := range fd.Body.List { | ||
expr, ok := statement.(*ast.ExprStmt) | ||
if !ok { | ||
continue | ||
} | ||
call, ok := expr.X.(*ast.CallExpr) | ||
if !ok { | ||
continue | ||
} | ||
sel, ok := call.Fun.(*ast.SelectorExpr) | ||
if !ok || sel == nil { | ||
continue | ||
} | ||
funcName := sel.Sel.Name | ||
b := false | ||
switch funcName { | ||
case registerGoFuncName: | ||
b = true | ||
gf.useTx = &b | ||
case registerGoFuncNameNoTx: | ||
gf.useTx = &b | ||
default: | ||
continue | ||
} | ||
if gf.name != "" { | ||
return nil, fmt.Errorf("found duplicate registered functions:\nprevious: %v\ncurrent: %v", gf.name, funcName) | ||
} | ||
gf.name = funcName | ||
|
||
if len(call.Args) != 2 { | ||
return nil, fmt.Errorf("registered goose functions have 2 arguments: got %d", len(call.Args)) | ||
} | ||
getNameFromExpr := func(expr ast.Expr) (string, error) { | ||
arg, ok := expr.(*ast.Ident) | ||
if !ok { | ||
return "", fmt.Errorf("failed to assert argument identifer: got %T", arg) | ||
} | ||
return arg.Name, nil | ||
} | ||
var err error | ||
gf.upFuncName, err = getNameFromExpr(call.Args[0]) | ||
if err != nil { | ||
return nil, err | ||
} | ||
gf.downFuncName, err = getNameFromExpr(call.Args[1]) | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
// validation | ||
switch gf.name { | ||
case registerGoFuncName, registerGoFuncNameNoTx: | ||
default: | ||
return nil, fmt.Errorf("goose register function must be one of: %s or %s", | ||
registerGoFuncName, | ||
registerGoFuncNameNoTx, | ||
) | ||
} | ||
if gf.useTx == nil { | ||
return nil, errors.New("validation error: failed to identify transaction: got nil bool") | ||
} | ||
// The up and down functions can either be named Go functions or "nil", an | ||
// empty string means there is a flaw in our parsing logic of the Go source code. | ||
if gf.upFuncName == "" { | ||
return nil, fmt.Errorf("validation error: up function is empty string") | ||
} | ||
if gf.downFuncName == "" { | ||
return nil, fmt.Errorf("validation error: down function is empty string") | ||
} | ||
return gf, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package migrationstats | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"io" | ||
"io/ioutil" | ||
|
||
"github.com/pressly/goose/v3/internal/sqlparser" | ||
) | ||
|
||
type sqlMigration struct { | ||
useTx bool | ||
upCount, downCount int | ||
} | ||
|
||
func parseSQLFile(r io.Reader, debug bool) (*sqlMigration, error) { | ||
by, err := ioutil.ReadAll(r) | ||
if err != nil { | ||
return nil, err | ||
} | ||
upStatements, txUp, err := sqlparser.ParseSQLMigration( | ||
bytes.NewReader(by), | ||
sqlparser.DirectionUp, | ||
debug, | ||
) | ||
if err != nil { | ||
return nil, err | ||
} | ||
downStatements, txDown, err := sqlparser.ParseSQLMigration( | ||
bytes.NewReader(by), | ||
sqlparser.DirectionDown, | ||
debug, | ||
) | ||
if err != nil { | ||
return nil, err | ||
} | ||
// This is a sanity check to ensure that the parser is behaving as expected. | ||
if txUp != txDown { | ||
return nil, fmt.Errorf("up and down statements must have the same transaction mode") | ||
} | ||
return &sqlMigration{ | ||
useTx: txUp, | ||
upCount: len(upStatements), | ||
downCount: len(downStatements), | ||
}, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
package migrationstats | ||
|
||
import ( | ||
"fmt" | ||
"io" | ||
"path/filepath" | ||
|
||
"github.com/pressly/goose/v3" | ||
) | ||
|
||
// FileWalker walks all files for GatherStats. | ||
type FileWalker interface { | ||
// Walk invokes fn for each file. | ||
Walk(fn func(filename string, r io.Reader) error) error | ||
} | ||
|
||
// Stats contains the stats for a migration file. | ||
type Stats struct { | ||
// FileName is the name of the file. | ||
FileName string | ||
// Version is the version of the migration. | ||
Version int64 | ||
// Tx is true if the .sql migration file has a +goose NO TRANSACTION annotation | ||
// or the .go migration file calls AddMigrationNoTx. | ||
Tx bool | ||
// UpCount is the number of statements in the Up migration. | ||
UpCount int | ||
// DownCount is the number of statements in the Down migration. | ||
DownCount int | ||
} | ||
|
||
// GatherStats returns the migration file stats. | ||
func GatherStats(fw FileWalker, debug bool) ([]*Stats, error) { | ||
var stats []*Stats | ||
err := fw.Walk(func(filename string, r io.Reader) error { | ||
version, err := goose.NumericComponent(filename) | ||
if err != nil { | ||
return fmt.Errorf("failed to get version from file %q: %w", filename, err) | ||
} | ||
var up, down int | ||
var tx bool | ||
switch filepath.Ext(filename) { | ||
case ".sql": | ||
m, err := parseSQLFile(r, debug) | ||
if err != nil { | ||
return fmt.Errorf("failed to parse file %q: %w", filename, err) | ||
} | ||
up, down = m.upCount, m.downCount | ||
tx = m.useTx | ||
case ".go": | ||
m, err := parseGoFile(r) | ||
if err != nil { | ||
return fmt.Errorf("failed to parse file %q: %w", filename, err) | ||
} | ||
up, down = nilAsNumber(m.upFuncName), nilAsNumber(m.downFuncName) | ||
tx = *m.useTx | ||
} | ||
stats = append(stats, &Stats{ | ||
FileName: filename, | ||
Version: version, | ||
Tx: tx, | ||
UpCount: up, | ||
DownCount: down, | ||
}) | ||
return nil | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return stats, nil | ||
} | ||
|
||
func nilAsNumber(s string) int { | ||
if s != "nil" { | ||
return 1 | ||
} | ||
return 0 | ||
} |
Oops, something went wrong.