diff --git a/cmd/split.go b/cmd/split.go index b941417..6764151 100644 --- a/cmd/split.go +++ b/cmd/split.go @@ -209,9 +209,9 @@ func (app *AppSplitEstimate) RunE(cmd *cobra.Command, args []string) error { type AppSplitSampling struct { *AppSplit // embedded parent command storage EstimateTableRows int - EstimateTableSize int - RegionSize int - ColumnName string + BaseDbName string + BaseTableName string + BaseIndexName string NewDbName string NewTableName string NewIndexName string @@ -225,67 +225,68 @@ func (app *AppSplit) AppSplitSampling() Cmder { func (app *AppSplitSampling) Cmd() *cobra.Command { cmd := &cobra.Command{ Use: "sampling", - Short: "Split single and joint index region base sampling data", - Long: `Split single and joint index region base sampling data`, + Short: "Generate split region from the distinct value of base table index", + Long: `Generate split region from the distinct value of base table index`, RunE: app.RunE, SilenceUsage: true, } cmd.Flags().IntVar(&app.EstimateTableRows, "new-table-row", 0, "estimate need be split table rows") - cmd.Flags().IntVar(&app.EstimateTableSize, "new-table-size", 0, "estimate need be split table size(M)") - cmd.Flags().IntVar(&app.RegionSize, "region-size", 96, "estimate need be split table region size(M)") - cmd.Flags().StringVar(&app.ColumnName, "col", "", "configure base estimate table column name") + cmd.Flags().StringVar(&app.BaseDbName, "base-db", "", "base estimate table db name") + cmd.Flags().StringVar(&app.BaseTableName, "base-table", "", "base estimate table name") + cmd.Flags().StringVar(&app.BaseIndexName, "base-index", "", "base estimate table index name") cmd.Flags().StringVar(&app.NewDbName, "new-db", "", "configure generate split table new db name through base estimate table column name") cmd.Flags().StringVar(&app.NewTableName, "new-table", "", "configure generate split table new table name through base estimate table column name") cmd.Flags().StringVar(&app.NewIndexName, "new-index", "", "configure generate split table index name through base estimate table column name") cmd.Flags().StringVarP(&app.OutDir, "out-dir", "o", "/tmp/split", "split sql file output dir") - return cmd } -func (app *AppSplitSampling) RunE(cmd *cobra.Command, args []string) error { - if app.DBName == "" { - return fmt.Errorf("flag db name is requirement, can not null") +func (app *AppSplitSampling) validateParameters() error { + msg := "flag `%s` is requirement, can not null" + if app.BaseDbName == "" { + return fmt.Errorf(msg, "base-db") } - engine, err := db.NewMysqlDSN(app.User, app.Password, app.Host, app.Port, app.DBName) + if app.BaseTableName == "" { + return fmt.Errorf(msg, "base-table") + } + if app.BaseIndexName == "" { + return fmt.Errorf(msg, "base-index") + } + if app.NewDbName == "" { + return fmt.Errorf(msg, "new-db") + } + if app.NewTableName == "" { + return fmt.Errorf(msg, "new-table") + } + if app.NewIndexName == "" { + return fmt.Errorf(msg, "new-index") + } + if app.EstimateTableRows == 0 { + return fmt.Errorf(msg, "new-table-row") + } + return nil +} + +func (app *AppSplitSampling) RunE(cmd *cobra.Command, args []string) error { + err := app.validateParameters() if err != nil { return err } - if !engine.IsExistDbName(app.DBName) { + engine, err := db.NewMysqlDSN(app.User, app.Password, app.Host, app.Port, app.BaseDbName) + if err != nil { return err } - //only support single table - switch { - case app.IncludeTable != nil && app.ExcludeTable == nil && app.RegexTable == "": - if len(app.IncludeTable) != 1 { - return fmt.Errorf(" flag include only support configre single table") - } - if app.NewIndexName == "" { - return fmt.Errorf("flag new index name is requirement, can not null") - - } - if err := split.IncludeTableSplitEstimate(engine, - app.DBName, - app.IncludeTable[0], - app.ColumnName, - app.NewDbName, - app.NewTableName, - app.NewIndexName, - app.EstimateTableRows, - app.EstimateTableSize, - app.RegionSize, - app.Concurrency, - app.OutDir); err != nil { - return err - } - default: - if err := cmd.Help(); err != nil { - return err - } - return fmt.Errorf("only support configre flag include, and only single table") - } - return nil + return split.GenerateSplitByBaseTable(engine, + app.BaseDbName, + app.BaseTableName, + app.BaseIndexName, + app.NewDbName, + app.NewTableName, + app.NewIndexName, + app.OutDir, + app.EstimateTableRows) } /* diff --git a/pkg/split/sampling.go b/pkg/split/sampling.go index 238d749..d39735f 100644 --- a/pkg/split/sampling.go +++ b/pkg/split/sampling.go @@ -14,3 +14,268 @@ See the License for the specific language governing permissions and limitations under the License. */ package split + +import ( + "bufio" + "bytes" + "database/sql" + "fmt" + "github.com/WentaoJin/tidba/pkg/db" + "os" + "path" + "strconv" + "strings" + "sync" +) + +func GenerateSplitByBaseTable(engine *db.Engine, baseDB, baseTable, baseIndex, newDB, newTable, newIndex, outDir string, totalWriteRows int) error { + // 1. get distinct value. + s := &splitByBase{ + baseDB: baseDB, + baseTable: baseTable, + baseIndex: baseIndex, + newDB: newDB, + newTable: newTable, + newIndex: newIndex, + outFilePath: path.Join(outDir, "split_by_base.sql"), + } + err := s.init(engine) + if err != nil { + return err + } + var wg sync.WaitGroup + var err1, err2 error + var regionCount int + + wg.Add(2) + go func() { + defer wg.Done() + err1 = s.getBaseDistinctValues(engine) + }() + + go func() { + defer wg.Done() + regionCount, err2 = s.calculateRegionNum(engine, totalWriteRows) + }() + wg.Wait() + + if err1 != nil { + return err1 + } + if err2 != nil { + return err2 + } + + err = s.generateSplit(regionCount) + if err != nil { + return err2 + } + + s.close() + return nil +} + +type splitByBase struct { + baseIndexInfo IndexInfo + file *os.File + fileWriter *bufio.Writer + distinctValues [][]string + + baseDB string + baseTable string + baseIndex string + newDB string + newTable string + newIndex string + outFilePath string +} + +func (s *splitByBase) init(engine *db.Engine) error { + err := s.initOutFile() + if err != nil { + return err + } + err = s.getBaseTableIndex(engine) + if err != nil { + return err + } + + return nil +} + +func (s *splitByBase) generateSplit(regionCount int) error { + if regionCount < 1 { + regionCount = 1 + } + sqlBuf := bytes.NewBuffer(nil) + sqlBuf.WriteString(fmt.Sprintf("split table %s index %s by ", s.tableName(s.newDB, s.newTable), s.newIndex)) + step := len(s.distinctValues) / regionCount + if step < 1 { + step = 1 + } + for i := 0; i < len(s.distinctValues); i += step { + if i > 0 { + sqlBuf.WriteString(",") + } + vs := s.distinctValues[i] + sqlBuf.WriteString("(") + sqlBuf.WriteString(strings.Join(vs, ",")) + sqlBuf.WriteString(")") + } + + _, err := s.fileWriter.WriteString(sqlBuf.String() + ";\n\n") + return err +} + +func (s *splitByBase) initOutFile() error { + outFile, err := os.OpenFile(s.outFilePath, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644) + if err != nil { + return err + } + s.fileWriter, s.file = bufio.NewWriter(outFile), outFile + return nil +} + +func (s *splitByBase) close() { + if s.file != nil { + s.fileWriter.Flush() + s.file.Close() + } +} + +func (s *splitByBase) getBaseTableIndex(engine *db.Engine) error { + s.baseIndexInfo = IndexInfo{ + IndexName: s.baseIndex, + ColumnName: nil, + } + condition := fmt.Sprintf("where lower(table_name)=lower('%s') and lower(table_schema)=lower('%s') and lower(KEY_NAME) = lower ('%s')", + s.baseTable, s.baseDB, s.baseIndex) + query := fmt.Sprintf("select COLUMN_NAME from INFORMATION_SCHEMA.TIDB_INDEXES %s", condition) + err := queryRows(engine.DB, query, func(row, cols []string) error { + if len(row) != 1 { + panic("result row is not index column name, should never happen") + } + s.baseIndexInfo.ColumnName = append(s.baseIndexInfo.ColumnName, row[0]) + return nil + }) + if err != nil { + return err + } + if len(s.baseIndexInfo.ColumnName) == 0 { + return fmt.Errorf("unknow index %v in %v", s.baseIndex, s.tableName(s.baseDB, s.baseTable)) + } + return err +} + +func (s *splitByBase) getBaseDistinctValues(engine *db.Engine) error { + idxCols := strings.Join(s.baseIndexInfo.ColumnName, ",") + query := fmt.Sprintf("select distinct %s from %s order by %s", + idxCols, s.tableName(s.baseDB, s.baseTable), idxCols) + rows, err := queryAllRows(engine.DB, query) + if err != nil { + return err + } + s.distinctValues = rows + return nil +} + +func (s *splitByBase) calculateRegionNum(engine *db.Engine, totalWriteRows int) (int, error) { + baseRows, err := s.getBaseTableCount(engine) + if err != nil { + return 0, err + } + baseIndexRegions, err := s.getBaseTableIndexRegionCount(engine) + if err != nil { + return 0, err + } + if baseIndexRegions < 1 { + baseIndexRegions = 1 + } + capacity := baseRows / baseIndexRegions + if capacity < 10000 { + capacity = 10000 + } + count := totalWriteRows / capacity + if count < 1 { + count = 1 + } + return count, nil +} + +func (s *splitByBase) getBaseTableCount(engine *db.Engine) (int, error) { + count := 0 + query := fmt.Sprintf("select count(1) from %v", s.tableName(s.baseDB, s.baseTable)) + err := queryRows(engine.DB, query, func(row, cols []string) error { + if len(row) != 1 { + panic("result row is not row counts, should never happen") + } + v, err := strconv.Atoi(row[0]) + if err != nil { + return err + } + count = v + return nil + }) + return count, err +} + +func (s *splitByBase) getBaseTableIndexRegionCount(engine *db.Engine) (int, error) { + count := 0 + query := fmt.Sprintf("show table %s index %s regions", s.tableName(s.baseDB, s.baseTable), s.baseIndex) + err := queryRows(engine.DB, query, func(row, cols []string) error { + count++ + return nil + }) + return count, err +} + +func (s *splitByBase) tableName(db, table string) string { + return fmt.Sprintf("%s.%s", db, table) +} + +func queryAllRows(Engine *sql.DB, SQL string) ([][]string, error) { + rows, err := Engine.Query(SQL) + if err == nil { + defer rows.Close() + } + + if err != nil { + return nil, err + } + + cols, err1 := rows.Columns() + if err1 != nil { + return nil, err1 + } + // Read all rows. + var actualRows [][]string + for rows.Next() { + + rawResult := make([][]byte, len(cols)) + result := make([]string, len(cols)) + dest := make([]interface{}, len(cols)) + for i := range rawResult { + dest[i] = &rawResult[i] + } + + err1 = rows.Scan(dest...) + if err1 != nil { + return nil, err1 + } + + for i, raw := range rawResult { + if raw == nil { + result[i] = "NULL" + } else { + val := string(raw) + result[i] = "'" + val + "'" + } + } + + actualRows = append(actualRows, result) + } + if err = rows.Err(); err != nil { + return nil, err + } + return actualRows, nil +}