Skip to content

Commit

Permalink
Merge pull request #201 from schemahero/viper-containment
Browse files Browse the repository at this point in the history
Prevent viper from leaking into pkg
  • Loading branch information
marccampbell committed May 29, 2020
2 parents 6ab68a9 + 0218630 commit cdd2062
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 53 deletions.
9 changes: 8 additions & 1 deletion pkg/cli/schemaherocli/apply.go
Expand Up @@ -48,7 +48,14 @@ func Apply() *cobra.Command {
return err
}

db := database.NewDatabase()
db := database.Database{
InputDir: v.GetString("input-dir"),
OutputDir: v.GetString("output-dir"),
Driver: v.GetString("driver"),
URI: v.GetString("uri"),
VaultURIRef: v.GetString("vault-uri-ref"),
}

if fi.Mode().IsDir() {
commands := []string{}
err := filepath.Walk(v.GetString("ddl"), func(path string, info os.FileInfo, err error) error {
Expand Down
11 changes: 10 additions & 1 deletion pkg/cli/schemaherocli/fixtures.go
Expand Up @@ -16,7 +16,16 @@ func Fixtures() *cobra.Command {
viper.BindPFlags(cmd.Flags())
},
RunE: func(cmd *cobra.Command, args []string) error {
db := database.NewDatabase()
v := viper.GetViper()

db := database.Database{
InputDir: v.GetString("input-dir"),
OutputDir: v.GetString("output-dir"),
Driver: v.GetString("driver"),
URI: v.GetString("uri"),
VaultURIRef: v.GetString("vault-uri-ref"),
}

return db.CreateFixturesSync()
},
}
Expand Down
9 changes: 8 additions & 1 deletion pkg/cli/schemaherocli/generate.go
Expand Up @@ -16,7 +16,14 @@ func Generate() *cobra.Command {
viper.BindPFlags(cmd.Flags())
},
RunE: func(cmd *cobra.Command, args []string) error {
g := generate.NewGenerator()
v := viper.GetViper()

g := generate.Generator{
Driver: v.GetString("driver"),
URI: v.GetString("uri"),
DBName: v.GetString("dbname"),
OutputDir: v.GetString("output-dir"),
}
return g.RunSync()
},
}
Expand Down
9 changes: 8 additions & 1 deletion pkg/cli/schemaherocli/plan.go
Expand Up @@ -69,7 +69,14 @@ func Plan() *cobra.Command {
defer f.Close()
}

db := database.NewDatabase()
db := database.Database{
InputDir: v.GetString("input-dir"),
OutputDir: v.GetString("output-dir"),
Driver: v.GetString("driver"),
URI: v.GetString("uri"),
VaultURIRef: v.GetString("vault-uri-ref"),
}

if fi.Mode().IsDir() {
err := filepath.Walk(v.GetString("spec-file"), func(path string, info os.FileInfo, err error) error {
if !info.IsDir() {
Expand Down
55 changes: 26 additions & 29 deletions pkg/database/database.go
Expand Up @@ -12,24 +12,21 @@ import (
"github.com/schemahero/schemahero/pkg/database/mysql"
"github.com/schemahero/schemahero/pkg/database/postgres"
"github.com/schemahero/schemahero/pkg/logger"
"github.com/spf13/viper"
"go.uber.org/zap"
"gopkg.in/yaml.v2"
)

type Database struct {
Viper *viper.Viper
}

func NewDatabase() *Database {
return &Database{
Viper: viper.GetViper(),
}
InputDir string
OutputDir string
Driver string
URI string
VaultURIRef string
}

func (d *Database) CreateFixturesSync() error {
logger.Infof("generating fixtures",
zap.String("input-dir", d.Viper.GetString("input-dir")))
zap.String("input-dir", d.InputDir))

statements := []string{}
handleFile := func(path string, info os.FileInfo, err error) error {
Expand All @@ -41,7 +38,7 @@ func (d *Database) CreateFixturesSync() error {
return nil
}

fileData, err := ioutil.ReadFile(filepath.Join(d.Viper.GetString("input-dir"), info.Name()))
fileData, err := ioutil.ReadFile(filepath.Join(d.InputDir, info.Name()))
if err != nil {
return err
}
Expand All @@ -68,7 +65,7 @@ func (d *Database) CreateFixturesSync() error {
return nil
}

if d.Viper.GetString("driver") == "postgres" {
if d.Driver == "postgres" {
if spec.Schema.Postgres == nil {
return nil
}
Expand All @@ -79,7 +76,7 @@ func (d *Database) CreateFixturesSync() error {
}

statements = append(statements, statement)
} else if d.Viper.GetString("driver") == "mysql" {
} else if d.Driver == "mysql" {
if spec.Schema.Mysql == nil {
return nil
}
Expand All @@ -90,7 +87,7 @@ func (d *Database) CreateFixturesSync() error {
}

statements = append(statements, statement)
} else if d.Viper.GetString("driver") == "cockroachdb" {
} else if d.Driver == "cockroachdb" {
if spec.Schema.CockroachDB == nil {
return nil
}
Expand All @@ -106,7 +103,7 @@ func (d *Database) CreateFixturesSync() error {
return nil
}

err := filepath.Walk(d.Viper.GetString("input-dir"), handleFile)
err := filepath.Walk(d.InputDir, handleFile)
if err != nil {
fmt.Printf("%#v\n", err)
return err
Expand All @@ -115,11 +112,11 @@ func (d *Database) CreateFixturesSync() error {
output := strings.Join(statements, ";\n")
output = fmt.Sprintf("/* Auto generated file. Do not edit by hand. This file was generated by SchemaHero. */\n\n %s;\n\n", output)

if _, err := os.Stat(d.Viper.GetString("output-dir")); os.IsNotExist(err) {
os.MkdirAll(d.Viper.GetString("output-dir"), 0755)
if _, err := os.Stat(d.OutputDir); os.IsNotExist(err) {
os.MkdirAll(d.OutputDir, 0755)
}

err = ioutil.WriteFile(filepath.Join(d.Viper.GetString("output-dir"), "fixtures.sql"), []byte(output), 0644)
err = ioutil.WriteFile(filepath.Join(d.OutputDir, "fixtures.sql"), []byte(output), 0644)
if err != nil {
fmt.Printf("%#v\n", err)
return err
Expand Down Expand Up @@ -160,27 +157,27 @@ func (d *Database) PlanSync(filename string) ([]string, error) {
return nil, err
}

if d.Viper.GetString("driver") == "postgres" {
if d.Driver == "postgres" {
return postgres.PlanPostgresTable(uri, spec.Name, spec.Schema.Postgres)
} else if d.Viper.GetString("driver") == "mysql" {
return mysql.PlanMysqlTable(d.Viper.GetString("uri"), spec.Name, spec.Schema.Mysql)
} else if d.Viper.GetString("driver") == "cockroachdb" {
return postgres.PlanPostgresTable(d.Viper.GetString("uri"), spec.Name, spec.Schema.CockroachDB)
} else if d.Driver == "mysql" {
return mysql.PlanMysqlTable(uri, spec.Name, spec.Schema.Mysql)
} else if d.Driver == "cockroachdb" {
return postgres.PlanPostgresTable(uri, spec.Name, spec.Schema.CockroachDB)
}

return nil, errors.New("unknown database driver")
}

func getURI(d *Database) (string, error) {
if uriRef := d.Viper.GetString("vault-uri-ref"); uriRef != "" {
b, err := ioutil.ReadFile(uriRef)
if d.VaultURIRef != "" {
b, err := ioutil.ReadFile(d.VaultURIRef)
if err != nil {
return "", err
return "", errors.Wrap(err, "failed to read vault uri file")
}
return string(b), nil
}

return d.Viper.GetString("uri"), nil
return d.URI, nil
}

func (d *Database) ApplySync(statements []string) error {
Expand All @@ -190,11 +187,11 @@ func (d *Database) ApplySync(statements []string) error {
return err
}

if d.Viper.GetString("driver") == "postgres" {
if d.Driver == "postgres" {
return postgres.DeployPostgresStatements(uri, statements)
} else if d.Viper.GetString("driver") == "mysql" {
} else if d.Driver == "mysql" {
return mysql.DeployMysqlStatements(uri, statements)
} else if d.Viper.GetString("driver") == "cockroachdb" {
} else if d.Driver == "cockroachdb" {
return postgres.DeployPostgresStatements(uri, statements)
}

Expand Down
36 changes: 16 additions & 20 deletions pkg/generate/generate.go
Expand Up @@ -12,32 +12,28 @@ import (
"github.com/schemahero/schemahero/pkg/database/mysql"
"github.com/schemahero/schemahero/pkg/database/postgres"
"github.com/schemahero/schemahero/pkg/database/types"
"github.com/spf13/viper"
"gopkg.in/yaml.v2"
)

type Generator struct {
Viper *viper.Viper
}

func NewGenerator() *Generator {
return &Generator{
Viper: viper.GetViper(),
}
Driver string
URI string
DBName string
OutputDir string
}

func (g *Generator) RunSync() error {
fmt.Printf("connecting to %s\n", g.Viper.GetString("uri"))
fmt.Printf("connecting to %s\n", g.URI)

var db interfaces.SchemaHeroDatabaseConnection
if g.Viper.GetString("driver") == "postgres" {
pgDb, err := postgres.Connect(g.Viper.GetString("uri"))
if g.Driver == "postgres" {
pgDb, err := postgres.Connect(g.URI)
if err != nil {
return errors.Wrap(err, "failed to connect to postgres")
}
db = pgDb
} else if g.Viper.GetString("driver") == "mysql" {
mysqlDb, err := mysql.Connect(g.Viper.GetString("uri"))
} else if g.Driver == "mysql" {
mysqlDb, err := mysql.Connect(g.URI)
if err != nil {
return errors.Wrap(err, "failed to connect to mysql")
}
Expand All @@ -56,12 +52,12 @@ func (g *Generator) RunSync() error {
return errors.Wrap(err, "failed to get table primary key")
}

foreignKeys, err := db.ListTableForeignKeys(g.Viper.GetString("dbname"), tableName)
foreignKeys, err := db.ListTableForeignKeys(g.DBName, tableName)
if err != nil {
return errors.Wrap(err, "failed to list table foreign keys")
}

indexes, err := db.ListTableIndexes(g.Viper.GetString("dbname"), tableName)
indexes, err := db.ListTableIndexes(g.DBName, tableName)
if err != nil {
return errors.Wrap(err, "failed to list table indexes")
}
Expand All @@ -75,14 +71,14 @@ func (g *Generator) RunSync() error {
if primaryKey != nil {
primaryKeyColumns = primaryKey.Columns
}
tableYAML, err := generateTableYAML(g.Viper.GetString("driver"), g.Viper.GetString("dbname"), tableName, primaryKeyColumns, foreignKeys, indexes, columns)
tableYAML, err := generateTableYAML(g.Driver, g.DBName, tableName, primaryKeyColumns, foreignKeys, indexes, columns)
if err != nil {
return errors.Wrap(err, "failed to generate table yaml")
}

// If there was a outputdir set, write it, else print it
if g.Viper.GetString("output-dir") != "" {
if err := ioutil.WriteFile(filepath.Join(g.Viper.GetString("output-dir"), fmt.Sprintf("%s.yaml", sanitizeName(tableName))), []byte(tableYAML), 0644); err != nil {
if g.OutputDir != "" {
if err := ioutil.WriteFile(filepath.Join(g.OutputDir, fmt.Sprintf("%s.yaml", sanitizeName(tableName))), []byte(tableYAML), 0644); err != nil {
return err
}

Expand All @@ -95,7 +91,7 @@ func (g *Generator) RunSync() error {
}

// If there was an output-dir, write a kustomization.yaml too -- this should be optional
if g.Viper.GetString("output-dir") != "" {
if g.OutputDir != "" {
kustomization := struct {
Resources []string `yaml:"resources"`
}{
Expand All @@ -107,7 +103,7 @@ func (g *Generator) RunSync() error {
return err
}

if err := ioutil.WriteFile(filepath.Join(g.Viper.GetString("output-dir"), "kustomization.yaml"), kustomizeDoc, 0644); err != nil {
if err := ioutil.WriteFile(filepath.Join(g.OutputDir, "kustomization.yaml"), kustomizeDoc, 0644); err != nil {
return err
}
}
Expand Down

0 comments on commit cdd2062

Please sign in to comment.