Skip to content
This repository has been archived by the owner on Mar 24, 2022. It is now read-only.

Commit

Permalink
Replace MigrationResult with watcher
Browse files Browse the repository at this point in the history
  • Loading branch information
Kris Hicks committed May 16, 2017
1 parent ca162ef commit 18d3d7c
Show file tree
Hide file tree
Showing 5 changed files with 561 additions and 47 deletions.
3 changes: 2 additions & 1 deletion commands/migrate.go
Expand Up @@ -39,7 +39,8 @@ func (c *MigrateCommand) Execute([]string) error {
}
defer pg.Close()

_, err = pg2mysql.NewMigrator(pg, mysql, c.Truncate).Migrate()
watcher := pg2mysql.NewStdoutPrinter()
err = pg2mysql.NewMigrator(pg, mysql, c.Truncate, watcher).Migrate()
if err != nil {
return fmt.Errorf("failed migrating: %s", err)
}
Expand Down
53 changes: 25 additions & 28 deletions migrator.go
Expand Up @@ -9,48 +9,55 @@ import (
)

type Migrator interface {
Migrate() ([]MigrationResult, error)
Migrate() error
}

func NewMigrator(src, dst DB, truncateFirst bool) Migrator {
func NewMigrator(src, dst DB, truncateFirst bool, watcher MigratorWatcher) Migrator {
return &migrator{
src: src,
dst: dst,
truncateFirst: truncateFirst,
watcher: watcher,
}
}

type migrator struct {
src, dst DB
truncateFirst bool
watcher MigratorWatcher
}

func (m *migrator) Migrate() ([]MigrationResult, error) {
func (m *migrator) Migrate() error {
srcSchema, err := BuildSchema(m.src)
if err != nil {
return nil, fmt.Errorf("failed to build source schema: %s", err)
return fmt.Errorf("failed to build source schema: %s", err)
}

m.watcher.WillDisableConstraints()
err = m.dst.DisableConstraints()
if err != nil {
return nil, fmt.Errorf("failed to disable constraints: %s", err)
return fmt.Errorf("failed to disable constraints: %s", err)
}
m.watcher.DidDisableConstraints()

defer func() {
m.watcher.WillEnableConstraints()
err = m.dst.EnableConstraints()
if err != nil {
fmt.Fprintf(os.Stderr, "failed to enable constraints: %s", err)
m.watcher.EnableConstraintsDidFailWithError(err)
} else {
m.watcher.EnableConstraintsDidFinish()
}
}()

var result []MigrationResult

for _, table := range srcSchema.Tables {
if m.truncateFirst {
m.watcher.WillTruncateTable(table.Name)
_, err := m.dst.DB().Exec(fmt.Sprintf("TRUNCATE TABLE %s", table.Name))
if err != nil {
return nil, fmt.Errorf("failed truncating: %s", err)
return fmt.Errorf("failed truncating: %s", err)
}
m.watcher.TruncateTableDidFinish(table.Name)
}

columnNamesForInsert := make([]string, len(table.Columns))
Expand All @@ -67,15 +74,17 @@ func (m *migrator) Migrate() ([]MigrationResult, error) {
strings.Join(placeholders, ","),
))
if err != nil {
return nil, fmt.Errorf("failed creating prepared statement: %s", err)
return fmt.Errorf("failed creating prepared statement: %s", err)
}

var recordsInserted int64

m.watcher.TableMigrationDidStart(table.Name)

if table.HasColumn("id") {
err = migrateWithIDs(m.src, m.dst, table, &recordsInserted, preparedStmt)
err = migrateWithIDs(m.watcher, m.src, m.dst, table, &recordsInserted, preparedStmt)
if err != nil {
return nil, fmt.Errorf("failed migrating table with ids: %s", err)
return fmt.Errorf("failed migrating table with ids: %s", err)
}
} else {
err = EachMissingRow(m.src, m.dst, table, func(scanArgs []interface{}) {
Expand All @@ -87,24 +96,18 @@ func (m *migrator) Migrate() ([]MigrationResult, error) {
recordsInserted++
})
if err != nil {
return nil, fmt.Errorf("failed migrating table without ids: %s", err)
return fmt.Errorf("failed migrating table without ids: %s", err)
}
}

if recordsInserted > 0 {
result = append(result, MigrationResult{
TableName: table.Name,
RowsMigrated: recordsInserted,
})
}

fmt.Printf("inserted %d records into %s\n", recordsInserted, table.Name)
m.watcher.TableMigrationDidFinish(table.Name, recordsInserted)
}

return result, nil
return nil
}

func migrateWithIDs(
watcher MigratorWatcher,
src DB,
dst DB,
table *Table,
Expand Down Expand Up @@ -183,12 +186,6 @@ func migrateWithIDs(
return nil
}

type MigrationResult struct {
TableName string
RowsMigrated int64
RowsSkipped int64
}

func insert(stmt *sql.Stmt, values []interface{}) error {
result, err := stmt.Exec(values...)
if err != nil {
Expand Down
86 changes: 68 additions & 18 deletions migrator_test.go
Expand Up @@ -7,6 +7,7 @@ import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"github.com/pivotal-cf/pg2mysql"
"github.com/pivotal-cf/pg2mysql/pg2mysqlfakes"
)

var _ = Describe("Migrator", func() {
Expand All @@ -15,6 +16,7 @@ var _ = Describe("Migrator", func() {
mysql pg2mysql.DB
pg pg2mysql.DB
truncateFirst bool
watcher *pg2mysqlfakes.FakeMigratorWatcher
)

BeforeEach(func() {
Expand All @@ -40,7 +42,8 @@ var _ = Describe("Migrator", func() {
err = pg.Open()
Expect(err).NotTo(HaveOccurred())

migrator = pg2mysql.NewMigrator(pg, mysql, truncateFirst)
watcher = &pg2mysqlfakes.FakeMigratorWatcher{}
migrator = pg2mysql.NewMigrator(pg, mysql, truncateFirst, watcher)
})

AfterEach(func() {
Expand All @@ -51,10 +54,35 @@ var _ = Describe("Migrator", func() {
})

Describe("Migrate", func() {
It("returns an empty result", func() {
result, err := migrator.Migrate()
It("notifies the watcher", func() {
err := migrator.Migrate()
Expect(err).NotTo(HaveOccurred())
Expect(result).To(BeNil())
Expect(watcher.TableMigrationDidStartCallCount()).To(Equal(2))
Expect(watcher.TableMigrationDidFinishCallCount()).To(Equal(2))

expected := map[string]int64{
"table_with_id": 0,
"table_without_id": 0,
}

tableName, missingRows := watcher.TableMigrationDidFinishArgsForCall(0)
Expect(missingRows).To(Equal(expected[tableName]))
tableName, missingRows = watcher.TableMigrationDidFinishArgsForCall(1)
Expect(missingRows).To(Equal(expected[tableName]))
})

It("does not insert any data into the target", func() {
err := migrator.Migrate()
Expect(err).NotTo(HaveOccurred())

var count int64
err = mysqlRunner.DB().QueryRow("SELECT COUNT(1) from table_with_id").Scan(&count)
Expect(err).NotTo(HaveOccurred())
Expect(count).To(BeZero())

err = mysqlRunner.DB().QueryRow("SELECT COUNT(1) from table_with_id").Scan(&count)
Expect(err).NotTo(HaveOccurred())
Expect(count).To(BeZero())
})

Context("when there is compatible data in postgres in a table with an 'id' column", func() {
Expand Down Expand Up @@ -84,15 +112,26 @@ var _ = Describe("Migrator", func() {
Expect(rowsAffected).To(BeNumerically("==", 1))
})

It("inserts the data into the target", func() {
result, err := migrator.Migrate()
It("notifies the watcher", func() {
err := migrator.Migrate()
Expect(err).NotTo(HaveOccurred())
Expect(watcher.TableMigrationDidStartCallCount()).To(Equal(2))
Expect(watcher.TableMigrationDidFinishCallCount()).To(Equal(2))

expected := map[string]int64{
"table_with_id": 1,
"table_without_id": 0,
}

tableName, missingRows := watcher.TableMigrationDidFinishArgsForCall(0)
Expect(missingRows).To(Equal(expected[tableName]))
tableName, missingRows = watcher.TableMigrationDidFinishArgsForCall(1)
Expect(missingRows).To(Equal(expected[tableName]))
})

Expect(result).To(ContainElement(pg2mysql.MigrationResult{
TableName: "table_with_id",
RowsMigrated: 1,
RowsSkipped: 0,
}))
It("inserts the data into the target", func() {
err := migrator.Migrate()
Expect(err).NotTo(HaveOccurred())

var id int
var name string
Expand Down Expand Up @@ -137,15 +176,26 @@ var _ = Describe("Migrator", func() {
Expect(rowsAffected).To(BeNumerically("==", 1))
})

It("inserts the data into the target", func() {
result, err := migrator.Migrate()
It("notifies the watcher", func() {
err := migrator.Migrate()
Expect(err).NotTo(HaveOccurred())
Expect(watcher.TableMigrationDidStartCallCount()).To(Equal(2))
Expect(watcher.TableMigrationDidFinishCallCount()).To(Equal(2))

expected := map[string]int64{
"table_with_id": 0,
"table_without_id": 1,
}

tableName, missingRows := watcher.TableMigrationDidFinishArgsForCall(0)
Expect(missingRows).To(Equal(expected[tableName]))
tableName, missingRows = watcher.TableMigrationDidFinishArgsForCall(1)
Expect(missingRows).To(Equal(expected[tableName]))
})

Expect(result).To(ContainElement(pg2mysql.MigrationResult{
TableName: "table_without_id",
RowsMigrated: 1,
RowsSkipped: 0,
}))
It("inserts the data into the target", func() {
err := migrator.Migrate()
Expect(err).NotTo(HaveOccurred())

var name string
var ci_name string
Expand Down

0 comments on commit 18d3d7c

Please sign in to comment.