diff --git a/reparo/config.go b/reparo/config.go index 5c618dabf..3de84f3da 100644 --- a/reparo/config.go +++ b/reparo/config.go @@ -14,6 +14,7 @@ package reparo import ( + "encoding/json" "flag" "fmt" "os" @@ -37,7 +38,7 @@ const ( // Config is the main configuration for the retore tool. type Config struct { - *flag.FlagSet + *flag.FlagSet `toml:"-" json:"-"` Dir string `toml:"data-dir" json:"data-dir"` StartDatetime string `toml:"start-datetime" json:"start-datetime"` StopDatetime string `toml:"stop-datetime" json:"stop-datetime"` @@ -56,6 +57,8 @@ type Config struct { LogFile string `toml:"log-file" json:"log-file"` LogLevel string `toml:"log-level" json:"log-level"` + SafeMode bool `toml:"safe-mode" json:"safe-mode"` + configFile string printVersion bool } @@ -79,9 +82,19 @@ func NewConfig() *Config { fs.StringVar(&c.LogLevel, "L", "info", "log level: debug, info, warn, error, fatal") fs.StringVar(&c.configFile, "config", "", "[REQUIRED] path to configuration file") fs.BoolVar(&c.printVersion, "V", false, "print reparo version info") + fs.BoolVar(&c.SafeMode, "safe-mode", false, "enable safe mode to make syncer reentrant") return c } +func (c *Config) String() string { + cfgBytes, err := json.Marshal(c) + if err != nil { + log.Error("marshal config failed", zap.Error(err)) + } + + return string(cfgBytes) +} + // Parse parses keys/values from command line flags and toml configuration file. func (c *Config) Parse(args []string) (err error) { // Parse first to get config file diff --git a/reparo/reparo.go b/reparo/reparo.go index edf89c958..1671d123e 100644 --- a/reparo/reparo.go +++ b/reparo/reparo.go @@ -35,9 +35,9 @@ type Reparo struct { // New creates a Reparo object. func New(cfg *Config) (*Reparo, error) { - log.Info("New Reparo", zap.Reflect("config", cfg)) + log.Info("New Reparo", zap.Stringer("config", cfg)) - syncer, err := syncer.New(cfg.DestType, cfg.DestDB) + syncer, err := syncer.New(cfg.DestType, cfg.DestDB, cfg.SafeMode) if err != nil { return nil, errors.Trace(err) } diff --git a/reparo/syncer/mysql.go b/reparo/syncer/mysql.go index e08bdad95..a7d32c665 100644 --- a/reparo/syncer/mysql.go +++ b/reparo/syncer/mysql.go @@ -51,21 +51,22 @@ var ( // should be only used for unit test to create mock db var createDB = loader.CreateDB -func newMysqlSyncer(cfg *DBConfig) (*mysqlSyncer, error) { +func newMysqlSyncer(cfg *DBConfig, safemode bool) (*mysqlSyncer, error) { db, err := createDB(cfg.User, cfg.Password, cfg.Host, cfg.Port) if err != nil { return nil, errors.Trace(err) } - return newMysqlSyncerFromSQLDB(db) + return newMysqlSyncerFromSQLDB(db, safemode) } -func newMysqlSyncerFromSQLDB(db *sql.DB) (*mysqlSyncer, error) { +func newMysqlSyncerFromSQLDB(db *sql.DB, safemode bool) (*mysqlSyncer, error) { loader, err := loader.NewLoader(db, loader.WorkerCount(defaultWorkerCount), loader.BatchSize(defaultBatchSize)) if err != nil { return nil, errors.Annotate(err, "new loader failed") } + loader.SetSafeMode(safemode) syncer := &mysqlSyncer{db: db, loader: loader} syncer.runLoader() diff --git a/reparo/syncer/mysql_test.go b/reparo/syncer/mysql_test.go index 235ef2ed0..a05eab161 100644 --- a/reparo/syncer/mysql_test.go +++ b/reparo/syncer/mysql_test.go @@ -14,6 +14,11 @@ type testMysqlSuite struct{} var _ = check.Suite(&testMysqlSuite{}) func (s *testMysqlSuite) TestMysqlSyncer(c *check.C) { + s.testMysqlSyncer(c, true) + s.testMysqlSyncer(c, false) +} + +func (s *testMysqlSuite) testMysqlSyncer(c *check.C, safemode bool) { var ( mock sqlmock.Sqlmock ) @@ -32,14 +37,14 @@ func (s *testMysqlSuite) TestMysqlSyncer(c *check.C) { createDB = oldCreateDB }() - syncer, err := newMysqlSyncer(&DBConfig{}) + syncer, err := newMysqlSyncer(&DBConfig{}, safemode) c.Assert(err, check.IsNil) mock.ExpectBegin() mock.ExpectExec("create database test").WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectCommit() - mock.ExpectQuery("SELECT column_name, extra FROM information_schema.columns").WithArgs("test", "t1").WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("a", "").AddRow("b", "")) + mock.ExpectQuery("SELECT column_name, extra FROM information_schema.columns").WithArgs("test", "t1").WillReturnRows(sqlmock.NewRows([]string{"column_name", "extra"}).AddRow("a", "").AddRow("b", "").AddRow("c", "")) rows := sqlmock.NewRows([]string{"non_unique", "index_name", "seq_in_index", "column_name"}) mock.ExpectQuery("SELECT non_unique, index_name, seq_in_index, column_name FROM information_schema.statistics"). @@ -47,9 +52,18 @@ func (s *testMysqlSuite) TestMysqlSyncer(c *check.C) { WillReturnRows(rows) mock.ExpectBegin() - mock.ExpectExec("INSERT INTO").WithArgs(1, "test").WillReturnResult(sqlmock.NewResult(0, 1)) + insertPattern := "INSERT INTO" + if safemode { + insertPattern = "REPLACE INTO" + } + mock.ExpectExec(insertPattern).WithArgs(1, "test", nil).WillReturnResult(sqlmock.NewResult(0, 1)) mock.ExpectExec("DELETE FROM").WithArgs(1, "test").WillReturnResult(sqlmock.NewResult(0, 1)) - mock.ExpectExec("UPDATE").WithArgs("abc").WillReturnResult(sqlmock.NewResult(0, 1)) + if safemode { + mock.ExpectExec("DELETE FROM").WithArgs().WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectExec(insertPattern).WithArgs(nil, nil, "abc").WillReturnResult(sqlmock.NewResult(0, 1)) + } else { + mock.ExpectExec("UPDATE").WithArgs("abc", "test").WillReturnResult(sqlmock.NewResult(0, 1)) + } mock.ExpectCommit() syncTest(c, Syncer(syncer)) diff --git a/reparo/syncer/syncer.go b/reparo/syncer/syncer.go index 6bcc89d89..13edb2d47 100644 --- a/reparo/syncer/syncer.go +++ b/reparo/syncer/syncer.go @@ -29,10 +29,10 @@ type Syncer interface { } // New creates a new executor based on the name. -func New(name string, cfg *DBConfig) (Syncer, error) { +func New(name string, cfg *DBConfig, safemode bool) (Syncer, error) { switch name { case "mysql": - return newMysqlSyncer(cfg) + return newMysqlSyncer(cfg, safemode) case "print": return newPrintSyncer() case "memory": diff --git a/reparo/syncer/syncer_test.go b/reparo/syncer/syncer_test.go index f6610b79f..7b2a4ce5a 100644 --- a/reparo/syncer/syncer_test.go +++ b/reparo/syncer/syncer_test.go @@ -34,7 +34,7 @@ func (s *testSyncerSuite) TestNewSyncer(c *check.C) { } for _, testCase := range testCases { - syncer, err := New(testCase.typeStr, cfg) + syncer, err := New(testCase.typeStr, cfg, false) c.Assert(err, check.IsNil) c.Assert(reflect.TypeOf(syncer), testCase.checker, testCase.tp) }